jwt_utils.py 4.3 KB
Newer Older
lcn's avatar
lcn committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
import json
import binascii
from sanic import Sanic, response
from sanic_jwt.decorators import protected
from sanic_jwt import Initialize
from pot_libs.aredis_util import aredis_utils
from jwt.compat import binary_type, text_type
from jwt.utils import base64url_decode
from jwt.exceptions import DecodeError

from pot_libs.aredis_util.aredis_utils import RedisUtils
from unify_api.modules.users.procedures.login_pds import ExtAuthentication

TOKEN_CONFIG = {
    "expiration_delta": 7*24*60*60,  # 12h->7day
    # secret generate url:http://tool.c7sky.com/password/
    "secret": "NI3g8jVepsoiK0aH7RgKrPSc9MQ4GwnAedDYzO65fQiLWEUk1CBJuthZnTFb2vyq",
    "refresh_token_enabled": True
}

REFRESH_CONFIG = {
    "expiration_delta": 3600 * 48,
    "leeway": 60 * 3,
}


async def check_token_blacklist(token):
    if await aredis_utils.RedisUtils().get("auth:block_list:%s" % token):
        return True
    return False


# when user logout, store user token
async def store_token_blacklist(token, *args, **kwargs):
    if isinstance(token, text_type):
        token = token.encode('utf-8')
    if not issubclass(type(token), binary_type):
        raise DecodeError("Invalid token type. Token must be a {0}".format(
            binary_type))
    try:
        _, payload_segment, _ = token.rsplit(b'.')
    except ValueError:
        raise DecodeError('Not enough segments')
    try:
        payload = base64url_decode(payload_segment)
    except (TypeError, binascii.Error):
        raise DecodeError('Invalid payload padding')
    # TTL
    exp = TOKEN_CONFIG.get("expiration_delta", 86400)
    payload = json.loads(payload.decode('utf-8'))
    user_id = payload.get("user_id", 1)
    if isinstance(token, bytes):
        token = token.decode('utf-8', 'ignore')
    await aredis_utils.RedisUtils().setex("auth:block_list:%s" % token, exp, user_id)


async def store_refresh_token(user_id, refresh_token, *args, **kwargs):
    key = f'refresh_token_{user_id}'
    await RedisUtils().hset("refresh_tokens", key, refresh_token)


async def retrieve_refresh_token(request, user_id, *args, **kwargs):
    key = f'refresh_token_{user_id}'
    return await RedisUtils().hget("refresh_tokens", key)


class JWTUtils(object):
    __app = None
    __sanic_jwt = None
    __instance = None

    def __new__(cls, *args, **kwargs):
        if cls.__instance is None:
            cls.__instance = super().__new__(cls)
        return cls.__instance

    @classmethod
    def init_app(cls, app):
        if cls.__app is None:
            cls.__app = app
        if cls.__sanic_jwt is None:
            cls.__sanic_jwt = Initialize(
                app,
                authentication_class=ExtAuthentication,
                url_prefix='/{}'.format(app.name.replace('_', '-')),
                path_to_authenticate='/auth',
                path_to_verify='/verify',
                path_to_refresh='/refresh_token',
                refresh_token_enabled=TOKEN_CONFIG.get("refresh_token_enabled",
                                                       True),
                expiration_delta=TOKEN_CONFIG.get("expiration_delta", 7*24*60*60),
                secret=TOKEN_CONFIG.get("secret", "QingKe"),
                store_refresh_token=store_refresh_token,
                retrieve_refresh_token=retrieve_refresh_token
            )

    @property
    def app(self):
        return self.__app


if __name__ == '__main__':
    app = Sanic()
    sanicjwt = Initialize(
        app,
        authentication_class=ExtAuthentication,
        url_prefix='./{}'.format(app.name),
        path_to_authenticate='/auth',
        path_to_verify='/verify',
        path_to_refresh='/refresh_token',
        refresh_token_enabled=TOKEN_CONFIG.get("refresh_token_enabled", True),
        expiration_delta=TOKEN_CONFIG.get(
            "expiration_delta", 86400),
        secret=TOKEN_CONFIG.get(
            "secret", "QingKe"),
        authenticate=lambda: True,
    )


    @app.route("/api/authentication/check_black_list")
    @protected()
    async def check_black_list(request):
        authorization = request.headers.get("authorization")
        token = authorization.split()[-1]
        if await aredis_utils.RedisUtils().get(token):
            return response.json(
                {"verify": False, "status": 401, "reason": "token is valid"})
        return response.json({"verify": True})


    app.run(debug=True, port=8889)