from datetime import datetime
from pot_libs.es_util.es_utils import EsUtil
from pot_libs.logger import log
from unify_api.constants import COMPANY_15MIN_POWER, COMPANY_DAY_AHEAD_PRE
from unify_api.modules.elec_charge.common.utils import interval_type
from unify_api.modules.load_analysis.components.load_forecast_cps import (
    ForecastResp,
    LoadValue,
)
from unify_api.utils.common_utils import choose_list
from unify_api.utils.es_query_body import EsQueryBody, es_process, \
    sql_time_process
from unify_api.utils.time_format import time_pick_transf
from unify_api.modules.load_analysis.dao.load_forecast_dao import \
    get_kwh_p_dao, get_pred_p_dao


async def load_forecast_service(cid, cids, start, end):
    """负荷预测"""
    terms = {"cid": [cid]}
    if cids:
        terms = {"cid": cids}
    # 2.es查询实时、预测数据
    query = EsQueryBody(terms=terms, start=start, end=end,
                        date_key="quarter_time", size=500)
    real_query = query.query()
    # date_type转换, 用于聚合查询
    trans_type = interval_type("range", start, end)
    if trans_type != "day":
        trans_type = "15m"
    interval = trans_type
    format = "yyyy-MM-dd HH:mm:ss"

    real_query["aggs"] = {
        "quarter_time": {
            "date_histogram": {
                "field": "quarter_time",
                "interval": interval,
                "time_zone": "+08:00",
                "format": format,
            },
            "aggs": {"p": {"stats": {"field": "p"}}, "kwh": {"stats": {"field": "kwh"}}},
        }
    }
    real_index = COMPANY_15MIN_POWER
    log.info(real_index + f"====={real_query}")
    async with EsUtil() as es:
        real_re = await es.search_origin(body=real_query, index=real_index)
    if trans_type == "day":
        real_re = real_re["aggregations"]["quarter_time"]["buckets"]
    else:
        real_re = real_re["aggregations"]["quarter_time"]["buckets"]
    if not real_re:
        return ForecastResp()
    # 获取slots
    intervel, slots = time_pick_transf(start, end)
    if trans_type == "day":
        fmt = "MM-DD"
        real_re = es_process(real_re, fmat=fmt)
    else:
        fmt = "HH:mm"
        real_re = es_process(real_re, fmat=fmt)
    lv_real = LoadValue()
    lv_real.slots = slots
    real_list = []
    real_power_list, pred_power_list = [], []
    for slot in slots:
        if slot in real_re:
            if trans_type == "day":
                real_value = real_re[slot].get("p").get("sum")
                real_power_val = (
                    real_re[slot].get("kwh").get("sum")
                    if type(real_re[slot].get("kwh").get("sum")) in [int, float]
                    else ""
                )
            else:
                real_value = real_re[slot].get("p").get("sum")
                real_power_val = (
                    real_re[slot].get("kwh").get("sum")
                    if type(real_re[slot].get("kwh").get("sum")) in [int, float]
                    else ""
                )
            real_power_list.append(real_power_val)
            # 值为0是正常数据
            if real_value in [0, 0.0]:
                real_list.append(0.0)
            else:
                real_list.append(round(real_value, 2) if real_value else "")
        else:
            real_list.append("")
            real_power_list.append("")
    lv_real.value = real_list
    # 3. 预测数据
    pred_index = COMPANY_DAY_AHEAD_PRE
    log.info(pred_index + f"====={real_query}")
    async with EsUtil() as es:
        pred_re = await es.search_origin(body=real_query, index=pred_index)
    pred_re = pred_re["aggregations"]["quarter_time"]["buckets"]
    if not pred_re:
        return ForecastResp()
    if trans_type == "day":
        fmt = "MM-DD"
    else:
        fmt = "HH:mm"
    pred_re = es_process(pred_re, fmat=fmt)
    lv_pred = LoadValue()
    lv_pred.slots = slots
    pred_list = []
    for slot in slots:
        if slot in pred_re:
            if trans_type == "day":
                pred_value = pred_re[slot].get("p").get("sum")
                pred_power_val = (
                    round(pred_re[slot].get("p").get("sum") * 0.25, 2)
                    if type(pred_re[slot].get("p").get("sum")) in [int, float]
                    else ""
                )
            else:
                pred_value = pred_re[slot].get("p").get("sum")
                pred_power_val = (
                    round(pred_re[slot].get("p").get("sum") * 0.25, 2)
                    if type(pred_re[slot].get("p").get("sum")) in [int, float]
                    else ""
                )
            pred_power_list.append(pred_power_val)
            # 值为0是正常数据
            if pred_value in [0, 0.0]:
                pred_list.append(0.0)
            else:
                pred_list.append(round(pred_value, 2) if pred_value else "")
        else:
            pred_list.append("")
            pred_power_list.append("")
    lv_pred.value = pred_list
    # 4.求偏差数据
    deviation_list = []  # 偏差
    deviation_list_abs = []  # 偏差取绝对值, 最大/最小/平均偏差都是绝对值后数据
    for num, value in enumerate(real_list):
        if not value or not pred_list[num]:
            deviation_list.append("")
        else:
            # (预测-实际)/实际 * 100%
            deviation = (pred_list[num] - value) / value
            deviation_list.append(round(deviation, 4))
            deviation_list_abs.append(abs(round(deviation, 4)))
    # 取绝对值,并保留
    deviation_list_tmp = [i for i in deviation_list_abs if i != ""]
    log.info(f"deviation_list_tmp:{deviation_list_tmp}, "
             f"deviation_list_abs:{deviation_list_abs}")
    if not deviation_list_tmp:
        return ForecastResp(
            pred_data=lv_pred,
            real_data=lv_real,
            deviation_list=deviation_list_abs,
            max_deviation=[],
            min_deviation=[],
            avg_deviation="",
            total_deviation="",
            real_power_list=LoadValue(slots=slots, value=real_power_list),
            pred_power_list=LoadValue(slots=slots, value=pred_power_list),
            deviation_power_list=[],
        )

    count, maxnum, minnum, average, max_index, min_index = choose_list(deviation_list_tmp, 4)
    # 求最大偏差和最小偏差时间
    max_t = slots[max_index]
    min_t = slots[min_index]
    if trans_type == "day":
        max_time = start[:4] + "-" + max_t
        min_time = start[:4] + "-" + min_t
    else:
        max_time = start.split(" ")[0] + " " + max_t
        min_time = start.split(" ")[0] + " " + min_t
    # 最大偏差
    max_deviation = [maxnum, max_time]
    # 最小偏差
    min_deviation = [minnum, min_time]
    # 平均偏差
    avg_deviation = average
    # 总量偏差 = | (预测(总) - 实际(总)) / 实际(总) | * 100%
    real_tmp = [i for i in real_list if i != ""]
    pred_tmp = pred_list[: len(real_tmp)]
    pred_tmp = [i for i in pred_tmp if i != ""]
    total_deviation = abs((sum(pred_tmp) - sum(real_tmp)) / sum(real_tmp))
    total_deviation = round(total_deviation, 4)

    deviation_power_list = []
    for index, real_power in enumerate(real_power_list):
        if real_power == "" or pred_power_list[index] == "" or real_power == 0:
            deviation_power_list.append("")
        else:
            # (预测-实际)/实际 * 100%
            deviation = (pred_power_list[index] - real_power) / real_power
            deviation_power_list.append(abs(round(deviation, 4)))

    if (
        start.split(" ")[0].rsplit("-", 1)[0] == str(datetime.now().date()).rsplit("-", 1)[0]
        and trans_type == "day"
    ):
        # 如果是本月的,那么当天的电量偏差没有意义, 置为空
        today_str = str(datetime.now().date()).split("-", 1)[1]
        index = slots.index(today_str)
        deviation_power_list[index] = ""

    return ForecastResp(
        pred_data=lv_pred,
        real_data=lv_real,
        deviation_list=deviation_list_abs,
        max_deviation=max_deviation,
        min_deviation=min_deviation,
        avg_deviation=avg_deviation,
        total_deviation=total_deviation,
        real_power_list=LoadValue(slots=slots, value=real_power_list),
        pred_power_list=LoadValue(slots=slots, value=pred_power_list),
        deviation_power_list=deviation_power_list,
    )


