From 0ae34deb446c5383435ee268569b909acdd4d546 Mon Sep 17 00:00:00 2001
From: robban64 <carl@schonfelder.se>
Date: Wed, 28 Apr 2021 09:53:58 +0200
Subject: [PATCH] add: protected route for sockets

---
 server/app/core/sockets.py | 83 +++++++++++++++++++++++++++++++-------
 1 file changed, 69 insertions(+), 14 deletions(-)

diff --git a/server/app/core/sockets.py b/server/app/core/sockets.py
index 00b2ddf5..bd75ae8a 100644
--- a/server/app/core/sockets.py
+++ b/server/app/core/sockets.py
@@ -4,12 +4,14 @@ from app.database.models import Competition, Slide, Team, ViewType, Code
 from flask.globals import request
 from flask_socketio import SocketIO, emit, join_room
 import logging
+from flask_jwt_extended import verify_jwt_in_request
+from flask_jwt_extended.utils import get_jwt_claims
 
 logger = logging.getLogger(__name__)
 logger.propagate = False
 logger.setLevel(logging.INFO)
 
-formatter = logging.Formatter('[%(levelname)s] %(funcName)s: %(message)s')
+formatter = logging.Formatter("[%(levelname)s] %(funcName)s: %(message)s")
 stream_handler = logging.StreamHandler()
 stream_handler.setFormatter(formatter)
 logger.addHandler(stream_handler)
@@ -19,6 +21,34 @@ sio = SocketIO(cors_allowed_origins="http://localhost:3000")
 presentations = {}
 
 
+def _is_allowed(allowed, actual):
+    return actual and "*" in allowed or actual in allowed
+
+
+def protect_route(allowed_views=None):
+    def wrapper(func):
+        def inner(*args, **kwargs):
+            try:
+                verify_jwt_in_request()
+            except:
+                logger.warning("Missing Authorization Header")
+                return
+
+            nonlocal allowed_views
+            allowed_views = allowed_views or []
+            claims = get_jwt_claims()
+            view = claims.get("view")
+            if not _is_allowed(allowed_views, view):
+                logger.warning(f"View {view} is not allowed to access route only accessible by {allowed_views}")
+                return
+
+            return func(*args, **kwargs)
+
+        return inner
+
+    return wrapper
+
+
 @sio.on("connect")
 def connect():
     logger.info(f"Client '{request.sid}' connected")
@@ -44,7 +74,9 @@ def start_presentation(data):
     competition_id = data["competition_id"]
 
     if competition_id in presentations:
-        logger.error(f"Client '{request.sid}' failed to start competition '{competition_id}', presentation already active")
+        logger.error(
+            f"Client '{request.sid}' failed to start competition '{competition_id}', presentation already active"
+        )
         return
 
     presentations[competition_id] = {
@@ -58,16 +90,21 @@ def start_presentation(data):
 
     logger.info(f"Client '{request.sid}' started competition '{competition_id}'")
 
+
 @sio.on("end_presentation")
 def end_presentation(data):
     competition_id = data["competition_id"]
 
     if competition_id not in presentations:
-        logger.error(f"Client '{request.sid}' failed to end presentation '{competition_id}', no such presentation exists")
+        logger.error(
+            f"Client '{request.sid}' failed to end presentation '{competition_id}', no such presentation exists"
+        )
         return
 
     if request.sid not in presentations[competition_id]["clients"]:
-        logger.error(f"Client '{request.sid}' failed to end presentation '{competition_id}', client not in presentation")
+        logger.error(
+            f"Client '{request.sid}' failed to end presentation '{competition_id}', client not in presentation"
+        )
         return
 
     if presentations[competition_id]["clients"][request.sid]["view_type"] != "Operator":
@@ -96,11 +133,15 @@ def join_presentation(data):
     competition_id = item_code.competition_id
 
     if competition_id not in presentations:
-        logger.error(f"Client '{request.sid}' failed to join presentation '{competition_id}', no such presentation exists")
+        logger.error(
+            f"Client '{request.sid}' failed to join presentation '{competition_id}', no such presentation exists"
+        )
         return
 
     if request.sid in presentations[competition_id]["clients"]:
-        logger.error(f"Client '{request.sid}' failed to join presentation '{competition_id}', client already in presentation")
+        logger.error(
+            f"Client '{request.sid}' failed to join presentation '{competition_id}', client already in presentation"
+        )
         return
 
     # TODO: Write function in database controller to do this
@@ -115,26 +156,35 @@ def join_presentation(data):
 
 
 @sio.on("set_slide")
+@protect_route(allowed_views=["Operator"])
 def set_slide(data):
     competition_id = data["competition_id"]
     slide_order = data["slide_order"]
 
     if competition_id not in presentations:
-        logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', no such presentation exists")
+        logger.error(
+            f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', no such presentation exists"
+        )
         return
 
     if request.sid not in presentations[competition_id]["clients"]:
-        logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client not in presentation")
+        logger.error(
+            f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client not in presentation"
+        )
         return
 
     if presentations[competition_id]["clients"][request.sid]["view_type"] != "Operator":
-        logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client is not operator")
+        logger.error(
+            f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client is not operator"
+        )
         return
 
     num_slides = db.session.query(Slide).filter(Slide.competition_id == competition_id).count()
 
     if not (0 <= slide_order < num_slides):
-        logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', slide number {slide_order} does not exist")
+        logger.error(
+            f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', slide number {slide_order} does not exist"
+        )
         return
 
     presentations[competition_id]["slide"] = slide_order
@@ -151,15 +201,21 @@ def set_timer(data):
     timer = data["timer"]
 
     if competition_id not in presentations:
-        logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', no such presentation exists")
+        logger.error(
+            f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', no such presentation exists"
+        )
         return
 
     if request.sid not in presentations[competition_id]["clients"]:
-        logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client not in presentation")
+        logger.error(
+            f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client not in presentation"
+        )
         return
 
     if presentations[competition_id]["clients"][request.sid]["view_type"] != "Operator":
-        logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client is not operator")
+        logger.error(
+            f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client is not operator"
+        )
         return
 
     # TODO: Save timer in presentation, maybe?
@@ -168,4 +224,3 @@ def set_timer(data):
     logger.debug(f"Emitting event 'set_timer' to room {competition_id} including self")
 
     logger.info(f"Client '{request.sid}' set timer '{timer}' in presentation '{competition_id}'")
-
-- 
GitLab