diff --git a/server/app/apis/__init__.py b/server/app/apis/__init__.py index 5eab3f829ae81e17ecc269d98f568eeacc45a68c..7d6d5c41be8d99fa027548ef4afc9da90605e475 100644 --- a/server/app/apis/__init__.py +++ b/server/app/apis/__init__.py @@ -21,14 +21,19 @@ def check_jwt(editor=False, *views): claims = get_jwt_claims() role = claims.get("role") view = claims.get("view") + competition_id = claims.get("competition_id") + competition_id_args = kwargs.get("competition_id") + if role == "Admin": return fn(*args, **kwargs) elif editor and role == "Editor": return fn(*args, **kwargs) - elif view in views: - return fn(*args, **kwargs) - else: - abort(http_codes.UNAUTHORIZED) + + if competition_id_args and view in views: + if competition_id == competition_id_args: + return fn(*args, **kwargs) + + abort(http_codes.UNAUTHORIZED) return decorator diff --git a/server/app/apis/auth.py b/server/app/apis/auth.py index 37da9e98d4b5970cd82b47dd54ad940cc53b6736..b1ec5388e015f212925a9994420168f1c4334f5e 100644 --- a/server/app/apis/auth.py +++ b/server/app/apis/auth.py @@ -33,6 +33,10 @@ def get_user_claims(item_user): return {"role": item_user.role.name, "city_id": item_user.city_id} +def get_code_claims(item_code): + return {"view": item_code.view_type.name, "competition_id": item_code.competition_id} + + @api.route("/signup") class AuthSignup(Resource): @check_jwt(editor=False) @@ -89,7 +93,16 @@ class AuthLoginCode(Resource): api.abort(codes.BAD_REQUEST, "Invalid code") item_code = dbc.get.code_by_code(code) - return item_response(CodeDTO.schema.dump(item_code)) + + access_token = create_access_token(item_code.id, user_claims=get_code_claims(item_code)) + + response = { + "competition_id": item_code.competition_id, + "view_type_id": item_code.view_type_id, + "team_id": item_code.team_id, + "access_token": access_token, + } + return response @api.route("/logout") diff --git a/server/app/database/models.py b/server/app/database/models.py index 3012938955f5c9e4fe1e285a0d95478ca7f7b8dd..b697b830cf2a8c4d55d3fb69a5be56ceececa50e 100644 --- a/server/app/database/models.py +++ b/server/app/database/models.py @@ -230,6 +230,8 @@ class Code(db.Model): competition_id = db.Column(db.Integer, db.ForeignKey("competition.id"), nullable=False) team_id = db.Column(db.Integer, db.ForeignKey("team.id"), nullable=True) + view_type = db.relationship("ViewType", uselist=False) + def __init__(self, code, view_type_id, competition_id=None, team_id=None): self.code = code self.view_type_id = view_type_id @@ -240,7 +242,6 @@ class Code(db.Model): class ViewType(db.Model): id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String(STRING_SIZE), unique=True) - codes = db.relationship("Code", backref="view_type") def __init__(self, name): self.name = name