# -*- coding: utf-8 -*-
"""
calc user side energy store optimation for companies.
This is a tornado process and  responds request from web back server.
"""

import pendulum
import datetime
import pandas as pd
from pot_libs.es_util.es_utils import EsUtil
from pot_libs.mysql_util.mysql_util import MysqlUtil
from unify_api.constants import ENERGY_INVESTMENT_ESS, \
    INLINE_15MIN_POWER_ESINDEX, INLINE_1DAY_POWER_ESINDEX
from unify_api.modules.energy_optimize.service.ess_optimation_tool import \
    EssOptimizationTool
from unify_api.modules.energy_optimize.service.ess_utils import \
    PricePolicyHelper


class AutoDic(dict):
    def __getitem__(self, item):
        try:
            return dict.__getitem__(self, item)
        except KeyError:
            value = self[item] = type(self)()
            return value


class EnergyStoreOptimize(object):

    def __init__(self, inlid):
        self._inlid = inlid
        # self._r_cache = redis.Redis(host="172.18.1.253", db=1)

    async def calc_inline(self, ess_system):
        rlt = {'rlt_flag': True}
        inl_info = await self._get_inline_info()
        inline_vc = inl_info['inline_vc']
        cid = inl_info['cid_belongedto']
        pp = await self._get_company_price_policy(cid)
        max_dt = await self._find_kwh_max_day()  # 00:00:00 of the max kwh day
        if not max_dt:
            rlt['rlt_flag'] = False
            rlt['message'] = '暂无'
            return rlt
        pp_info_d = PricePolicyHelper.map_price_policy(pp, inline_vc,
                                                       max_dt.int_timestamp)
        time_str_d = PricePolicyHelper.quarter_chars_2_time_str(
            pp_info_d['quarters'])

        # construct inline_var
        inline_var = {'inline_capacity': inl_info['tc_runtime'],
                      'capacity_price': pp_info_d['price_tc'],
                      'max_demand_price': pp_info_d['price_md']}
        if pp_info_d['price_s']:
            sct = await self._contruct_section('s', pp_info_d, time_str_d, max_dt)
            inline_var['section_s'] = sct
        if pp_info_d['price_p']:
            sct = await self._contruct_section('p', pp_info_d, time_str_d, max_dt)
            inline_var['section_p'] = sct
        if pp_info_d['price_f']:
            sct = await self._contruct_section('f', pp_info_d, time_str_d, max_dt)
            inline_var['section_f'] = sct
        if pp_info_d['price_v']:
            sct = await self._contruct_section('v', pp_info_d, time_str_d, max_dt)
            inline_var['section_v'] = sct

        # contruct df_curve
        df_curve = await self._get_kw_curve(max_dt)

        # handle return
        if len(df_curve) == 0:
            rlt['rlt_flag'] = False
            rlt['message'] = '暂无'
            return rlt
        eot = EssOptimizationTool(inline_var, ess_system, df_curve)
        eot.output()
        if not eot.flag:
            rlt['rlt_flag'] = False
            rlt['message'] = '无存储优化空间'
            return rlt
        # really return something
        rlt['opt_analysis'] = eot.opt_analysis
        rlt['capacity'] = eot.capacity
        rlt['economic_evaluate'] = self.convert_economic_evaluate(
            eot.economic_evaluate)
        rlt['opt_curve'] = self.convert_opt_curve(eot.opt_curve)
        return rlt

    def convert_economic_evaluate(self, economic_evaluate):
        invest_income_table = economic_evaluate['invest_income_table']
        table = []
        for idx, row in invest_income_table.iterrows():
            tmp_d = {}
            tmp_d['rate_of_investment'] = idx
            tmp_d['5'] = row['5']
            tmp_d['6'] = row['6']
            tmp_d['7'] = row['7']
            tmp_d['8'] = row['8']
            tmp_d['9'] = row['9']
            tmp_d['10'] = row['10']
            table.append(tmp_d)
        economic_evaluate['invest_income_table'] = table
        return economic_evaluate

    def convert_opt_curve(self, opt_curve):
        rlt = []
        for idx, row in opt_curve.iterrows():
            tmp_d = {}
            tmp_d['quarter_time'] = idx
            tmp_d['load_curve'] = row['load_curve']
            tmp_d['bat_curve'] = row['bat_curve']
            tmp_d['load_bat_curve'] = row['load_bat_curve']
            rlt.append(tmp_d)
        return rlt

    async def _contruct_section(self, p_char, pp_info_d, time_str_d, max_dt):
        """ contruct section_x for inline_var."""
        section = {'price': pp_info_d['price_' + p_char]}
        kwh = await self._get_total_kwh_of_one_pchar(p_char, max_dt)
        section['ttl_kwh'] = kwh
        time_range_str = ';'.join(time_str_d[p_char])
        section['time_range'] = time_range_str
        return section

    async def _get_inline_info(self):
        """ get inline_vc, tc_runtime, cid_belongedto from redis.
        :return: a dict
        """
        # inline_j = self._r_cache.hget(INLINE_HASHNAME, str(self._inlid))
        # info = json.loads(inline_j)
        sql = "SELECT inline_vc, tc_runtime, cid cid_belongedto from " \
              "inline where inlid = %s"
        async with MysqlUtil() as conn:
            info = await conn.fetchone(sql, args=(self._inlid,))
        rlt = {'inline_vc': info['inline_vc'],
               'tc_runtime': info['tc_runtime'],
               'cid_belongedto': info['cid_belongedto']}
        return rlt

    async def _get_company_price_policy(self, cid):
        # pp_json = await RedisUtils(db=1).hget(PRICE_POLICY_HASHNAME, str(cid))
        # pp_json = self._r_cache.hget(PRICE_POLICY_HASHNAME, str(cid))
        result = AutoDic()
        sql = 'SELECT * FROM price_policy where cid = %s'
        async with MysqlUtil() as conn:
            policies = await conn.fetchall(sql, (cid,))
        for policy in policies:
            result[str(policy['inline_vc'])][str(policy['start_month'])][
                policy['time_range']] = policy
        return result
        # pp = json.loads(pp_json)
        # return pp

    async def _build_max_kwh_day(self):
        """ build es query sentance for find max kwh day."""
        dt = pendulum.now()
        dt_half_year_ago = dt.subtract(months=6)
        q = {
            "size": 1,
            "query": {
                "bool": {
                    "must": [
                        {"term": {
                            "inlid": {
                                "value": self._inlid
                            }
                        }},
                        {"range": {
                            "day": {
                                "gte": str(dt_half_year_ago),
                                "lt": str(dt)
                            }
                        }}
                    ]
                }
            },
            "sort": [
                {
                    "kwh": {
                        "order": "desc"
                    }
                }
            ]
        }
        return q

    async def _find_kwh_max_day(self):
        """ find the max kwh day in latest 6 months.
        :return: a dt object, or None if no doc
        """
        rlt = None
        q = await self._build_max_kwh_day()
        async with EsUtil() as es:
            search_rlt = await es.search_origin(
                body=q,
                index=INLINE_1DAY_POWER_ESINDEX)
        # search_rlt = self._es.search(INLINE_1DAY_POWER_ESINDEX, q)
        hits_list = search_rlt['hits']['hits']
        try:
            max_day_doc = hits_list[0]['_source']
        except IndexError:
            pass
        else:
            day_str = max_day_doc['day']
            rlt = pendulum.from_format(day_str, 'YYYY-MM-DDTHH:mm:ssZ',
                                       tz='Asia/Shanghai')
        return rlt

    def _build_aggs_kwh(self, p_char, start_dt):
        end_dt = start_dt.add(days=1)
        q = {
            "size": 0,
            "query": {
                "bool": {
                    "must": [
                        {"term": {
                            "inlid": {
                                "value": self._inlid
                            }
                        }},
                        {"term": {
                            "spfv": {
                                "value": p_char
                            }
                        }},
                        {"range": {
                            "quarter_time": {
                                "gte": str(start_dt),
                                "lt": str(end_dt)
                            }
                        }}
                    ]
                }
            },
            "aggs": {
                "kwh": {
                    "sum": {
                        "field": "kwh"
                    }
                }
            }
        }
        return q

    async def _get_total_kwh_of_one_pchar(self, p_char, start_dt):
        q = self._build_aggs_kwh(p_char, start_dt)
        async with EsUtil() as es:
            search_rlt = await es.search_origin(body=q, index=INLINE_15MIN_POWER_ESINDEX)
        # search_rlt = self._es.search(INLINE_15MIN_POWER_ESINDEX, q)
        aggs_rlt = search_rlt['aggregations']
        return aggs_rlt['kwh']['value']

    def _build_kw_curve(self, start_dt):
        end_dt = start_dt.add(days=1)
        q = {
            "size": 100,
            "_source": ["quarter_time", "p"],
            "query": {
                "bool": {
                    "must": [
                        {"term": {
                            "inlid": {
                                "value": self._inlid
                            }
                        }},
                        {"range": {
                            "quarter_time": {
                                "gte": str(start_dt),
                                "lt": str(end_dt)
                            }
                        }}
                    ]
                }
            },
            "sort": [
                {
                    "quarter_time": {
                        "order": "asc"
                    }
                }
            ]
        }
        return q

    async def _get_kw_curve(self, start_dt):
        q = self._build_kw_curve(start_dt)
        async with EsUtil() as es:
            search_rlt = await es.search_origin(
                body=q,
                index=INLINE_15MIN_POWER_ESINDEX)
        # search_rlt = self._es.search(INLINE_15MIN_POWER_ESINDEX, q)
        # hits_list is already sorted by quarter_time asc
        hits_list = search_rlt['hits']['hits']
        kw_list = []
        for item in hits_list:
            src_d = item['_source']
            qrt_str = src_d['quarter_time']
            dt = pendulum.from_format(qrt_str, 'YYYY-MM-DDTHH:mm:ssZ',
                                      tz='Asia/Shanghai')
            qrt_dt = datetime.datetime(year=dt.year, month=dt.month,
                                       day=dt.day, hour=dt.hour,
                                       minute=dt.minute, second=dt.second)
            src_d['quarter_time'] = qrt_dt
            kw_list.append(src_d)
        df = pd.DataFrame(kw_list)
        return df


async def ess_out_result(inlid, ess_system):
    """结果输出函数"""
    # get cid
    sql = "select cid cid_belongedto from inline where inlid = %s"
    async with MysqlUtil() as conn:
        cid_info = await conn.fetchone(sql=sql, args=(inlid,))
    cid = cid_info.get("cid_belongedto")
    # get proxy_id
    sql = "select cpm.proxy from company c inner join company_proxy_map cpm " \
          "on cpm.cid=c.cid where c.cid = %s"
    async with MysqlUtil() as conn:
        proxy_res = await conn.fetchone(sql, args=(cid,))
    proxy_id = proxy_res["proxy"]
    sql = "insert into energy_investment_analysis_record " \
          "(cid, analysis_type, inlid, proxy, time) " \
          "values (%s, %s, %s, %s, %s)"
    ts = pendulum.now().int_timestamp
    async with MysqlUtil() as conn:
        await conn.execute(sql, args=(
            cid, ENERGY_INVESTMENT_ESS, inlid, proxy_id, ts))
    eso = EnergyStoreOptimize(inlid)
    algo_rlt = await eso.calc_inline(ess_system)
    return algo_rlt
