From 0ae34deb446c5383435ee268569b909acdd4d546 Mon Sep 17 00:00:00 2001 From: robban64 <carl@schonfelder.se> Date: Wed, 28 Apr 2021 09:53:58 +0200 Subject: [PATCH] add: protected route for sockets --- server/app/core/sockets.py | 83 +++++++++++++++++++++++++++++++------- 1 file changed, 69 insertions(+), 14 deletions(-) diff --git a/server/app/core/sockets.py b/server/app/core/sockets.py index 00b2ddf5..bd75ae8a 100644 --- a/server/app/core/sockets.py +++ b/server/app/core/sockets.py @@ -4,12 +4,14 @@ from app.database.models import Competition, Slide, Team, ViewType, Code from flask.globals import request from flask_socketio import SocketIO, emit, join_room import logging +from flask_jwt_extended import verify_jwt_in_request +from flask_jwt_extended.utils import get_jwt_claims logger = logging.getLogger(__name__) logger.propagate = False logger.setLevel(logging.INFO) -formatter = logging.Formatter('[%(levelname)s] %(funcName)s: %(message)s') +formatter = logging.Formatter("[%(levelname)s] %(funcName)s: %(message)s") stream_handler = logging.StreamHandler() stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) @@ -19,6 +21,34 @@ sio = SocketIO(cors_allowed_origins="http://localhost:3000") presentations = {} +def _is_allowed(allowed, actual): + return actual and "*" in allowed or actual in allowed + + +def protect_route(allowed_views=None): + def wrapper(func): + def inner(*args, **kwargs): + try: + verify_jwt_in_request() + except: + logger.warning("Missing Authorization Header") + return + + nonlocal allowed_views + allowed_views = allowed_views or [] + claims = get_jwt_claims() + view = claims.get("view") + if not _is_allowed(allowed_views, view): + logger.warning(f"View {view} is not allowed to access route only accessible by {allowed_views}") + return + + return func(*args, **kwargs) + + return inner + + return wrapper + + @sio.on("connect") def connect(): logger.info(f"Client '{request.sid}' connected") @@ -44,7 +74,9 @@ def start_presentation(data): competition_id = data["competition_id"] if competition_id in presentations: - logger.error(f"Client '{request.sid}' failed to start competition '{competition_id}', presentation already active") + logger.error( + f"Client '{request.sid}' failed to start competition '{competition_id}', presentation already active" + ) return presentations[competition_id] = { @@ -58,16 +90,21 @@ def start_presentation(data): logger.info(f"Client '{request.sid}' started competition '{competition_id}'") + @sio.on("end_presentation") def end_presentation(data): competition_id = data["competition_id"] if competition_id not in presentations: - logger.error(f"Client '{request.sid}' failed to end presentation '{competition_id}', no such presentation exists") + logger.error( + f"Client '{request.sid}' failed to end presentation '{competition_id}', no such presentation exists" + ) return if request.sid not in presentations[competition_id]["clients"]: - logger.error(f"Client '{request.sid}' failed to end presentation '{competition_id}', client not in presentation") + logger.error( + f"Client '{request.sid}' failed to end presentation '{competition_id}', client not in presentation" + ) return if presentations[competition_id]["clients"][request.sid]["view_type"] != "Operator": @@ -96,11 +133,15 @@ def join_presentation(data): competition_id = item_code.competition_id if competition_id not in presentations: - logger.error(f"Client '{request.sid}' failed to join presentation '{competition_id}', no such presentation exists") + logger.error( + f"Client '{request.sid}' failed to join presentation '{competition_id}', no such presentation exists" + ) return if request.sid in presentations[competition_id]["clients"]: - logger.error(f"Client '{request.sid}' failed to join presentation '{competition_id}', client already in presentation") + logger.error( + f"Client '{request.sid}' failed to join presentation '{competition_id}', client already in presentation" + ) return # TODO: Write function in database controller to do this @@ -115,26 +156,35 @@ def join_presentation(data): @sio.on("set_slide") +@protect_route(allowed_views=["Operator"]) def set_slide(data): competition_id = data["competition_id"] slide_order = data["slide_order"] if competition_id not in presentations: - logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', no such presentation exists") + logger.error( + f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', no such presentation exists" + ) return if request.sid not in presentations[competition_id]["clients"]: - logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client not in presentation") + logger.error( + f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client not in presentation" + ) return if presentations[competition_id]["clients"][request.sid]["view_type"] != "Operator": - logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client is not operator") + logger.error( + f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client is not operator" + ) return num_slides = db.session.query(Slide).filter(Slide.competition_id == competition_id).count() if not (0 <= slide_order < num_slides): - logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', slide number {slide_order} does not exist") + logger.error( + f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', slide number {slide_order} does not exist" + ) return presentations[competition_id]["slide"] = slide_order @@ -151,15 +201,21 @@ def set_timer(data): timer = data["timer"] if competition_id not in presentations: - logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', no such presentation exists") + logger.error( + f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', no such presentation exists" + ) return if request.sid not in presentations[competition_id]["clients"]: - logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client not in presentation") + logger.error( + f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client not in presentation" + ) return if presentations[competition_id]["clients"][request.sid]["view_type"] != "Operator": - logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client is not operator") + logger.error( + f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client is not operator" + ) return # TODO: Save timer in presentation, maybe? @@ -168,4 +224,3 @@ def set_timer(data): logger.debug(f"Emitting event 'set_timer' to room {competition_id} including self") logger.info(f"Client '{request.sid}' set timer '{timer}' in presentation '{competition_id}'") - -- GitLab