From e642ca0b9f4326c341f9934980b7a5438b067724 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Victor=20L=C3=B6fgren?= <viclo211@student.liu.se>
Date: Thu, 29 Apr 2021 12:42:07 +0000
Subject: [PATCH] Resolve "Whitelist jwt"

---
 server/app/__init__.py                   | 12 -------
 server/app/apis/__init__.py              |  2 --
 server/app/apis/answers.py               |  1 -
 server/app/apis/auth.py                  | 43 ++++++++++++++----------
 server/app/apis/codes.py                 |  1 -
 server/app/apis/competitions.py          |  2 --
 server/app/core/__init__.py              |  8 ++++-
 server/app/core/codes.py                 |  1 -
 server/app/core/files.py                 |  3 +-
 server/app/core/parsers.py               |  2 +-
 server/app/core/schemas.py               |  2 --
 server/app/core/sockets.py               |  3 +-
 server/app/database/__init__.py          |  6 +---
 server/app/database/controller/add.py    | 13 ++++---
 server/app/database/controller/delete.py | 16 +++++++--
 server/app/database/controller/get.py    |  8 +++--
 server/app/database/controller/search.py |  2 +-
 server/app/database/controller/utils.py  |  1 -
 server/app/database/models.py            | 13 ++++++-
 server/configmodule.py                   |  2 +-
 server/tests/test_app.py                 |  2 +-
 21 files changed, 79 insertions(+), 64 deletions(-)

diff --git a/server/app/__init__.py b/server/app/__init__.py
index 4a28018c..5bdffe82 100644
--- a/server/app/__init__.py
+++ b/server/app/__init__.py
@@ -45,15 +45,3 @@ def create_app(config_name="configmodule.DevelopmentConfig"):
             return response
 
     return app, sio
-
-
-def identity(payload):
-    user_id = payload["identity"]
-    return models.User.query.filter_by(id=user_id)
-
-
-@jwt.token_in_blacklist_loader
-def check_if_token_in_blacklist(decrypted_token):
-    jti = decrypted_token["jti"]
-
-    return models.Blacklist.query.filter_by(jti=jti).first() is not None
diff --git a/server/app/apis/__init__.py b/server/app/apis/__init__.py
index e5ec3d2c..3cb92488 100644
--- a/server/app/apis/__init__.py
+++ b/server/app/apis/__init__.py
@@ -1,5 +1,3 @@
-from functools import wraps
-
 import app.core.http_codes as http_codes
 from flask_jwt_extended import verify_jwt_in_request
 from flask_jwt_extended.utils import get_jwt_claims
diff --git a/server/app/apis/answers.py b/server/app/apis/answers.py
index 3990e4e1..6d0490a9 100644
--- a/server/app/apis/answers.py
+++ b/server/app/apis/answers.py
@@ -1,4 +1,3 @@
-import app.core.http_codes as codes
 import app.database.controller as dbc
 from app.apis import item_response, list_response, protect_route
 from app.core.dto import QuestionAnswerDTO
diff --git a/server/app/apis/auth.py b/server/app/apis/auth.py
index a3be2091..bf9eeefd 100644
--- a/server/app/apis/auth.py
+++ b/server/app/apis/auth.py
@@ -5,10 +5,10 @@ import app.database.controller as dbc
 from app.apis import item_response, protect_route, text_response
 from app.core import sockets
 from app.core.codes import verify_code
-from app.core.dto import AuthDTO, CodeDTO
-from flask_jwt_extended import (create_access_token, create_refresh_token,
-                                get_jwt_identity, get_raw_jwt,
-                                jwt_refresh_token_required)
+from app.core.dto import AuthDTO
+from app.database.models import Whitelist
+from flask_jwt_extended import create_access_token, get_jti, get_raw_jwt
+from flask_jwt_extended.utils import get_jti
 from flask_restx import Resource, inputs, reqparse
 
 api = AuthDTO.api
