from datetime import datetime
import pendulum
from pot_libs.mysql_util.mysql_util import MysqlUtil
from unify_api import constants
from unify_api.constants import CST


def point_day2month(dt):
    if isinstance(dt, int) or isinstance(dt, float):
        dt = pendulum.from_timestamp(dt, tz="Asia/Shanghai")
        es_index = f"{constants.POINT_1MIN_INDEX}_{dt.year}_{dt.month}"

    elif isinstance(dt, datetime):
        es_index = f"{constants.POINT_1MIN_INDEX}_{dt.year}_{dt.month}"

    else:
        es_index = constants.POINT_1MIN_INDEX

    return es_index


async def today_alarm_cnt(cids):
    start_time = pendulum.today(tz="Asia/Shanghai")
    es_end_time = start_time.subtract(days=-1).format("YYYY-MM-DD HH:mm:ss")
    es_start_time = start_time.format("YYYY-MM-DD HH:mm:ss")

    sql = f"""
        select cid,count(*) count
        from point_1min_event pe
        left join event_type et on pe.event_type = et.e_type
        where cid in %s and et.mode = 'alarm' and event_datetime >= %s
        and event_datetime < %s
        group by cid
    """
    async with MysqlUtil() as conn:
        datas = await conn.fetchall(sql=sql,
                                    args=(cids, es_start_time, es_end_time))

    cid_bucket_map = {i["cid"]: i["count"] for i in datas}

    cid_alarm_map = {cid: {"today_alarm_count": 0} for cid in cids}
    for cid in cids:
        alarm_count = cid_bucket_map.get("cid") or 0
        cid_alarm_map[cid]["today_alarm_count"] += alarm_count
    return cid_alarm_map


async def proxy_safe_run_info(cids, start_time_str=None,
                              end_time_str=None):
    """
    批量获取 各个工厂的安全运行天数以及今日报警数, 如果是获取月份的,那么计算这个月的安全运行天数
    :param cids:
    :return:
    """
    filters = [
        {"terms": {"cid": cids}},
        # {"term": {"mode": "alarm"}},
        {"term": {"importance": 1}},
    ]
    where = ""
    start_dt, end_dt, start_ts, end_ts = None, None, 0, 0
    now_dt = pendulum.now(tz=CST)
    if start_time_str and end_time_str:
        start_dt = pendulum.parse(start_time_str)
        end_dt = pendulum.parse(end_time_str)
        start_ts = start_dt.int_timestamp
        end_ts = end_dt.int_timestamp
        now_ts = now_dt.int_timestamp
        if end_ts > now_ts:
            end_time_str = now_dt.format("YYYY-MM-DD HH:mm:ss")
        where += f" and event_datetime>= '{start_time_str}' and " \
                 f"event_datetime < '{end_time_str}' "
    sql = f"""
        select cid,date_format(event_datetime,"%%Y-%%m-%%d") fmt_day,
        count(*) count
        from point_1min_event
        where cid in %s {where}
        group by cid,date_format(event_datetime,"%%Y-%%m-%%d")
    """
    async with MysqlUtil() as conn:
        datas = await conn.fetchall(sql=sql, args=(cids,))

    # 获取到工厂安装时间create_time
    async with MysqlUtil() as conn:
        company_sql = "select cid, create_time from company where cid in %s"
        companys = await conn.fetchall(company_sql, (cids,))

    create_time_timestamp_map = {
        company["cid"]: pendulum.from_timestamp(
            company["create_time"], tz=CST) for company in companys
    }
    cid_alarm_map = {cid: {"today_alarm_count": 0, "safe_run_days": 0} for cid
                     in cids}
    cid_alarm_count_dict = dict()
    for data in datas:
        cid = data.get("cid")
        if cid not in cid_alarm_count_dict:
            cid_alarm_count_dict[cid] = 0
        elif data.get("count") > 0:
            cid_alarm_count_dict[cid] += 1
    for cid in cids:
        create_dt = create_time_timestamp_map[cid]
        total_days = (now_dt - create_dt).days + 1
        if start_time_str and end_time_str:
            # 计算一段时间内安全运行天数,总天数的逻辑稍微不一样
            total_days = (end_dt - start_dt).days + 1
            create_ts = create_dt.int_timestamp
            if start_ts < create_ts < end_ts:
                total_days = (end_dt - create_dt).days + 1
            elif create_ts > end_ts:
                total_days = 0

        has_alarm_days = cid_alarm_count_dict.get("cid") or 0
        safe_run_days = total_days - has_alarm_days
        cid_alarm_map[cid]["safe_run_days"] = safe_run_days
        cid_alarm_map[cid]["total_days"] = total_days
    today_alarm_map = await today_alarm_cnt(cids)
    for cid in cid_alarm_map:
        cid_alarm_map[cid]["today_alarm_count"] = today_alarm_map[cid][
            "today_alarm_count"]
    return cid_alarm_map


async def alarm_time_distribution(company_ids, start, end):
    sql = f"""
        SELECT
            HOUR (pevent.event_datetime) event_hour,
            COUNT(*) event_count
        FROM
            point_1min_event pevent
        WHERE
            cid IN %s
        AND pevent.event_datetime >= '{start}'
        AND pevent.event_datetime <= '{end}'
        GROUP BY
            HOUR (pevent.event_datetime)
    """
    async with MysqlUtil() as conn:
        datas = await conn.fetchall(sql, args=(company_ids,))

    time_distribution_map = {"day_alarm_cnt": 0, "night_alarm_cnt": 0,
                             "morning_alarm_cnt": 0}
    for data in datas:
        hour = int(data["event_hour"])
        if 6 <= hour < 18:
            time_distribution_map["day_alarm_cnt"] += data["event_count"]
        elif 18 <= hour <= 23:
            time_distribution_map["night_alarm_cnt"] += data["event_count"]
        else:
            time_distribution_map["morning_alarm_cnt"] += data["event_count"]
    return time_distribution_map