async def load_forecast_service_new15(cid, cids, start, end):
    terms = cids if cids else [cid]
    time_type = interval_type("range", start, end)
    if time_type == "day":
        table_name = "company_1day_power"
        fmt = "%m-%d"
    else:
        table_name = "company_15min_power"
        fmt = "%H:%M"
    datas = await get_kwh_p_dao(table_name, terms, start, end)
    if not datas:
        return ForecastResp()
    # 获取slots
    intervel, slots = time_pick_transf(start, end)
    real_re = sql_time_process(datas, "create_time", fmt)
    lv_real = LoadValue()
    lv_real.slots = slots
    real_list = []
    real_power_list, pred_power_list = [], []
    for slot in slots:
        if slot in real_re:
            real_value = real_re[slot].get("p")
            real_power_val = (
                real_re[slot].get("kwh")
                if type(real_re[slot].get("kwh")) in [int, float] else ""
            )
            real_power_list.append(real_power_val)
            # 值为0是正常数据
            if real_value in [0, 0.0]:
                real_list.append(0.0)
            else:
                real_list.append(round(real_value, 2) if real_value else "")
        else:
            real_list.append("")
            real_power_list.append("")
    lv_real.value = real_list
    # 3. 预测数据
    pred_re = await get_pred_p_dao(terms, start, end)
    if not pred_re:
        return ForecastResp()
    pred_re = sql_time_process(pred_re, "create_time", fmt)
    lv_pred = LoadValue()
    lv_pred.slots = slots
    pred_list = []
    for slot in slots:
        if slot in pred_re:
            pred_value = pred_re[slot].get("p")
            pred_power_val = (
                round(pred_re[slot].get("p") * 0.25, 2)
                if type(pred_re[slot].get("p")) in [int, float] else ""
            )
            pred_power_list.append(pred_power_val)
            # 值为0是正常数据
            if pred_value in [0, 0.0]:
                pred_list.append(0.0)
            else:
                pred_list.append(round(pred_value, 2) if pred_value else "")
        else:
            pred_list.append("")
            pred_power_list.append("")
    lv_pred.value = pred_list
    # 4.求偏差数据
    deviation_list = []  # 偏差
    deviation_list_abs = []  # 偏差取绝对值, 最大/最小/平均偏差都是绝对值后数据
    for num, value in enumerate(real_list):
        if not value or not pred_list[num]:
            deviation_list.append("")
        else:
            # (预测-实际)/实际 * 100%
            deviation = (pred_list[num] - value) / value
            deviation_list.append(round(deviation, 4))
            deviation_list_abs.append(abs(round(deviation, 4)))
    # 取绝对值,并保留
    deviation_list_tmp = [i for i in deviation_list_abs if i != ""]
    log.info(f"deviation_list_tmp:{deviation_list_tmp}, "
             f"deviation_list_abs:{deviation_list_abs}")
    if not deviation_list_tmp:
        return ForecastResp(
            pred_data=lv_pred,
            real_data=lv_real,
            deviation_list=deviation_list_abs,
            max_deviation=[],
            min_deviation=[],
            avg_deviation="",
            total_deviation="",
            real_power_list=LoadValue(slots=slots, value=real_power_list),
            pred_power_list=LoadValue(slots=slots, value=pred_power_list),
            deviation_power_list=[],
        )

    count, maxnum, minnum, average, max_index, min_index = choose_list(
        deviation_list_tmp, 4)
    # 求最大偏差和最小偏差时间
    max_t = slots[max_index]
    min_t = slots[min_index]
    if time_type == "day":
        max_time = start[:4] + "-" + max_t
        min_time = start[:4] + "-" + min_t
    else:
        max_time = start.split(" ")[0] + " " + max_t
        min_time = start.split(" ")[0] + " " + min_t
    # 最大偏差
    max_deviation = [maxnum, max_time]
    # 最小偏差
    min_deviation = [minnum, min_time]
    # 平均偏差
    avg_deviation = average
    # 总量偏差 = | (预测(总) - 实际(总)) / 实际(总) | * 100%
    real_tmp = [i for i in real_list if i != ""]
    pred_tmp = pred_list[: len(real_tmp)]
    pred_tmp = [i for i in pred_tmp if i != ""]
    total_deviation = abs((sum(pred_tmp) - sum(real_tmp)) / sum(real_tmp))
    total_deviation = round(total_deviation, 4)

    deviation_power_list = []
    for index, real_power in enumerate(real_power_list):
        if real_power == "" or pred_power_list[index] == "" or real_power == 0:
            deviation_power_list.append("")
        else:
            # (预测-实际)/实际 * 100%
            deviation = (pred_power_list[index] - real_power) / real_power
            deviation_power_list.append(abs(round(deviation, 4)))

    if (
            start.split(" ")[0].rsplit("-", 1)[0] ==
            str(datetime.now().date()).rsplit("-", 1)[0]
            and time_type == "day"
    ):
        # 如果是本月的,那么当天的电量偏差没有意义, 置为空
        today_str = str(datetime.now().date()).split("-", 1)[1]
        index = slots.index(today_str)
        deviation_power_list[index] = ""

    return ForecastResp(
        pred_data=lv_pred,
        real_data=lv_real,
        deviation_list=deviation_list_abs,
        max_deviation=max_deviation,
        min_deviation=min_deviation,
        avg_deviation=avg_deviation,
        total_deviation=total_deviation,
        real_power_list=LoadValue(slots=slots, value=real_power_list),
        pred_power_list=LoadValue(slots=slots, value=pred_power_list),
        deviation_power_list=deviation_power_list,
    )


