from pot_libs.mysql_util.mysql_util import MysqlUtil
from unify_api.modules.product_info.procedures.hardware_pds import (
    get_user_hardware_info, hardware_statistics)


async def check_company_exist(company_id):
    '''
        判断工厂是否存在
    '''
    raw_sql = "select count(*) as company_count from company where cid = %s"
    async with MysqlUtil() as conn:
        company_count = await conn.fetchone(sql=raw_sql, args=(company_id))
    return company_count.get('company_count') > 0


async def equip_management_list(company_id, page_num, page_size):
    '''
        获取设备管理的监测点列表,先保留老的写法,后面1.0改版的时候统一改
    '''
    datas = await get_user_hardware_info(company_id, page_num, page_size)
    return_fields = (
        "installed_location", "device_number", "wiring_type", "ct_change",
        "pt_change", "rated_voltage", "start_time")
    return_datas = []
    for data in datas.get('rows'):
        return_one = {}
        for return_field in return_fields:
            return_one[return_field] = data.get(return_field)
        
        return_datas.append(return_one)
    datas['rows'] = return_datas
    return datas


async def equip_management_total(company_id):
    '''
        获取设备管理的汇总信息
    '''
    datas = await hardware_statistics(company_id)
    return datas


async def equip_run_list(company_id, point_ids, start_time, end_time,
                         page_num, page_size, sort_field, sort_type):
    '''
    获取设备运行记录
    '''
    async with MysqlUtil() as conn:
        raw_sql = "SELECT {} from scope_equip_run_record s " \
                  "left join (select pid,max(start_time) start_time " \
                  "from scope_equip_run_record group by pid) sp " \
                  "on s.pid = sp.pid " \
                  "left join point p on s.pid=p.pid " \
                  "left join monitor_reuse r on p.mtid = r.mtid " \
                  "where (p.cid=%s or r.cid = %s) and s.start_time " \
                  "BETWEEN %s and %s and (s.end_time > 0 or " \
                  "(s.end_time = 0 and s.start_time = sp.start_time)) "
        if point_ids:
            raw_sql += " and s.pid in %s"
            args = (
                company_id,
                company_id,
                start_time,
                end_time,
                tuple(point_ids)
            )
        else:
            args = (
                company_id,
                company_id,
                start_time,
                end_time,
            )
        # 先总数
        count_sql = raw_sql.format("count(*) as run_count", "")
        count_result = await conn.fetchone(sql=count_sql, args=args)
        
        list_result = []
        if count_result.get("run_count", 0) > 0:
            # 排序字段处理
            if sort_field == 'point_name':
                sort_field = 'p.name'
            elif sort_field == 'run_time':
                sort_field = '(s.end_time-s.start_time)'
            # 再分页列表
            raw_sql = raw_sql.format(
                "s.pid,p.name point_name,s.start_time,s.end_time",
            )
            raw_sql += " order by {} {}  LIMIT %s OFFSET %s".format(sort_field,
                                                                    sort_type)
            if point_ids:
                args = (company_id,
                        company_id,
                        start_time,
                        end_time,
                        tuple(point_ids),
                        page_size,
                        (page_num - 1) * page_size
                        )
            else:
                args = (company_id,
                        company_id,
                        start_time,
                        end_time,
                        page_size,
                        (page_num - 1) * page_size
                        )
            list_result = await conn.fetchall(sql=raw_sql,
                                              args=args)
    return list_result, count_result.get("run_count", 0)


async def equip_run_statistics(company_id, point_ids, start_time, end_time):
    '''
    获取运行统计数据
    '''
    dura_time = "case when end_time > 0 then end_time-s.start_time else 0 end"
    async with MysqlUtil() as conn:
        count_sql = f"SELECT count(*) as total_count," \
                    f"avg({dura_time}) as avg_time," \
                    f"sum({dura_time}) as all_time," \
                    f"max({dura_time}) as max_time " \
                    "from scope_equip_run_record s " \
                    "left join (select pid,max(start_time) start_time from " \
                    "scope_equip_run_record group by pid) sp " \
                    "on s.pid = sp.pid " \
                    "left join point p on s.pid=p.pid " \
                    "left join monitor_reuse r on p.mtid = r.mtid " \
                    "where (p.cid=%s or r.cid = %s) " \
                    "and s.start_time BETWEEN %s and %s and (s.end_time > 0 " \
                    "or (s.end_time = 0 and s.start_time = sp.start_time)) "
        if point_ids:
            count_sql += " and s.pid in %s"
            args = (
                company_id,
                company_id,
                start_time,
                end_time,
                tuple(point_ids)
            )
        else:
            args = (
                company_id,
                company_id,
                start_time,
                end_time,
            )
        count_result = await conn.fetchone(sql=count_sql, args=args)
    return count_result


async def get_equip_run_status(point_id):
    '''
    获取当前设备是否正在运行
    '''
    async with MysqlUtil() as conn:
        # 是否非动力设备
        power_equip_sql = "select is_power_equipment from monitor m " \
                          "left join point p on m.mtid = p.mtid " \
                          "where p.pid = %s"
        power_equip_result = await conn.fetchone(sql=power_equip_sql,
                                                 args=(point_id,))
        if power_equip_result.get("is_power_equipment", 0) == 0:
            return 2
        raw_sql = "select (case when end_time > unix_timestamp(NOW()) then 1 " \
                  "when end_time=0 then 1 else 0 end) run_count " \
                  "from scope_equip_run_record " \
                  "WHERE pid = %s and start_time < unix_timestamp(NOW()) " \
                  "order by start_time desc limit 1"
        result = await conn.fetchone(sql=raw_sql, args=(point_id,))
    return 1 if result and result.get("run_count") else 0