from pot_libs.utils.exc_util import BusinessException
from unify_api.constants import SLOTS
from unify_api.modules.elec_charge.dao.elec_charge_dao import \
    histogram_aggs_points
from unify_api.modules.tsp_water.dao.drop_dust_dao import sum_water_group
from unify_api.modules.tsp_water.dao.tsp_dao import tsp_histogram_avg
from unify_api.utils import time_format
from unify_api.utils.common_utils import round_2n
from unify_api.utils.es_query_body import es_process
from pot_libs.mysql_util.mysql_util import MysqlUtil


async def per_hour_wave(start, end, tsp_id=None):
    """PM2.5/PM10/TSP每小时或每天曲线数据"""
    interval, slots = time_format.time_pick_transf(start, end)
    if interval == 24 * 3600:
        interval = "day"
        fmt = "MM-DD"
    # 需求是每小时一个点
    elif interval == 15 * 60:
        slots = SLOTS["day"]
        interval = "hour"
        fmt = "HH:mm"
    else:
        raise BusinessException(message="time range not day or month")
    # 1. 查询es
    es_res = await tsp_histogram_avg(start, end, interval, tsp_id)
    es_dic = es_process(es_res, fmat=fmt)
    # 2. 组装数据
    pm25_list = []
    pm10_list = []
    tsp_list = []
    for slot in slots:
        if slot in es_dic:
            pm25_value = round_2n(es_dic[slot]["pm25"].get("value"))
            pm10_value = round_2n(es_dic[slot]["pm10"].get("value"))
            tsp_value = round_2n(es_dic[slot]["tsp"].get("value"))
        else:
            pm25_value, pm10_value, tsp_value = None, None, None
        pm25_list.append(pm25_value)
        pm10_list.append(pm10_value)
        tsp_list.append(tsp_value)
    return pm25_list, pm10_list, tsp_list, slots


async def per_hour_wave_new15(start, end, tsp_id=None):
    interval, slots = time_format.time_pick_transf(start, end)
    mid_sql = f"tsp_id={tsp_id} and" if tsp_id else ""
    if interval == 24 * 3600:
        sql = f'SELECT DATE_FORMAT(create_time,"%m-%d") date_time, ' \
              f'AVG(pm25_mean) pm25,AVG(pm10_mean) pm10,AVG(tsp_mean) tsp ' \
              f'FROM `tsp_day_record` where {mid_sql} ' \
              f'create_time BETWEEN "{start}" and "{end}" GROUP BY date_time ' \
              f'ORDER BY date_time'
    elif interval == 15 * 60:
        sql = f'SELECT DATE_FORMAT(create_time,"%%H:00") date_time, ' \
              f'AVG(pm25_mean) pm25,AVG(pm10_mean) pm10,AVG(tsp_mean) tsp ' \
              f'FROM `tsp_15min_record` where {mid_sql} ' \
              f'create_time BETWEEN "{start}" and "{end}" GROUP BY date_time ' \
              f'ORDER BY date_time'
    else:
        raise BusinessException(message="time range not day or month")
    async with MysqlUtil() as conn:
        datas = await conn.fetchall(sql)
    datas_map = {data["date_time"]: data for data in datas}
    pm25_list = []
    pm10_list = []
    tsp_list = []
    for slot in slots:
        slot_data = datas_map.get(slot)
        if slot_data:
            pm25_value = round_2n(slot_data.get("pm25"))
            pm10_value = round_2n(slot_data.get("pm10"))
            tsp_value = round_2n(slot_data.get("tsp"))
        else:
            pm25_value, pm10_value, tsp_value = None, None, None
        pm25_list.append(pm25_value)
        pm10_list.append(pm10_value)
        tsp_list.append(tsp_value)
    return pm25_list, pm10_list, tsp_list, slots


async def per_hour_kwh_wave(start, end, tsp_id_list=None):
    """每小时或每天电量曲线数据"""
    interval, slots = time_format.time_pick_transf(start, end)
    if interval == 24 * 3600:
        interval = "day"
        fmt = "MM-DD"
    # 需求是每小时一个点
    elif interval == 15 * 60:
        slots = SLOTS["day"]
        interval = "hour"
        fmt = "HH:mm"
    else:
        raise BusinessException(message="time range not day or month")
    # 1. 查询es
    es_res = await histogram_aggs_points(start, end, tsp_id_list, interval)
    es_dic = es_process(es_res, fmat=fmt)
    # 2. 组装数据
    kwh_list = []
    for slot in slots:
        if slot in es_dic:
            kwh_value = round_2n(es_dic[slot]["kwh"].get("value"))
        else:
            kwh_value = None
        kwh_list.append(kwh_value)
    return kwh_list, slots


async def per_hour_kwh_wave_new15(start, end, pids):
    interval, slots = time_format.time_pick_transf(start, end)
    if interval == 24 * 3600:
        sql = f'SELECT DATE_FORMAT(create_time,"%%m-%%d") date_time, ' \
              f'sum(kwh) kwh,sum(charge) charge,sum(p) p ' \
              f'FROM `point_1day_power` where pid in %s and ' \
              f'create_time BETWEEN "{start}" and "{end}" GROUP BY date_time' \
              f' ORDER BY date_time'
    elif interval == 15 * 60:
        sql = f'SELECT DATE_FORMAT(create_time,"%%H:00") date_time, ' \
              f'sum(kwh) kwh,sum(charge) charge,sum(p) p ' \
              f'FROM `point_15min_power` where pid in %s and ' \
              f'create_time BETWEEN "{start}" and "{end}" GROUP BY date_time' \
              f' ORDER BY date_time'
    else:
        raise BusinessException(message="time range not day or month")
    async with MysqlUtil() as conn:
        datas = await conn.fetchall(sql, args=(pids, ))
    datas_map = {data["date_time"]: data for data in datas}
    # 2. 组装数据
    kwh_list = []
    for slot in slots:
        slot_data = datas_map.get(slot)
        kwh_value = round_2n(slot_data.get("kwh")) if slot_data else None
        kwh_list.append(kwh_value)
    return kwh_list, slots


async def per_hour_water_wave(start, end, tsp_id_list=None):
    """每小时或每天水量曲线数据"""
    interval, slots = time_format.time_pick_transf(start, end)
    if interval == 24 * 3600:
        date_type = "month"
    # 需求是每小时一个点
    elif interval == 15 * 60:
        slots = SLOTS["day"]
        date_type = "day"
    else:
        raise BusinessException(message="time range not day or month")
    # 1. 查询mysql
    water_info = await sum_water_group(start, end, date_type=date_type)
    # 2. 组装数据
    water_list = [None for _ in range(len(slots))]
    for index, info in enumerate(water_info):
        water_list[index] = round_2n(info["water"])
    return water_list
