import pandas as pd
from collections import defaultdict
from datetime import datetime

from pot_libs.es_util.es_utils import EsUtil
from pot_libs.logger import log
from pot_libs.mysql_util.mysql_util import MysqlUtil
from unify_api import constants


async def location_stats_statics(
        cid, pid, start_timestamp, end_timestamp, _type="residual_current"
):
    groups = {}
    if _type == "residual_current":
        args = (cid, "residual_current")
    else:
        args = (cid, "temperature")
    async with MysqlUtil() as conn:
        sql = "SELECT mtid, id, `group`, item, `type` FROM location WHERE cid=%s and `type`=%s"
        locations = await conn.fetchall(sql, args=args)
        for location in locations:
            group = location.get("mtid")
            groups.setdefault(group, []).append(location)

    location_map = {}
    async with MysqlUtil() as conn:
        sql = "SELECT pid, name, mtid FROM point WHERE pid=%s"
        point = await conn.fetchone(sql, args=(pid,))
        if point:
            # point_name = point["name"]
            mtid = point["mtid"]
            locations = groups.get(mtid, [])
            for l in locations:
                location_map.setdefault(l["item"], []).append(l["id"])

    if not location_map:
        log.info(f"{pid}无任何location_id")
        return {}

    all_location_ids = []
    for item, item_location_ids in location_map.items():
        all_location_ids.extend(item_location_ids)

    aggs = {
        "value_avg": {"avg": {"field": "value_avg"}},
        "value_max": {"top_hits": {"sort": [{"value_max": {"order": "desc"}}], "size": 1}},
        "value_min": {"top_hits": {"sort": [{"value_min": {"order": "asc"}}], "size": 1}},
    }
    query_body = {
        "query": {
            "bool": {
                "filter": [
                    {"term": {"cid": cid}},
                    {"term": {"type.keyword": _type}},
                    {"terms": {"location_id": all_location_ids}},
                    {"range": {"time": {"gte": start_timestamp, "lte": end_timestamp,}}},
                ]
            }
        },
        "size": 0,
        "aggs": {},
    }
    if _type == "residual_current":
        info_map = defaultdict(dict)
        query_body["aggs"] = aggs
        async with EsUtil() as es:
            print("------debug", query_body)
            es_results = await es.search_origin(body=query_body, index=constants.LOCATION_15MIN_AIAO)

            value_min_hits = (
                es_results.get("aggregations", {})
                .get("value_min", {})
                .get("hits", {})
                .get("hits", [])
            )
            if value_min_hits:
                min_dt = value_min_hits[0]["_source"]["value_min_time"]
                info_map["漏电流"]["value_min"] = {
                    "value": value_min_hits[0]["_source"]["value_min"],
                    "time": str(datetime.strptime(min_dt, "%Y-%m-%dT%H:%M:%S+08:00")),
                }

            value_max_hits = (
                es_results.get("aggregations", {})
                .get("value_max", {})
                .get("hits", {})
                .get("hits", [])
            )
            if value_max_hits:
                max_dt = value_max_hits[0]["_source"]["value_max_time"]
                info_map["漏电流"]["value_max"] = {
                    "value": value_max_hits[0]["_source"]["value_max"],
                    "time": str(datetime.strptime(max_dt, "%Y-%m-%dT%H:%M:%S+08:00")),
                }

            info_map["漏电流"]["value_avg"] = (
                es_results.get("aggregations", {}).get("value_avg", {}).get("value")
            )

        return info_map

    elif _type == "temperature":
        for item, location_ids in location_map.items():
            # 温度，漏电流
            query_body["aggs"][f"{item}_aggs"] = {
                "filter": {"terms": {"location_id": location_ids}},
                "aggs": aggs,
            }

        async with EsUtil() as es:
            es_results = await es.search_origin(body=query_body, index=constants.LOCATION_15MIN_AIAO)

            info_map = defaultdict(dict)
            if es_results:
                for item in location_map.keys():
                    value_min_hits = (
                        es_results.get("aggregations", {})
                        .get(f"{item}_aggs")
                        .get("value_min", {})
                        .get("hits", {})
                        .get("hits", [])
                    )
                    if value_min_hits:
                        min_dt = value_min_hits[0]["_source"]["value_min_time"]
                        info_map[f"{item}温度"]["value_min"] = {
                            "value": value_min_hits[0]["_source"]["value_min"],
                            "time": str(datetime.strptime(min_dt, "%Y-%m-%dT%H:%M:%S+08:00")),
                        }

                    value_max_hits = (
                        es_results.get("aggregations", {})
                        .get(f"{item}_aggs")
                        .get("value_max", {})
                        .get("hits", {})
                        .get("hits", [])
                    )
                    if value_max_hits:
                        max_dt = value_max_hits[0]["_source"]["value_max_time"]
                        info_map[f"{item}温度"]["value_max"] = {
                            "value": value_max_hits[0]["_source"]["value_max"],
                            "time": str(datetime.strptime(max_dt, "%Y-%m-%dT%H:%M:%S+08:00")),
                        }

                    info_map[f"{item}温度"]["value_avg"] = (
                        es_results.get("aggregations", {})
                        .get(f"{item}_aggs")
                        .get("value_avg", {})
                        .get("value")
                    )
            return info_map


async def location_stats_statics_new15(table_name, cid, start, end):
    sql = "SELECT mtid,lid,item,ad_type FROM location WHERE cid=%s"
    location_map = {}
    async with MysqlUtil() as conn:
        locations = await conn.fetchall(sql, args=(cid, ))
    if not locations:
        return location_map
    for loca in locations:
        location_map[loca["lid"]] = loca
    datas_sql = f"SELECT * from {table_name} where lid in %s and create_time" \
                f" BETWEEN '{start}' and '{end}' order by create_time desc"
    lids = list(location_map.keys())
    async with MysqlUtil() as conn:
        results = await conn.fetchall(datas_sql, args=(lids, ))
    if not results:
        return {}
    df = pd.DataFrame(list(results))
    for lid in lids:
        max_value = df.loc[df["lid"]==lid].value_max.max()
        if not pd.isna(max_value):
            max_datas = df.loc[df.loc[df["lid"]==lid].value_max.idxmax()].to_dict()
            max_value_time = max_datas.get("value_max_time")
            max_value_time = '' if pd.isnull(max_value_time) else str(
                max_value_time)
            max_value = round(max_value, 2)
        else:
            max_value, max_value_time = "", ""
        min_value = df.loc[df["lid"] == lid].value_min.min()
        if not pd.isna(min_value):
            min_datas = df.loc[df.loc[df["lid"]==lid].value_min.idxmin()].to_dict()
            min_value_time = min_datas.get("value_min_time")
            min_value_time = '' if pd.isnull(min_value_time) else str(
                min_value_time)
            min_value = round(min_value, 2)
        else:
            min_value, min_value_time = "", ""
        mean_value = df.loc[df["lid"] == lid].value_avg.mean()
        if not pd.isna(mean_value):
            mean_value = round(mean_value, 2) if mean_value else ""
        else:
            mean_value = ''
        location_map[lid].update({
            "max_value": max_value, "max_value_time": max_value_time,
            "min_value": min_value, "min_value_time": min_value_time,
            "mean_value": mean_value,
        })
    return location_map
