from pot_libs.es_util.es_utils import EsUtil
from pot_libs.logger import log
from pot_libs.mysql_util.mysql_util import MysqlUtil
from unify_api.constants import POINT_15MIN_INDEX, INDEX
from unify_api.utils.es_query_body import EsQueryBody
from unify_api.utils.time_format import power_slots, range_to_type


async def pttl_max(cid, start, end, point_id=None, inline_id=None):
    # 根据进线,找point
    if inline_id:
        sql = "SELECT pid from `point` WHERE cid_belongedto = %s " \
              "and inlid_belongedto = %s and add_to_company = 1"
        async with MysqlUtil() as conn:
            point_info = await conn.fetchall(sql=sql,
                                             args=(cid, inline_id))
        point_list = [point.get("pid") for point in point_info]
        terms = {"pid": point_list}
    elif point_id == -1:  # 选的全部
        # 1.找出工厂所有pid，point表add_to_company字段为1
        sql = "SELECT pid from `point` WHERE cid_belongedto = %s " \
              "and add_to_company = 1"
        async with MysqlUtil() as conn:
            point_info = await conn.fetchall(sql=sql, args=(cid,))
        point_list = [point.get("pid") for point in point_info]
        terms = {"pid": point_list}
    else:
        terms = {"pid": [point_id]}
    # 1. 根据时间范围,取不同的index
    date_type = range_to_type(start, end)
    index = INDEX[date_type]
    if date_type == "day":
        date_key = "hour"
    elif date_type == "month":
        date_key = "day"
    else:
        date_key = "month"
    # 2. 构造query_body
    eqb = EsQueryBody(terms=terms, start=start, end=end,
                      date_key=date_key)
    query = eqb.query()
    query["aggs"] = {
        "time_column": {
            "date_histogram": {
                "field": date_key,
                "interval": date_key,
                "time_zone": "+08:00",
                "format": "yyyy-MM-dd HH:mm"
            },
            "aggs": {
                "pttl_max": {
                    "sum": {
                        "field": "pttl_max"
                    }
                }
            }
        }
    }
    log.info(index + f"====={query}")
    async with EsUtil() as es:
        es_re = await es.search_origin(body=query, index=index)
    if not es_re["aggregations"]["time_column"]["buckets"]:
        return "", ""
    # 2.返回
    es_re = es_re["aggregations"]["time_column"]["buckets"]
    # 最大需量
    max_val = 0
    max_val_time = ""
    for res in es_re:
        mdp_max_value = res["pttl_max"]["value"]
        if mdp_max_value and mdp_max_value > max_val:
            max_val = mdp_max_value
            max_val_time = res["key_as_string"]
    # 根据时间范围, 返回不同时间格式
    if max_val_time:
        if date_type == "day":
            max_val_time = max_val_time.split(" ")[1]
        elif date_type == "month":
            max_val_time = max_val_time.split("-", 1)[1].split(" ")[0]
        else:
            max_val_time = max_val_time[:7]
    return max_val, max_val_time


async def pttl_max_new15(cid, start, end, point_id=None, inline_id=None):
    # 根据进线,找point
    if inline_id:
        sql = "SELECT pid from `point` WHERE cid = %s " \
              "and inlid = %s and add_to_company = 1"
        async with MysqlUtil() as conn:
            point_info = await conn.fetchall(sql=sql,
                                             args=(cid, inline_id))
        point_list = [point.get("pid") for point in point_info]
    elif point_id == -1:  # 选的全部
        # 1.找出工厂所有pid，point表add_to_company字段为1
        sql = "SELECT pid from `point` WHERE cid = %s " \
              "and add_to_company = 1"
        async with MysqlUtil() as conn:
            point_info = await conn.fetchall(sql=sql, args=(cid,))
        point_list = [point.get("pid") for point in point_info]
    else:
        point_list = [point_id]
    if not point_list:
        return "", ""
    # 1. 根据时间范围,取不同的index
    date_type = range_to_type(start, end)
    # index = INDEX[date_type]
    if date_type == "day":
        table_name = "point_15min_electric"
        time_format = "%%Y-%%m-%%d %%H:00:00"
    elif date_type == "month":
        table_name = "point_15min_electric"
        time_format = "%%Y-%%m-%%d"
    else:
        table_name = "point_1day_electric"
        time_format = "%%Y-%%m"
    sum_sql = f"SELECT DATE_FORMAT(create_time, '{time_format}') time_date, " \
              f"sum(pttl_max) pttl_max_sum FROM {table_name} WHERE " \
              f"pid in %s and create_time BETWEEN '{start}' and '{end}' " \
              f"GROUP BY time_date ORDER BY time_date"
    async with MysqlUtil() as conn:
        datas = await conn.fetchall(sql=sum_sql, args=(point_list,))
    # 最大需量
    max_val = 0
    max_val_time = ""
    for data in datas:
        if data.get("pttl_max_sum"):
            mdp_max_value = data["pttl_max_sum"]
            if mdp_max_value and mdp_max_value > max_val:
                max_val = mdp_max_value
                max_val_time = data["time_date"]
    # 根据时间范围, 返回不同时间格式
    if max_val_time:
        if date_type == "day":
            max_val_time = str(max_val_time)[11:16]
        elif date_type == "month":
            max_val_time = str(max_val_time)[5:10]
        else:
            max_val_time = str(max_val_time)[:7]
    return max_val, max_val_time


