from collections import defaultdict

import pendulum

from pot_libs.mysql_util.mysql_util import MysqlUtil
from unify_api.modules.common.procedures.common_cps import proxy_safe_run_info
from unify_api.modules.common.procedures.points import get_points, get_points_new15


async def proxy_electric_count_info(cids, month_str):
    safe_stats_list, company_point_map = [], {}
    if cids:
        async with MysqlUtil() as conn:
            sql = "select * from safe_health_stats_cid where cid in %s and cal_month=%s"
            safe_stats_list = await conn.fetchall(sql, args=(cids, month_str,))

        # company_point_map = await get_points(cids)
        company_point_map = await get_points_new15(cids)
    cid_alarm_score_map = {}

    alarm_content_map = {
        "temperature_cnt": 0,
        "residual_current_cnt": 0,
        "electric_param_cnt": 0,
    }
    for i in safe_stats_list:
        first_alarm_cnt = i["lv_1"]
        second_alarm_cnt = i["lv_2"]
        third_alarm_cnt = i["lv_3"]
        cid = i["cid"]
        point_len = len(company_point_map.get(cid) or {})
        alarm_score = (
            (first_alarm_cnt * 2 + second_alarm_cnt * 1 + third_alarm_cnt * 0.5) / point_len
            if point_len
            else 0
        )
        if alarm_score >= 15:
            alarm_score = 15
        cid_alarm_score_map[cid] = alarm_score

        alarm_content_map["temperature_cnt"] += i["temp"]
        alarm_content_map["residual_current_cnt"] += i["residual_curr"]
        alarm_content_map["electric_param_cnt"] += i["ele_para"]

    security_level_map = {
        "security_cnt": 0,
        "pretty_low_cnt": 0,
        "medium_cnt": 0,
        "pretty_high_cnt": 0,
        "high_cnt": 0,
    }
    for cid, alarm_score in cid_alarm_score_map.items():
        if alarm_score <= 1:
            security_level_map["security_cnt"] += 1
        elif 1 < alarm_score <= 2:
            security_level_map["pretty_low_cnt"] += 1
        elif 2 < alarm_score <= 5:
            security_level_map["medium_cnt"] += 1
        elif 5 < alarm_score <= 10:
            security_level_map["pretty_high_cnt"] += 1
        else:
            security_level_map["high_cnt"] += 1

    cid_safe_run_map = {}
    if cids:
        year, month = month_str.split("-")
        start_time = pendulum.datetime(year=int(year), month=int(month), day=1, tz="Asia/Shanghai")
        end_time = start_time.subtract(months=-1, seconds=1)
        cid_safe_run_map = await proxy_safe_run_info(
            cids,
            start_time_str=start_time.format("YYYY-MM-DD HH:mm:ss"),
            end_time_str=end_time.format("YYYY-MM-DD HH:mm:ss"),
        )
    # 如果这个月运行天数不为0,说明这个月这个工厂创建了,选了月份,但是没有创建,不应该算到总用户数中
    total_cid_cnt = len([cid for cid, i in cid_safe_run_map.items() if i["total_days"] != 0])
    total_run_days = sum([i["total_days"] for cid, i in cid_safe_run_map.items()])
    total_safe_run_days = sum([i["safe_run_days"] for cid, i in cid_safe_run_map.items()])

    return {
        "total_cid_cnt": total_cid_cnt,
        "total_run_days": total_run_days,
        "total_safe_run_days": total_safe_run_days,
        "alarm_content": alarm_content_map,
        "security_level": security_level_map,
    }


async def page_proxy_electric_info(
    cids, month, page_size, page_num, sort_field="electric_index", sort_direction="asc"
):
    if not cids:
        return {"total": 0, "rows": []}
    sort_field_map = {
        "electric_index": "safe_exp",
        "first_alarm_cnt": "lv_1",
        "second_alarm_cnt": "lv_2",
        "third_alarm_cnt": "lv_3",
        "electric_alarm_cnt": "ele_para",
        "temperature_alarm_cnt": "temp",
        "residual_current_cnt": "residual_curr",
    }
    sort_field = sort_field_map[sort_field]
    async with MysqlUtil() as conn:
        if sort_direction == "asc":
            # aiomysql 貌似有个bug， 排序字段用%s替换老是不成功，只有用f才有效
            sql = f"select * from safe_health_stats_cid where cid in %s and cal_month=%s order by {sort_field} limit %s offset %s"
        else:
            sql = f"select * from safe_health_stats_cid where cid in %s and cal_month=%s order by {sort_field} desc limit %s offset %s"
        safe_stats_list = await conn.fetchall(
            sql, args=(cids, month, page_size, (page_num - 1) * page_size,)
        )
        total_sql = (
            "select count(*) as total from safe_health_stats_cid where cid in %s and cal_month=%s"
        )
        total_result = await conn.fetchone(total_sql, args=(cids, month))
        total = total_result["total"]

        company_sql = "select cid, shortname from company where cid in %s"
        companys = await conn.fetchall(company_sql, args=(cids,))
        company_map = {i["cid"]: i for i in companys}

    return {
        "total": total,
        "rows": [
            {
                "company_name": company_map[i["cid"]]["shortname"],
                "electric_index": i["safe_exp"],
                "first_alarm_cnt": i["lv_1"],
                "second_alarm_cnt": i["lv_2"],
                "third_alarm_cnt": i["lv_3"],
                "electric_alarm_cnt": i["ele_para"],
                "temperature_alarm_cnt": i["temp"],
                "residual_current_cnt": i["residual_curr"],
            }
            for i in safe_stats_list
        ],
    }
