Skip to content
Snippets Groups Projects
Commit 0ae34deb authored by robban64's avatar robban64
Browse files

add: protected route for sockets

parent 5404d2dc
No related branches found
No related tags found
1 merge request!115Resolve "Add authorization to socketio"
...@@ -4,12 +4,14 @@ from app.database.models import Competition, Slide, Team, ViewType, Code ...@@ -4,12 +4,14 @@ from app.database.models import Competition, Slide, Team, ViewType, Code
from flask.globals import request from flask.globals import request
from flask_socketio import SocketIO, emit, join_room from flask_socketio import SocketIO, emit, join_room
import logging import logging
from flask_jwt_extended import verify_jwt_in_request
from flask_jwt_extended.utils import get_jwt_claims
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.propagate = False logger.propagate = False
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
formatter = logging.Formatter('[%(levelname)s] %(funcName)s: %(message)s') formatter = logging.Formatter("[%(levelname)s] %(funcName)s: %(message)s")
stream_handler = logging.StreamHandler() stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter) stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
...@@ -19,6 +21,34 @@ sio = SocketIO(cors_allowed_origins="http://localhost:3000") ...@@ -19,6 +21,34 @@ sio = SocketIO(cors_allowed_origins="http://localhost:3000")
presentations = {} presentations = {}
def _is_allowed(allowed, actual):
return actual and "*" in allowed or actual in allowed
def protect_route(allowed_views=None):
def wrapper(func):
def inner(*args, **kwargs):
try:
verify_jwt_in_request()
except:
logger.warning("Missing Authorization Header")
return
nonlocal allowed_views
allowed_views = allowed_views or []
claims = get_jwt_claims()
view = claims.get("view")
if not _is_allowed(allowed_views, view):
logger.warning(f"View {view} is not allowed to access route only accessible by {allowed_views}")
return
return func(*args, **kwargs)
return inner
return wrapper
@sio.on("connect") @sio.on("connect")
def connect(): def connect():
logger.info(f"Client '{request.sid}' connected") logger.info(f"Client '{request.sid}' connected")
...@@ -44,7 +74,9 @@ def start_presentation(data): ...@@ -44,7 +74,9 @@ def start_presentation(data):
competition_id = data["competition_id"] competition_id = data["competition_id"]
if competition_id in presentations: if competition_id in presentations:
logger.error(f"Client '{request.sid}' failed to start competition '{competition_id}', presentation already active") logger.error(
f"Client '{request.sid}' failed to start competition '{competition_id}', presentation already active"
)
return return
presentations[competition_id] = { presentations[competition_id] = {
...@@ -58,16 +90,21 @@ def start_presentation(data): ...@@ -58,16 +90,21 @@ def start_presentation(data):
logger.info(f"Client '{request.sid}' started competition '{competition_id}'") logger.info(f"Client '{request.sid}' started competition '{competition_id}'")
@sio.on("end_presentation") @sio.on("end_presentation")
def end_presentation(data): def end_presentation(data):
competition_id = data["competition_id"] competition_id = data["competition_id"]
if competition_id not in presentations: if competition_id not in presentations:
logger.error(f"Client '{request.sid}' failed to end presentation '{competition_id}', no such presentation exists") logger.error(
f"Client '{request.sid}' failed to end presentation '{competition_id}', no such presentation exists"
)
return return
if request.sid not in presentations[competition_id]["clients"]: if request.sid not in presentations[competition_id]["clients"]:
logger.error(f"Client '{request.sid}' failed to end presentation '{competition_id}', client not in presentation") logger.error(
f"Client '{request.sid}' failed to end presentation '{competition_id}', client not in presentation"
)
return return
if presentations[competition_id]["clients"][request.sid]["view_type"] != "Operator": if presentations[competition_id]["clients"][request.sid]["view_type"] != "Operator":
...@@ -96,11 +133,15 @@ def join_presentation(data): ...@@ -96,11 +133,15 @@ def join_presentation(data):
competition_id = item_code.competition_id competition_id = item_code.competition_id
if competition_id not in presentations: if competition_id not in presentations:
logger.error(f"Client '{request.sid}' failed to join presentation '{competition_id}', no such presentation exists") logger.error(
f"Client '{request.sid}' failed to join presentation '{competition_id}', no such presentation exists"
)
return return
if request.sid in presentations[competition_id]["clients"]: if request.sid in presentations[competition_id]["clients"]:
logger.error(f"Client '{request.sid}' failed to join presentation '{competition_id}', client already in presentation") logger.error(
f"Client '{request.sid}' failed to join presentation '{competition_id}', client already in presentation"
)
return return
# TODO: Write function in database controller to do this # TODO: Write function in database controller to do this
...@@ -115,26 +156,35 @@ def join_presentation(data): ...@@ -115,26 +156,35 @@ def join_presentation(data):
@sio.on("set_slide") @sio.on("set_slide")
@protect_route(allowed_views=["Operator"])
def set_slide(data): def set_slide(data):
competition_id = data["competition_id"] competition_id = data["competition_id"]
slide_order = data["slide_order"] slide_order = data["slide_order"]
if competition_id not in presentations: if competition_id not in presentations:
logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', no such presentation exists") logger.error(
f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', no such presentation exists"
)
return return
if request.sid not in presentations[competition_id]["clients"]: if request.sid not in presentations[competition_id]["clients"]:
logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client not in presentation") logger.error(
f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client not in presentation"
)
return return
if presentations[competition_id]["clients"][request.sid]["view_type"] != "Operator": if presentations[competition_id]["clients"][request.sid]["view_type"] != "Operator":
logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client is not operator") logger.error(
f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client is not operator"
)
return return
num_slides = db.session.query(Slide).filter(Slide.competition_id == competition_id).count() num_slides = db.session.query(Slide).filter(Slide.competition_id == competition_id).count()
if not (0 <= slide_order < num_slides): if not (0 <= slide_order < num_slides):
logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', slide number {slide_order} does not exist") logger.error(
f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', slide number {slide_order} does not exist"
)
return return
presentations[competition_id]["slide"] = slide_order presentations[competition_id]["slide"] = slide_order
...@@ -151,15 +201,21 @@ def set_timer(data): ...@@ -151,15 +201,21 @@ def set_timer(data):
timer = data["timer"] timer = data["timer"]
if competition_id not in presentations: if competition_id not in presentations:
logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', no such presentation exists") logger.error(
f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', no such presentation exists"
)
return return
if request.sid not in presentations[competition_id]["clients"]: if request.sid not in presentations[competition_id]["clients"]:
logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client not in presentation") logger.error(
f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client not in presentation"
)
return return
if presentations[competition_id]["clients"][request.sid]["view_type"] != "Operator": if presentations[competition_id]["clients"][request.sid]["view_type"] != "Operator":
logger.error(f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client is not operator") logger.error(
f"Client '{request.sid}' failed to set slide in presentation '{competition_id}', client is not operator"
)
return return
# TODO: Save timer in presentation, maybe? # TODO: Save timer in presentation, maybe?
...@@ -168,4 +224,3 @@ def set_timer(data): ...@@ -168,4 +224,3 @@ def set_timer(data):
logger.debug(f"Emitting event 'set_timer' to room {competition_id} including self") logger.debug(f"Emitting event 'set_timer' to room {competition_id} including self")
logger.info(f"Client '{request.sid}' set timer '{timer}' in presentation '{competition_id}'") logger.info(f"Client '{request.sid}' set timer '{timer}' in presentation '{competition_id}'")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment