from pot_libs.mysql_util.mysql_util import MysqlUtil
from unify_api.utils.time_format import range_to_type


async def load_pttl_max(cid, start, end, point_id=None, inline_id=None):
    # 根据进线,找point
    if inline_id:
        sql = "SELECT pid from `point` WHERE cid = %s " \
              "and inlid = %s and add_to_company = 1"
        async with MysqlUtil() as conn:
            point_info = await conn.fetchall(sql=sql,
                                             args=(cid, inline_id))
        point_list = [point.get("pid") for point in point_info]
    elif point_id == -1:  # 选的全部
        # 1.找出工厂所有pid,point表add_to_company字段为1
        sql = "SELECT pid from `point` WHERE cid = %s " \
              "and add_to_company = 1"
        async with MysqlUtil() as conn:
            point_info = await conn.fetchall(sql=sql, args=(cid,))
        point_list = [point.get("pid") for point in point_info]
    else:
        point_list = [point_id]
    if not point_list:
        return "", ""
    # 1. 根据时间范围,取不同的index
    date_type = range_to_type(start, end)
    # index = INDEX[date_type]
    if date_type == "day":
        table_name = "point_15min_electric"
    else:
        table_name = "point_1day_electric"
    sum_sql = f"SELECT create_time, sum(pttl_max) pttl_max_sum " \
              f"FROM {table_name} WHERE " \
              f"pid in %s and create_time BETWEEN '{start}' and '{end}' " \
              f"GROUP BY create_time ORDER BY create_time"
    async with MysqlUtil() as conn:
        datas = await conn.fetchall(sql=sum_sql, args=(point_list,))
    # 最大需量
    max_val = 0
    max_val_time = ""
    for data in datas:
        if data.get("pttl_max_sum"):
            mdp_max_value = data["pttl_max_sum"]
            if mdp_max_value and mdp_max_value > max_val:
                max_val = mdp_max_value
                max_val_time = data["create_time"]
    # 根据时间范围, 返回不同时间格式
    if max_val_time:
        if date_type == "day":
            max_val_time = max_val_time.strftime("%H:00")
        elif date_type == "month":
            max_val_time = max_val_time.strftime("%m-%d")
        else:
            max_val_time = max_val_time.strftime("%Y-%m")
    return max_val, max_val_time


async def load_pttl_max_15min(cid, start, end, point_id=None, inline_id=None):
    if inline_id:
        sql = "SELECT pid from `point` WHERE cid = %s and inlid = %s"
        async with MysqlUtil() as conn:
            point_info = await conn.fetchall(sql=sql, args=(cid, inline_id))
        point_list = [point.get("pid") for point in point_info]
    elif point_id == -1:  # 选的全部
        # 1.找出工厂所有pid,point表add_to_company字段为1
        sql = "SELECT pid from `point` WHERE cid= %s"
        async with MysqlUtil() as conn:
            point_info = await conn.fetchall(sql=sql, args=(cid,))
        point_list = [point.get("pid") for point in point_info]
    else:
        point_list = [point_id]
    # 1. 根据时间范围,取不同的index
    sql = f"SELECT create_time, pttl_max FROM `point_15min_electric` " \
          f"where pid in %s and create_time BETWEEN '{start}' and '{end}'"
    async with MysqlUtil() as conn:
        datas = await conn.fetchall(sql=sql, args=(point_list,))
    if not datas:
        return "", ""
        # 最大需量
    max_val = 0
    max_val_time = ""
    for res in datas:
        mdp_max_value = res["pttl_max"]
        if mdp_max_value and mdp_max_value > max_val:
            max_val = mdp_max_value
            max_val_time = str(res["create_time"])
    # 根据时间范围, 返回不同时间格式
    if max_val_time:
        max_val_time = max_val_time[5:16]
    return max_val, max_val_time