# -*- coding:utf-8 -*-
#
# Author:jing
# Date: 2020/7/9
import json
from dataclasses import fields
from datetime import datetime

from pot_libs.mysql_util.mysql_util import MysqlUtil
from pot_libs.sanic_api import summary, skip_validate, description, examples
from pot_libs.utils.exc_util import DBException
from unify_api import constants
from unify_api.modules.electric.procedures.electric_util import get_wiring_type
from unify_api.utils import time_format
from pot_libs.es_util.es_utils import EsUtil
from pot_libs.es_util.es_query import EsQuery
from pot_libs.es_util import es_helper
from pot_libs.logger import log
from pot_libs.common.components.query import (
    PageRequest,
    Equal,
    Filter,
    Range
)
from unify_api.modules.alarm_manager.components.list_alarm import (
    QueryDetails,
    ScopeDetailsResponse,
    TempTrendResponse,
    ScopeContent,
    temp_trend_example,
    scope_details_example
)


@summary("波形分析")
@description("后端返回全部波形数据,客户端筛选展示")
@examples(scope_details_example)
async def get_scope_details(req, query: QueryDetails) -> ScopeDetailsResponse:
    # 1.根据es_id查询point_1min_event对应的point属性
    es_id = Equal(field="_id", value=query.doc_id)
    filter = Filter(equals=[es_id], ranges=[], in_groups=[], keywords=[])
    query_request = PageRequest(page_size=1, page_num=1, filter=filter,
                                sort=None)
    query_body = EsQuery.query(query_request)
    try:
        async with EsUtil() as es:
            es_results = await es.search_origin(
                body=query_body,
                index=constants.POINT_1MIN_EVENT)
    except:
        log.error("es query error")
        return ScopeDetailsResponse().db_error()
    
    try:
        es_results = es_results["hits"]["hits"][0]
        source = es_results.get("_source", {})
        point_name = source.get("name")
        point_id = source.get("point_id")
        cmp_time = source.get("time")
    except Exception:
        log.warning("can not find data on es(index: %s): %s" % (
            constants.POINT_1MIN_SCOPE, query_body))
        raise DBException
    
    # 2.用point_1min_event的ponint_id、cmp_time查询point_1min_scope
    point_equal = Equal(field="point_id", value=point_id)
    time_equal = Equal(field="time", value=cmp_time)
    filter = Filter(equals=[point_equal, time_equal], ranges=[], in_groups=[],
                    keywords=[])
    query_request = PageRequest(page_size=1, page_num=1, filter=filter,
                                sort=None)
    query_body = EsQuery.query(query_request)
    try:
        async with EsUtil() as es:
            es_results = await es.search_origin(
                body=query_body,
                index=constants.POINT_1MIN_SCOPE)
    except Exception as e:
        log.error(f"es query error {e}")
        return ScopeDetailsResponse().db_error()
    
    if not es_results:
        log.warning("can not find data on es(index: %s): %s" % (
            constants.POINT_1MIN_SCOPE, query_body))
        return ScopeDetailsResponse().db_error()
    
    es_results = es_results["hits"]["hits"][0]
    fault_type = es_results.get("_source", {}).get("fault_type")
    fault_type = constants.EVENT_TYPE_MAP.get(fault_type)
    date_time = time_format.esstr_to_dtstr(
        es_results["_source"].get("datetime"))
    position = list(json.loads(es_results["_source"].get("result")).keys())
    position_tmp = [
        str + "相" if len(str) == 1 else str + "线" if len(str) == 2 else '' for
        str in position]
    position_tmp = [p for p in position_tmp if p]  # 去掉空字符串
    position_str = "|".join(position_tmp)
    
    # 3.曲线数据
    res_dic = json.loads(es_results["_source"].get("result"))
    location = [d for d in res_dic.values()][0].get("location")
    wave_data = json.loads(es_results["_source"].get("context"))
    if query.wave_range == "100ms":
        if location <= 200:
            for key, value in wave_data.items():
                wave_data[key] = value[:400]
        elif location >= 1400:
            for key, value in wave_data.items():
                wave_data[key] = value[1200:]
        else:
            for key, value in wave_data.items():
                wave_data[key] = value[location - 200:location + 200]
    scope_content = ScopeContent(
        **{k: v for k, v in wave_data.items() if
           k in [field.name for field in fields(ScopeContent)]}
    )
    # 4.接线法:二表/三表
    ctnum, _ = await get_wiring_type(point_id)
    if not ctnum:
        ctnum = 3
    return ScopeDetailsResponse(
        ctnum=ctnum,
        group=point_name,
        item=position_str,
        type=fault_type,
        date_time=date_time,
        location=location,
        contents=scope_content
    )


