import json
import binascii
from sanic import Sanic, response
from sanic_jwt.decorators import protected
from sanic_jwt import Initialize
from pot_libs.aredis_util import aredis_utils
from jwt.compat import binary_type, text_type
from jwt.utils import base64url_decode
from jwt.exceptions import DecodeError

from pot_libs.aredis_util.aredis_utils import RedisUtils
from unify_api.modules.users.procedures.login_pds import ExtAuthentication

TOKEN_CONFIG = {
    "expiration_delta": 7*24*60*60,  # 12h->7day
    # secret generate url:http://tool.c7sky.com/password/
    "secret": "NI3g8jVepsoiK0aH7RgKrPSc9MQ4GwnAedDYzO65fQiLWEUk1CBJuthZnTFb2vyq",
    "refresh_token_enabled": True
}

REFRESH_CONFIG = {
    "expiration_delta": 3600 * 48,
    "leeway": 60 * 3,
}


async def check_token_blacklist(token):
    if await aredis_utils.RedisUtils().get("auth:block_list:%s" % token):
        return True
    return False


# when user logout, store user token
async def store_token_blacklist(token, *args, **kwargs):
    if isinstance(token, text_type):
        token = token.encode('utf-8')
    if not issubclass(type(token), binary_type):
        raise DecodeError("Invalid token type. Token must be a {0}".format(
            binary_type))
    try:
        _, payload_segment, _ = token.rsplit(b'.')
    except ValueError:
        raise DecodeError('Not enough segments')
    try:
        payload = base64url_decode(payload_segment)
    except (TypeError, binascii.Error):
        raise DecodeError('Invalid payload padding')
    # TTL
    exp = TOKEN_CONFIG.get("expiration_delta", 86400)
    payload = json.loads(payload.decode('utf-8'))
    user_id = payload.get("user_id", 1)
    if isinstance(token, bytes):
        token = token.decode('utf-8', 'ignore')
    await aredis_utils.RedisUtils().setex("auth:block_list:%s" % token, exp, user_id)


async def store_refresh_token(user_id, refresh_token, *args, **kwargs):
    key = f'refresh_token_{user_id}'
    await RedisUtils().hset("refresh_tokens", key, refresh_token)


async def retrieve_refresh_token(request, user_id, *args, **kwargs):
    key = f'refresh_token_{user_id}'
    return await RedisUtils().hget("refresh_tokens", key)


class JWTUtils(object):
    __app = None
    __sanic_jwt = None
    __instance = None

    def __new__(cls, *args, **kwargs):
        if cls.__instance is None:
            cls.__instance = super().__new__(cls)
        return cls.__instance

    @classmethod
    def init_app(cls, app):
        if cls.__app is None:
            cls.__app = app
        if cls.__sanic_jwt is None:
            cls.__sanic_jwt = Initialize(
                app,
                authentication_class=ExtAuthentication,
                url_prefix='/{}'.format(app.name.replace('_', '-')),
                path_to_authenticate='/auth',
                path_to_verify='/verify',
                path_to_refresh='/refresh_token',
                refresh_token_enabled=TOKEN_CONFIG.get("refresh_token_enabled",
                                                       True),
                expiration_delta=TOKEN_CONFIG.get("expiration_delta", 7*24*60*60),
                secret=TOKEN_CONFIG.get("secret", "QingKe"),
                store_refresh_token=store_refresh_token,
                retrieve_refresh_token=retrieve_refresh_token
            )

    @property
    def app(self):
        return self.__app


if __name__ == '__main__':
    app = Sanic()
    sanicjwt = Initialize(
        app,
        authentication_class=ExtAuthentication,
        url_prefix='./{}'.format(app.name),
        path_to_authenticate='/auth',
        path_to_verify='/verify',
        path_to_refresh='/refresh_token',
        refresh_token_enabled=TOKEN_CONFIG.get("refresh_token_enabled", True),
        expiration_delta=TOKEN_CONFIG.get(
            "expiration_delta", 86400),
        secret=TOKEN_CONFIG.get(
            "secret", "QingKe"),
        authenticate=lambda: True,
    )


    @app.route("/api/authentication/check_black_list")
    @protected()
    async def check_black_list(request):
        authorization = request.headers.get("authorization")
        token = authorization.split()[-1]
        if await aredis_utils.RedisUtils().get(token):
            return response.json(
                {"verify": False, "status": 401, "reason": "token is valid"})
        return response.json({"verify": True})


    app.run(debug=True, port=8889)