diff --git a/server/app/core/sockets.py b/server/app/core/sockets.py index 9115a0f5866c10e194a98dfbcce7645fb3fe90e0..b7a7e53c24e2a1cff90392cb5f70cb69f028eb69 100644 --- a/server/app/core/sockets.py +++ b/server/app/core/sockets.py @@ -3,8 +3,8 @@ Contains all functionality related sockets. That is starting, joining, ending, disconnecting from and syncing active competitions. """ import logging +from functools import wraps -from decorator import decorator from flask.globals import request from flask_jwt_extended import verify_jwt_in_request from flask_jwt_extended.utils import get_jwt @@ -55,8 +55,7 @@ def _get_sync_variables(active_competition, sync_values): return {key: value for key, value in active_competition.items() if key in sync_values} -@decorator -def authorize_client(f, allowed_views=None, require_active_competition=True, *args, **kwargs): +def authorization(allowed_views=None, require_active_competition=True): """ Decorator used to authorize a client that sends socket events. Check that the client has authorization headers, that client view gotten from claims @@ -64,31 +63,41 @@ def authorize_client(f, allowed_views=None, require_active_competition=True, *ar if require_active_competition is True. """ - try: - verify_jwt_in_request() - except: - logger.error(f"Won't call function '{f.__name__}': Missing Authorization Header") - return + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + verify_jwt_in_request() + except: + logger.error(f"Won't call function '{func.__name__}': Missing Authorization Header") + return - def _is_allowed(allowed, actual): - return actual and "*" in allowed or actual in allowed + def _is_allowed(allowed, actual): + return actual and "*" in allowed or actual in allowed - competition_id, view = _unpack_claims() + competition_id, view = _unpack_claims() + + if require_active_competition and not is_active_competition(competition_id): + logger.error(f"Won't call function '{func.__name__}': Competition '{competition_id}' is not active") + return + + nonlocal allowed_views + allowed_views = allowed_views or [] + if not _is_allowed(allowed_views, view): + logger.error( + f"Won't call function '{func.__name__}': View '{view}' is not '{' or '.join(allowed_views)}'" + ) + return - if require_active_competition and not is_active_competition(competition_id): - logger.error(f"Won't call function '{f.__name__}': Competition '{competition_id}' is not active") - return + return func(*args, **kwargs) - allowed_views = allowed_views or [] - if not _is_allowed(allowed_views, view): - logger.error(f"Won't call function '{f.__name__}': View '{view}' is not '{' or '.join(allowed_views)}'") - return + return wrapper - return f(*args, **kwargs) + return decorator @sio.event -@authorize_client(require_active_competition=False, allowed_views=["*"]) +@authorization(require_active_competition=False, allowed_views=["*"]) def connect() -> None: """ Connect to a active competition. If competition with competition_id is not active, @@ -122,7 +131,7 @@ def connect() -> None: @sio.event -@authorize_client(allowed_views=["*"]) +@authorization(allowed_views=["*"]) def disconnect() -> None: """ Remove client from the active_competition it was in. Delete active_competition if no @@ -139,7 +148,7 @@ def disconnect() -> None: @sio.event -@authorize_client(allowed_views=["Operator"]) +@authorization(allowed_views=["Operator"]) def end_presentation() -> None: """ End a presentation by sending end_presentation to all connected clients. @@ -150,7 +159,7 @@ def end_presentation() -> None: @sio.event -@authorize_client(allowed_views=["Operator"]) +@authorization(allowed_views=["Operator"]) def sync(data) -> None: """ Update all values from data thats in an active_competitions. Also sync all diff --git a/server/requirements.txt b/server/requirements.txt index b32db8af5507f9b307ec6f6ccfb03d1608b6ae75..bcb325f5f3a16a5ee0e655435923c38d8926b807 100644 Binary files a/server/requirements.txt and b/server/requirements.txt differ