@@ -32,7 +32,12 @@ def get_user_claims(item_user):
 
 
 def get_code_claims(item_code):
-    return {"view": item_code.view_type.name, "competition_id": item_code.competition_id, "team_id": item_code.team_id, "code": item_code.code}
+    return {
+        "view": item_code.view_type.name,
+        "competition_id": item_code.competition_id,
+        "team_id": item_code.team_id,
+        "code": item_code.code,
+    }
 
 
 @api.route("/test")
@@ -56,18 +61,16 @@ class AuthSignup(Resource):
         return item_response(schema.dump(item_user))
 
 
-@api.route("/delete/<ID>")
-@api.param("ID")
+@api.route("/delete/<user_id>")
+@api.param("user_id")
 class AuthDelete(Resource):
     @protect_route(allowed_roles=["Admin"])
-    def delete(self, ID):
-        item_user = dbc.get.user(ID)
-
+    def delete(self, user_id):
+        item_user = dbc.get.user(user_id)
+        dbc.delete.whitelist_to_blacklist(Whitelist.user_id == user_id)
         dbc.delete.default(item_user)
-        if int(ID) == get_jwt_identity():
-            jti = get_raw_jwt()["jti"]
-            dbc.add.blacklist(jti)
-        return text_response(f"User {ID} deleted")
+
+        return text_response(f"User {user_id} deleted")
 
 
 @api.route("/login")
@@ -82,9 +85,10 @@ class AuthLogin(Resource):
             api.abort(codes.UNAUTHORIZED, "Invalid email or password")
 
         access_token = create_access_token(item_user.id, user_claims=get_user_claims(item_user))
-        refresh_token = create_refresh_token(item_user.id)
+        # refresh_token = create_refresh_token(item_user.id)
 
-        response = {"id": item_user.id, "access_token": access_token, "refresh_token": refresh_token}
+        response = {"id": item_user.id, "access_token": access_token}
+        dbc.add.whitelist(get_jti(access_token), item_user.id)
         return response
 
 
@@ -98,7 +102,7 @@ class AuthLoginCode(Resource):
             api.abort(codes.UNAUTHORIZED, "Invalid code")
 
         item_code = dbc.get.code_by_code(code)
-    
+
         if item_code.view_type_id != 4:
             if item_code.competition_id not in sockets.presentations:
                 api.abort(codes.UNAUTHORIZED, "Competition not active")
@@ -107,6 +111,7 @@ class AuthLoginCode(Resource):
             item_code.id, user_claims=get_code_claims(item_code), expires_delta=timedelta(hours=8)
         )
 
