import json import binascii from sanic import Sanic, response from sanic_jwt.decorators import protected from sanic_jwt import Initialize from pot_libs.aredis_util import aredis_utils from jwt.compat import binary_type, text_type from jwt.utils import base64url_decode from jwt.exceptions import DecodeError from pot_libs.aredis_util.aredis_utils import RedisUtils from unify_api.modules.users.procedures.login_pds import ExtAuthentication TOKEN_CONFIG = { "expiration_delta": 7*24*60*60, # 12h->7day # secret generate url:http://tool.c7sky.com/password/ "secret": "NI3g8jVepsoiK0aH7RgKrPSc9MQ4GwnAedDYzO65fQiLWEUk1CBJuthZnTFb2vyq", "refresh_token_enabled": True } REFRESH_CONFIG = { "expiration_delta": 3600 * 48, "leeway": 60 * 3, } async def check_token_blacklist(token): if await aredis_utils.RedisUtils().get("auth:block_list:%s" % token): return True return False # when user logout, store user token async def store_token_blacklist(token, *args, **kwargs): if isinstance(token, text_type): token = token.encode('utf-8') if not issubclass(type(token), binary_type): raise DecodeError("Invalid token type. Token must be a {0}".format( binary_type)) try: _, payload_segment, _ = token.rsplit(b'.') except ValueError: raise DecodeError('Not enough segments') try: payload = base64url_decode(payload_segment) except (TypeError, binascii.Error): raise DecodeError('Invalid payload padding') # TTL exp = TOKEN_CONFIG.get("expiration_delta", 86400) payload = json.loads(payload.decode('utf-8')) user_id = payload.get("user_id", 1) if isinstance(token, bytes): token = token.decode('utf-8', 'ignore') await aredis_utils.RedisUtils().setex("auth:block_list:%s" % token, exp, user_id) async def store_refresh_token(user_id, refresh_token, *args, **kwargs): key = f'refresh_token_{user_id}' await RedisUtils().hset("refresh_tokens", key, refresh_token) async def retrieve_refresh_token(request, user_id, *args, **kwargs): key = f'refresh_token_{user_id}' return await RedisUtils().hget("refresh_tokens", key) class JWTUtils(object): __app = None __sanic_jwt = None __instance = None def __new__(cls, *args, **kwargs): if cls.__instance is None: cls.__instance = super().__new__(cls) return cls.__instance @classmethod def init_app(cls, app): if cls.__app is None: cls.__app = app if cls.__sanic_jwt is None: cls.__sanic_jwt = Initialize( app, authentication_class=ExtAuthentication, url_prefix='/{}'.format(app.name.replace('_', '-')), path_to_authenticate='/auth', path_to_verify='/verify', path_to_refresh='/refresh_token', refresh_token_enabled=TOKEN_CONFIG.get("refresh_token_enabled", True), expiration_delta=TOKEN_CONFIG.get("expiration_delta", 7*24*60*60), secret=TOKEN_CONFIG.get("secret", "QingKe"), store_refresh_token=store_refresh_token, retrieve_refresh_token=retrieve_refresh_token ) @property def app(self): return self.__app if __name__ == '__main__': app = Sanic() sanicjwt = Initialize( app, authentication_class=ExtAuthentication, url_prefix='./{}'.format(app.name), path_to_authenticate='/auth', path_to_verify='/verify', path_to_refresh='/refresh_token', refresh_token_enabled=TOKEN_CONFIG.get("refresh_token_enabled", True), expiration_delta=TOKEN_CONFIG.get( "expiration_delta", 86400), secret=TOKEN_CONFIG.get( "secret", "QingKe"), authenticate=lambda: True, ) @app.route("/api/authentication/check_black_list") @protected() async def check_black_list(request): authorization = request.headers.get("authorization") token = authorization.split()[-1] if await aredis_utils.RedisUtils().get(token): return response.json( {"verify": False, "status": 401, "reason": "token is valid"}) return response.json({"verify": True}) app.run(debug=True, port=8889)