async def pttl_max_15min(cid, start, end, point_id=None, inline_id=None):
    """负荷分布,最高负荷需要拿15min"""
    # 根据进线,找point
    if inline_id:
        sql = "SELECT pid from `point` WHERE cid_belongedto = %s " \
              "and inlid_belongedto = %s"
        async with MysqlUtil() as conn:
            point_info = await conn.fetchall(sql=sql,
                                             args=(cid, inline_id))
        point_list = [point.get("pid") for point in point_info]
        terms = {"pid": point_list}
    elif point_id == -1:  # 选的全部
        # 1.找出工厂所有pid，point表add_to_company字段为1
        sql = "SELECT pid from `point` WHERE cid_belongedto = %s"
        async with MysqlUtil() as conn:
            point_info = await conn.fetchall(sql=sql, args=(cid,))
        point_list = [point.get("pid") for point in point_info]
        terms = {"pid": point_list}
    else:
        terms = {"pid": [point_id]}
    # 1. 根据时间范围,取不同的index
    index = POINT_15MIN_INDEX
    date_key = "quarter_time"
    interval = "15m"
    # 2. 构造query_body
    eqb = EsQueryBody(terms=terms, start=start, end=end,
                      date_key=date_key)
    query = eqb.query()
    query["aggs"] = {
        "time_column": {
            "date_histogram": {
                "field": date_key,
                "interval": interval,
                "time_zone": "+08:00",
                "format": "yyyy-MM-dd HH:mm"
            },
            "aggs": {
                "pttl_max": {
                    "sum": {
                        "field": "pttl_max"
                    }
                }
            }
        }
    }
    log.info(index + f"====={query}")
    async with EsUtil() as es:
        es_re = await es.search_origin(body=query, index=index)
    if not es_re["aggregations"]["time_column"]["buckets"]:
        return "", ""
    # 2.返回
    es_re = es_re["aggregations"]["time_column"]["buckets"]
    # 最大需量
    max_val = 0
    max_val_time = ""
    for res in es_re:
        mdp_max_value = res["pttl_max"]["value"]
        if mdp_max_value and mdp_max_value > max_val:
            max_val = mdp_max_value
            max_val_time = res["key_as_string"]
    # 根据时间范围, 返回不同时间格式
    if max_val_time:
        max_val_time = max_val_time[5:]
    return max_val, max_val_time


async def pttl_max_15min_new15(cid, start, end, point_id=None, inline_id=None):
    if inline_id:
        sql = "SELECT pid from `point` WHERE cid = %s and inlid = %s"
        async with MysqlUtil() as conn:
            point_info = await conn.fetchall(sql=sql, args=(cid, inline_id))
        point_list = [point.get("pid") for point in point_info]
    elif point_id == -1:  # 选的全部
        # 1.找出工厂所有pid，point表add_to_company字段为1
        sql = "SELECT pid from `point` WHERE cid= %s"
        async with MysqlUtil() as conn:
            point_info = await conn.fetchall(sql=sql, args=(cid,))
        point_list = [point.get("pid") for point in point_info]
    else:
        point_list = [point_id]
    # 1. 根据时间范围,取不同的index
    sql = f"SELECT create_time, pttl_max FROM `point_15min_electric` " \
          f"where pid in %s and create_time BETWEEN '{start}' and '{end}'"
    async with MysqlUtil() as conn:
        datas = await conn.fetchall(sql=sql, args=(point_list,))
    if not datas:
        return "", ""
        # 最大需量
    max_val = 0
    max_val_time = ""
    for res in datas:
        mdp_max_value = res["pttl_max"]
        if mdp_max_value and mdp_max_value > max_val:
            max_val = mdp_max_value
            max_val_time = str(res["create_time"])
    # 根据时间范围, 返回不同时间格式
    if max_val_time:
        max_val_time = max_val_time[5:]
    return max_val, max_val_time