+        dbc.add.whitelist(get_jti(access_token), competition_id=item_code.competition_id)
         response = {
             "competition_id": item_code.competition_id,
             "view": item_code.view_type.name,
@@ -122,9 +127,12 @@ class AuthLogout(Resource):
     def post(self):
         jti = get_raw_jwt()["jti"]
         dbc.add.blacklist(jti)
+        Whitelist.query.filter(Whitelist.jti == jti).delete()
+        dbc.utils.commit()
         return text_response("Logout")
 
 
+"""
 @api.route("/refresh")
 class AuthRefresh(Resource):
     @protect_route(allowed_roles=["*"])
@@ -137,3 +145,4 @@ class AuthRefresh(Resource):
         dbc.add.blacklist(old_jti)
         response = {"access_token": access_token}
         return response
+"""
diff --git a/server/app/apis/codes.py b/server/app/apis/codes.py
index d07e1743..12aeb088 100644
--- a/server/app/apis/codes.py
+++ b/server/app/apis/codes.py
@@ -1,6 +1,5 @@
 import app.database.controller as dbc
 from app.apis import item_response, list_response, protect_route
-from app.core import http_codes as codes
 from app.core.dto import CodeDTO
 from app.database.models import Code
 from flask_restx import Resource
diff --git a/server/app/apis/competitions.py b/server/app/apis/competitions.py
index c5ff37c9..bfd06fac 100644
--- a/server/app/apis/competitions.py
+++ b/server/app/apis/competitions.py
@@ -1,5 +1,3 @@
-import time
-
 import app.database.controller as dbc
 from app.apis import item_response, list_response, protect_route
 from app.core.dto import CompetitionDTO
diff --git a/server/app/core/__init__.py b/server/app/core/__init__.py
index 94d1daf2..c80bb5c0 100644
--- a/server/app/core/__init__.py
+++ b/server/app/core/__init__.py
@@ -2,7 +2,7 @@
 The core submodule contains everything important to the server that doesn't
 fit neatly in either apis or database.
 """
-
+import app.database.models as models
 from app.database import Base, ExtendedQuery
 from flask_bcrypt import Bcrypt
 from flask_jwt_extended.jwt_manager import JWTManager
@@ -13,3 +13,9 @@ db = SQLAlchemy(model_class=Base, query_class=ExtendedQuery)
 bcrypt = Bcrypt()
 jwt = JWTManager()
 ma = Marshmallow()
+
+
+@jwt.token_in_blacklist_loader
+def check_if_token_in_blacklist(decrypted_token):
+    jti = decrypted_token["jti"]
+    return models.Blacklist.query.filter_by(jti=jti).first() is not None
diff --git a/server/app/core/codes.py b/server/app/core/codes.py
index ad6d844c..c52ddf8d 100644
--- a/server/app/core/codes.py
+++ b/server/app/core/codes.py
@@ -4,7 +4,6 @@ Contains all functions purely related to creating and verifying a code.
 
 import random
 import re
-import string
 
 CODE_LENGTH = 6
 ALLOWED_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
diff --git a/server/app/core/files.py b/server/app/core/files.py
index 5d22c8e0..f45b9bb5 100644
--- a/server/app/core/files.py
+++ b/server/app/core/files.py
@@ -2,10 +2,9 @@
 Contains functions related to file handling, mainly saving and deleting images.
 """
 
-from PIL import Image, ImageChops
+from PIL import Image
 from flask import current_app, has_app_context
 import os
-import datetime
 from flask_uploads import IMAGES, UploadSet
 
 if has_app_context():
diff --git a/server/app/core/parsers.py b/server/app/core/parsers.py
index 4c152520..3b541cf2 100644
--- a/server/app/core/parsers.py
+++ b/server/app/core/parsers.py
@@ -2,7 +2,7 @@
 This module contains the parsers used to parse the data gotten in api requests.
 """
 
-from flask_restx import inputs, reqparse
+from flask_restx import reqparse
 
 
 class Sentinel:
diff --git a/server/app/core/schemas.py b/server/app/core/schemas.py
index a2e81c25..d9331ae7 100644
--- a/server/app/core/schemas.py
+++ b/server/app/core/schemas.py
@@ -3,8 +3,6 @@ This module contains schemas used to convert database objects into
 dictionaries.
 """
 
-from marshmallow.decorators import pre_load
-from marshmallow.decorators import pre_dump, post_dump
 import app.database.models as models
 from app.core import ma
 from marshmallow_sqlalchemy import fields
diff --git a/server/app/core/sockets.py b/server/app/core/sockets.py
index b3bb9f87..4422d466 100644
--- a/server/app/core/sockets.py
+++ b/server/app/core/sockets.py
@@ -6,9 +6,8 @@ connected to the same presentation.
 import logging
 from typing import Dict
 
-import app.database.controller as dbc
 from app.core import db
-from app.database.models import Code, Competition, Slide, Team, ViewType
+from app.database.models import Code, Slide, ViewType
 from flask.globals import request
 from flask_jwt_extended import verify_jwt_in_request
 from flask_jwt_extended.utils import get_jwt_claims
diff --git a/server/app/database/__init__.py b/server/app/database/__init__.py
index 9b002830..2840c94e 100644
--- a/server/app/database/__init__.py
+++ b/server/app/database/__init__.py
@@ -3,15 +3,11 @@ The database submodule contaisn all functionality that has to do with the
 database. It can add, get, delete, edit, search and copy items.
 """
 
-import json
-
 from flask_restx import abort
 from flask_sqlalchemy import BaseQuery
 from flask_sqlalchemy.model import Model
-from sqlalchemy import Column, DateTime, Text
+from sqlalchemy import Column, DateTime
 from sqlalchemy.sql import func
-from sqlalchemy.types import TypeDecorator
-from sqlalchemy import event
 
 
 class Base(Model):
diff --git a/server/app/database/controller/add.py b/server/app/database/controller/add.py
index 18bcd4b1..5f9aec28 100644
--- a/server/app/database/controller/add.py
+++ b/server/app/database/controller/add.py
@@ -6,13 +6,12 @@ import os
 
 import app.core.http_codes as codes
 from app.core import db
-from app.database.controller import get, search, utils
+from app.database.controller import get, utils
 from app.database.models import (
     Blacklist,
     City,
     Code,
     Competition,
-    Component,
     ComponentType,
     ImageComponent,
     Media,
@@ -28,14 +27,12 @@ from app.database.models import (
     TextComponent,
     User,
     ViewType,
+    Whitelist,
 )
 from flask.globals import current_app
 from flask_restx import abort
 from PIL import Image
 from sqlalchemy import exc
-from sqlalchemy.orm import with_polymorphic
-from sqlalchemy.orm import relation
-from sqlalchemy.orm.session import sessionmaker
 from flask import current_app
 
 from app.database.types import ID_IMAGE_COMPONENT, ID_QUESTION_COMPONENT, ID_TEXT_COMPONENT
@@ -197,6 +194,12 @@ def blacklist(jti):
     return db_add(Blacklist(jti))
 
 
+def whitelist(jti, user_id=None, competition_id=None):
+    """ Adds a whitelist to the database. """
+
+    return db_add(Whitelist(jti, user_id, competition_id))
+
+
 def mediaType(name):
     """ Adds a media type to the database. """
 
diff --git a/server/app/database/controller/delete.py b/server/app/database/controller/delete.py
index 93a4c339..b0b36fba 100644
--- a/server/app/database/controller/delete.py
+++ b/server/app/database/controller/delete.py
@@ -5,9 +5,8 @@ This file contains functionality to delete data to the database.
 import app.core.http_codes as codes
 import app.database.controller as dbc
 from app.core import db
-from app.database.models import Blacklist, City, Competition, Role, Slide, User
+from app.database.models import Whitelist
 from flask_restx import abort
-from sqlalchemy import exc
 
 
 def default(item):
@@ -20,6 +19,19 @@ def default(item):
         abort(codes.INTERNAL_SERVER_ERROR, f"Item of type {type(item)} could not be deleted")
 
 
+def whitelist_to_blacklist(filters):
+    """
+    Remove whitelist by condition(filters) and insert those into blacklist
+    Example: When delete user all whitelisted tokens for that user should be blacklisted
+    """
+    whitelist = Whitelist.query.filter(filters).all()
+    for item in whitelist:
+        dbc.add.blacklist(item.jti)
+
+    Whitelist.query.filter(filters).delete()
+    dbc.utils.commit()
+
+
 def component(item_component):
     """ Deletes component. """
 
diff --git a/server/app/database/controller/get.py b/server/app/database/controller/get.py
index ba975d35..b6701ca5 100644
--- a/server/app/database/controller/get.py
+++ b/server/app/database/controller/get.py
@@ -27,17 +27,19 @@ def all(db_type):
     return db_type.query.all()
 
 
-def one(db_type, id):
+def one(db_type, id, required=True):
     """ Get lazy db-item in the table that has the same id. """
 
-    return db_type.query.filter(db_type.id == id).first_extended()
+    return db_type.query.filter(db_type.id == id).first_extended(required=required)
 
 
 ### Codes ###
 def code_by_code(code):
     """ Gets the code object associated with the provided code. """
 
-    return Code.query.filter(Code.code == code.upper()).first_extended( True, "A presentation with that code does not exist")
+    return Code.query.filter(Code.code == code.upper()).first_extended(
+        True, "A presentation with that code does not exist"
+    )
 
 
 def code_list(competition_id):
diff --git a/server/app/database/controller/search.py b/server/app/database/controller/search.py
index bfc40843..4d112a5f 100644
--- a/server/app/database/controller/search.py
+++ b/server/app/database/controller/search.py
@@ -2,7 +2,7 @@
 This file contains functionality to find data to the database.
 """
 
-from app.database.models import Competition, Media, Question, Slide, Team, User
+from app.database.models import Competition, Media, Question, Slide, User
 
 
 def image(filename, page=0, page_size=15, order=1, order_by=None):
diff --git a/server/app/database/controller/utils.py b/server/app/database/controller/utils.py
index c01205a6..14eaa48d 100644
--- a/server/app/database/controller/utils.py
+++ b/server/app/database/controller/utils.py
@@ -7,7 +7,6 @@ from app.core import db
 from app.core.codes import generate_code_string
 from app.database.models import Code
 from flask_restx import abort
-from sqlalchemy import exc
 
 
 def move_slides(item_competition, start_order, end_order):
diff --git a/server/app/database/models.py b/server/app/database/models.py
index 0f909aee..a59335d9 100644
--- a/server/app/database/models.py
+++ b/server/app/database/models.py
@@ -12,10 +12,21 @@ from app.database.types import ID_IMAGE_COMPONENT, ID_QUESTION_COMPONENT, ID_TEX
 STRING_SIZE = 254
 
 
+class Whitelist(db.Model):
+    id = db.Column(db.Integer, primary_key=True)
+    jti = db.Column(db.String, unique=True)
+    user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=True)
+    competition_id = db.Column(db.Integer, db.ForeignKey("competition.id"), nullable=True)
+
+    def __init__(self, jti, user_id=None, competition_id=None):
+        self.jti = jti
+        self.user_id = user_id
+        self.competition_id = competition_id
+
+
 class Blacklist(db.Model):
     id = db.Column(db.Integer, primary_key=True)
     jti = db.Column(db.String, unique=True)