@skip_validate
@summary("获取温度趋势")
@examples(temp_trend_example)
async def get_temp_trend(req, query: QueryDetails) -> TempTrendResponse:
    # 1.根据point_1min_event的es_id查询location_id
    equal = Equal(field="_id", value=query.doc_id)
    filter = Filter(equals=[equal], ranges=[],
                    in_groups=[], keywords=[])
    query_request = PageRequest(page_size=1, page_num=1, filter=filter,
                                sort=None)
    query_body = EsQuery.query(query_request)
    try:
        async with EsUtil() as es:
            es_results = await es.search_origin(
                body=query_body, index=constants.POINT_1MIN_EVENT)
    except:
        log.error("es error")
        return TempTrendResponse().db_error()
    
    if not es_results:
        log.warning("can not find data on es(index: %s): %s" % (
            constants.POINT_1MIN_EVENT, query_body))
        return TempTrendResponse().db_error()
    
    # 温度趋势页赋值
    try:
        source = es_results["hits"]["hits"][0].get("_source")
    except:
        log.warning("can not find data on es(index: %s): %s" % (
            constants.POINT_1MIN_EVENT, query_body))
        return TempTrendResponse().db_error()
    
    group_name = source.get("name")
    date_time = source.get("datetime")
    description = source.get("message")
    location_id = source.get("location_id")
    item = ""
    if location_id:
        async with MysqlUtil() as conn:
            location_sql = "SELECT item FROM location WHERE id = %s"
            location = await conn.fetchone(location_sql, args=(location_id,))
            item = location["item"] or ""
    
    # 2.获取时间轴
    # 获取date_time对应时间的起点(0点)、终点(23点)
    start_str, end_str = time_format.get_start_end_by_tz_time(date_time)
    intervel, slots = time_format.time_pick_transf(start_str, end_str)
    
    # 3.获取温度实时数据,查看今天从0点到现在的所有15min
    equal = Equal(field="location_id", value=location_id)
    start_ts = time_format.get_date_timestamp(start_str)
    end_ts = time_format.get_date_timestamp(end_str)
    range = Range(field="time", start=start_ts, end=end_ts)
    filter = Filter(equals=[equal], ranges=[range], in_groups=[], keywords=[])
    page_request = PageRequest(page_size=1, page_num=1, sort=None,
                               filter=filter)
    query_body_realdata = EsQuery.aggr_history(page_request,
                                               interval=intervel,
                                               stats_items=["value"])
    try:
        async with EsUtil() as es:
            es_results = await es.search_origin(
                body=query_body_realdata,
                index=constants.LOCATION_1MIN_AIAO)
    except:
        log.error("es error")
        return TempTrendResponse().db_error()

    date_time = datetime.strptime(date_time, "%Y-%m-%dT%H:%M:%S+08:00")
    date_time = datetime.strftime(date_time, "%Y-%m-%d %H:%M:%S")
    if not es_results:
        log.warning("can not find data on es(index: %s): %s" % (
            constants.LOCATION_1MIN_AIAO, query_body_realdata))
        return TempTrendResponse(group=group_name,
                                 item=item,
                                 date_time=date_time,
                                 description=description,
                                 time_slots=slots,
                                 realtime=[],
                                 daily=[],
                                 quarterly=[]
                                 )
    
    aggs_res = es_results.get('aggregations', {})
    real_data_buckets = aggs_res.get('aggs_name', {}).get("buckets", [])
    real_data_tmp = es_helper.process_es_aggs_aiao(real_data_buckets,
                                                   time_key="key",
                                                   value_key="value")
    # 4.获取温度实时预测和日前预测
    start_dt = time_format.convert_to_dt(start_str)
    end_dt = time_format.convert_to_dt(end_str)
    start_time = time_format.convert_dt_to_str(start_dt, date_type="tz")
    end_time = time_format.convert_dt_to_str(end_dt, date_type="tz")
    equal = Equal(field="location_id", value=location_id)
    range = Range(field="quarter", start=start_time, end=end_time)
    filter = Filter(equals=[equal], ranges=[range], in_groups=[], keywords=[])
    page_request = PageRequest(page_size=100, page_num=1,
                               filter=filter, sort=None)
    query_body_trend = EsQuery.query(page_request)
    try:
        async with EsUtil() as es:
            es_results = await es.search_origin(
                body=query_body_trend,
                index=constants.LOCATION_TEMP_TREND)
    except:
        log.error("es error")
        return TempTrendResponse().db_error()
    
    if not es_results:
        log.warning("can not find data on es(index: %s): %s" % (
            constants.LOCATION_TEMP_TREND, query_body_realdata))
        return TempTrendResponse(group=group_name,
                                 item=item,
                                 date_time=date_time,
                                 description=description,
                                 time_slots=slots,
                                 realtime=[],
                                 daily=[],
                                 quarterly=[]
                                 )
    
    trend_data = es_results['hits']['hits']
    trend_data_tmp = es_helper.process_es_data_aiao(trend_data, key="quarter")
    
    realtime = []  # 实时温度
    daily = []  # 实时预测
    quarterly = []
    for slot_index in slots:
        # 实时数据
        if slot_index in real_data_tmp:
            value = real_data_tmp[slot_index].get("avg")
            value = round(value, 2) if value is not None else ""
            realtime.append(value)
        else:
            realtime.append('')
        # 预测数据
        if slot_index in trend_data_tmp:
            # 日前预测的值
            day_value = trend_data_tmp[slot_index]["_source"].get(
                "overTempTrendDaily")
            day_value = round(day_value, 2) if day_value is not None else ""
            daily.append(day_value)
            # 实时预测的值
            real_value = trend_data_tmp[slot_index]["_source"].get(
                "overTempTrendQuarterly")
            real_value = round(real_value, 2) if real_value is not None else ""
            quarterly.append(real_value)
        else:
            daily.append('')
            quarterly.append('')
    
    return TempTrendResponse(
        group=group_name,
        item=item,
        date_time=date_time,
        description=description,
        time_slots=slots,
        realtime=realtime,
        daily=daily,
        quarterly=quarterly
    )