-    expire_date = db.Column(db.Integer, nullable=True)
 
     def __init__(self, jti):
         self.jti = jti
diff --git a/server/configmodule.py b/server/configmodule.py
index 8c07211e..93d21cbe 100644
--- a/server/configmodule.py
+++ b/server/configmodule.py
@@ -12,7 +12,7 @@ class Config:
     JWT_BLACKLIST_ENABLED = True
     JWT_BLACKLIST_TOKEN_CHECKS = ["access", "refresh"]
     JWT_ACCESS_TOKEN_EXPIRES = timedelta(days=2)
-    JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=30)
+    # JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=30)
     UPLOADED_PHOTOS_DEST = os.path.join(os.getcwd(), "app", "static", "images")
     THUMBNAIL_SIZE = (120, 120)
     SECRET_KEY = os.urandom(24)
diff --git a/server/tests/test_app.py b/server/tests/test_app.py
index d59428a6..3deb6f29 100644
--- a/server/tests/test_app.py
+++ b/server/tests/test_app.py
@@ -149,7 +149,7 @@ def test_auth_and_user_api(client):
     # Try loggin with right PASSWORD
     response, body = post(client, "/api/auth/login", {"email": "test1@test.se", "password": "abc123"})
     assert response.status_code == codes.OK
-    refresh_token = body["refresh_token"]
+    # refresh_token = body["refresh_token"]
     headers = {"Authorization": "Bearer " + body["access_token"]}
 
     # Get the current user
-- 
GitLab