import json
import sys
import os
import pytz
import time
import logging
from collections import deque
from datetime import datetime
from connectors.firebase.firebase import Firebase

from api.feature.automationV2.src.models.execution_context import ExecutionContext
from api.feature.automationV2.src.models.execution_result import ExecutionResult
from api.feature.automationV2.src.models import Workflow
from api.feature.automationV2.src import nodes
from api.feature.automationV2.src.engine import NodeFactory

from api.feature.automationV2.src.integrations.database_connection import DatabaseConnection
from api.feature.automationV2.src.integrations.database_logger import DatabaseLogger
import concurrent.futures

sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
fb = Firebase(host=os.environ.get("FIREBASE_HOST"))
# timezone = pytz.timezone('Asia/Bangkok')
timezone = pytz.utc
logging.basicConfig(level=logging.INFO)

typeConditions = {
    'audience': nodes.AudienceNode,
    'api_call': nodes.APICallNode,
    'wait': nodes.WaitNode,
    'switchcase': nodes.SwitchCaseNode,
    'line_api': nodes.LINEAPI,
    "facebook": nodes.Facebook,
    "sms": nodes.SMS,
    "email": nodes.Email
}

def update_user_status_single(database_logger, execution_id: str, node_id: str, user_pseudo_id: str, status: int):
    database_logger.update_user_status_running_stage(execution_id, node_id, user_pseudo_id, status)
    return user_pseudo_id

def update_user_status_running_stage_wait_node(database_logger, execution_id: str, node_id: str, user_pseudo_id: str, status: int, next_node_id, next_node_type, wait_until):
    database_logger.update_user_status_running_stage_wait_node(execution_id, node_id, user_pseudo_id, status, next_node_id, next_node_type, wait_until)
    return user_pseudo_id

def process_execution(execution, database_logger):
    try:
        property_id = execution.get("property_id")
        automation_id = execution.get("automation_id")
        execution_id = str(execution.get("execution_id"))
        path_to_json_flow = execution.get("path_to_json_flow")

        #Update trigger_manager: -> change status to running
        update = database_logger.update_trigger_mannager_status(property_id, automation_id, execution_id, 'running')

        # Load workflow
        automation_json = fb.db.reference().child(path_to_json_flow).get()
        wf = Workflow(
            automation_json['automation_id'],
            automation_json,
            database_logger=database_logger
        )

        context = ExecutionContext(
            workflow_id=automation_id, 
            user_id="test_user_123"
        )

        node_trigger = wf.start_node_id
        audience_node_json = automation_json['nodes'][wf.connections[node_trigger][0]]

        node_factory = NodeFactory()
        node_factory.register_node_type('audience', nodes.AudienceNode)
        audience_node = node_factory.create_node(
            audience_node_json, wf.property_id, database_logger=database_logger
        )
        audience_context = audience_node.execute(context)

        if not audience_context.is_success():
            logging.error(f"Audience execution failed for automation {automation_id}")
            return

        user_pseudo_ids = audience_context.data['user_pseudo_ids']
        total_user_pseudo_ids = len(user_pseudo_ids)
        logging.info(f"[{automation_id}] total_user_pseudo_ids: {total_user_pseudo_ids}")
        logging.info("-" * 50)

        next_node_ids = automation_json['connections'][audience_node.node_id]

        def process_user(user):
            try:
                for next_node in next_node_ids:
                    data_log = {
                        "property_id": wf.property_id,
                        "automation_id": automation_id,
                        "execution_id": execution_id,
                        "user_pseudo_id": user,
                        "current_node_id": audience_node.node_id,
                        "current_node_type": audience_node.node_type,
                        "current_node_json": audience_node_json,
                        "next_node_id": next_node,
                        "next_node_type": automation_json['nodes'][next_node]['type'],
                        "next_node_json": automation_json['nodes'][next_node],
                        "wait_until": '-',
                        "input_data": {'user_pseudo_id': user},
                        "status": 0,
                        "is_moved": 0,
                        "execution_time_ms": 0,
                        "error_message": None,
                    }

                    user_node_context = {
                        "execution_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                        'property_id': property_id,
                        'user_pseudo_id': user,
                        'automation_id': automation_id,
                        'execution_id': execution_id,
                        'node_id': audience_node.node_id,
                        'node_type': audience_node.node_type,
                        'input_data': {},
                        'output_data': {'user_pseudo_id': user},
                        'status': 'completed',
                        'execution_time_ms': 0,
                        "error_message": None                    }
                    database_logger.add_data_to_running_stage(data_log)
                    database_logger.log_node_context(user_node_context)
            except Exception as e:
                logging.error(f"Error processing user {user} in automation {automation_id}: {e}")

        # Parallelize users within this execution
        max_workers_inner = min(20, total_user_pseudo_ids)
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers_inner) as user_executor:
            futures = [user_executor.submit(process_user, user) for user in user_pseudo_ids]
            for f in concurrent.futures.as_completed(futures):
                _ = f.result()

    except Exception as e:
        logging.error(f"Error processing execution {execution.get('execution_id')}: {e}")

#Start here

#Install auidnece
def execute_workflow():
    db_conn = DatabaseConnection()
    database_logger = DatabaseLogger(db_connection=db_conn)
    now = datetime.now()
    time_trigger = now.strftime('%Y-%m-%d %H:%M:00')
    logging.info(f"Automation time: {time_trigger}")
    executions = database_logger.get_active_workflow(time_trigger)
    start_time = time.time()  
    max_workers_outer = min(10, len(executions))  # avoid overloading
    logging.info(f"Total automatch: {len(executions)}")
    if len(executions) > 0:
        logging.info(f"In Automate")
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers_outer) as executor:
            futures = [executor.submit(process_execution, execution, database_logger) for execution in executions]
            for f in concurrent.futures.as_completed(futures):
                _ = f.result()
        end_time_1 = time.time()
        elapsed_seconds_1 = round(end_time_1 - start_time, 2)
        logging.info(f"\nAll executions finished in {elapsed_seconds_1} seconds")


        #Running workflow
        executions_node = database_logger.get_running_stage_context()
        # Get automation context
        automationIds = set([(item['automation_id'], item['property_id']) for item in executions_node])
        automationContext = {
            auto: fb.db.reference().child(f"account/{property_id}/automation_v2/{auto}").get()
            for auto, property_id in automationIds
        }

        node_context = {
            f"{item['property_id']}_{item['automation_id']}_{item['execution_id']}_{item['next_node_id']}" : item['next_node_json']
            for item in executions_node
        }
        auto_execute_currentnode_nextnode = list(set([(i['property_id'], i['automation_id'], str(i['execution_id']), i['current_node_id'], i['current_node_type'], i['next_node_id'], i['next_node_type']) for i in executions_node]))

        def process_execution_node_pair(execute_cnode_nnode):
            try:
                logging.info('+' * 100)
                property_id = execute_cnode_nnode[0]
                automation_id = execute_cnode_nnode[1]
                execution_id = execute_cnode_nnode[2]
                current_node_id = execute_cnode_nnode[3]
                current_node_type = execute_cnode_nnode[4]
                next_node_id = execute_cnode_nnode[5]
                next_node_type = execute_cnode_nnode[6]

                update = database_logger.update_trigger_mannager_status(property_id, automation_id, execution_id, 'running')

                wf = Workflow(workflow_id=automation_id, definition=automationContext[automation_id])
                task_flow = wf.connections
                final_node = wf.final_nodes

                # Get user_pseudo_ids
                user_pseudo_ids = [
                    str(item['user_pseudo_id'])
                    for item in executions_node
                    if item['next_node_id'] == next_node_id
                    and item['property_id'] == property_id
                    and item['automation_id'] == automation_id
                ]

                # Update running status to 3 in parallel
                max_workers = min(16, len(user_pseudo_ids))
                with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
                    futures = {
                        executor.submit(update_user_status_single, database_logger, execution_id, current_node_id, user, 3): user
                        for user in user_pseudo_ids
                    }
                    for future in concurrent.futures.as_completed(futures):
                        user = futures[future]
                        try:
                            result = future.result()
                            logging.info(f"Updated user {result}")
                        except Exception as e:
                            logging.error(f"Error updating user {user}: {e}")

                # While loop to process workflow chain
                logging.info(f'next_node_id {next_node_id}')
                queue = deque([(next_node, current_node_id) for next_node in wf.connections[current_node_id]])
                parents = {}
                masterResult = ExecutionResult(success=True)

                while queue:
                    node_perform_id, previous_node_id = queue.popleft()
                    node_perform_context = automationContext[automation_id]['nodes'][node_perform_id]
                    previous_node_context = automationContext[automation_id]['nodes'][previous_node_id]
                    node_perform_type = node_perform_context['type']
                    node_perform_name = node_perform_context['name']

                    parents[node_perform_id] = previous_node_id
                    previous_node_type = previous_node_context['type'] if previous_node_context else None

                    node_factory = NodeFactory()
                    # Check if node type is supported
                    if node_perform_type not in typeConditions:
                        logging.error(f"Unsupported node type: {node_perform_type} for node {node_perform_id}")
                        continue
                    node_factory.register_node_type(node_perform_type, typeConditions[node_perform_type])
                    node_context = automationContext[automation_id]['nodes'][node_perform_id]
                    node_build = node_factory.create_node(node_context, property_id, database_logger=database_logger)

                    # Manage context
                    if masterResult.data:
                        if previous_node_type == 'switchcase':
                            logging.info('After switchcase')
                            success_user_pseudo_ids = masterResult.data[previous_node_id]['conditions'][node_perform_id]['user_pseudo_ids']
                            logging.info(f'User node: {node_perform_id} | {success_user_pseudo_ids}')
                        elif 'user_pseudo_ids' in masterResult.data[previous_node_id]:
                            success_user_pseudo_ids = masterResult.data[previous_node_id]['user_pseudo_ids']
                        else:
                            success_user_pseudo_ids = user_pseudo_ids
                    else:
                        success_user_pseudo_ids = user_pseudo_ids

                    context = ExecutionContext(automation_id, user_id='ADMIN')
                    context.set_data('workflow_id', automation_id)
                    context.set_data('execution_id', execution_id)
                    context.data['user_pseudo_ids'] = success_user_pseudo_ids

                    # Execute node
                    result = node_build.execute(context=context)
                    masterResult.merge_data(result.data)

                    # Parallelize logging of each log_result
                    log_result = result.data[node_perform_id]['log_result'] if node_perform_id in result.data else []
                    if log_result:
                        with concurrent.futures.ThreadPoolExecutor(max_workers=8) as log_executor:
                            log_futures = []
                            for log in log_result:
                                logData = {
                                    'property_id': property_id,
                                    'user_pseudo_id': log['user_pseudo_id'],
                                    'automation_id': automation_id,
                                    'execution_id': execution_id,
                                    'node_id': node_perform_id,
                                    'node_type': node_perform_type,
                                    'node_name': node_perform_name,
                                    'input_data': {},
                                    'output_data': log,
                                    'status': log['status'],
                                    'execution_time_ms': 0,
                                    "error_message": str(log.get('message', None))
                                }
                                log_futures.append(log_executor.submit(database_logger.log_node_context, logData))

                            # Wait for logs to complete
                            for f in concurrent.futures.as_completed(log_futures):
                                _ = f.result()
                    
                    #Check wait node
                    face_wait = False
                    if node_perform_type == 'wait':
                        face_wait = True

                        wait_until = result.data[node_perform_id]['wait_calcualted'] if node_perform_id in result.data else '-'
                        logging.info(f"wait_until: {wait_until}")

                        next_node_id_temp = wf.connections[node_perform_id][0]
                        next_node_type_temp = wf.nodes[next_node_id_temp]['type']
                        # wait_until = '2025-12-12 00:00:00'
                        logging.info(f"NEXT_NODE_TEMP {next_node_id_temp} : {wait_until}")
                        
                        max_workers = min(16, len(success_user_pseudo_ids))
                        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
                            futures = {
                                executor.submit(update_user_status_running_stage_wait_node, database_logger, execution_id, node_perform_id, user, 3, next_node_id_temp, next_node_type_temp, wait_until): user
                                for user in success_user_pseudo_ids
                            }
                            for future in concurrent.futures.as_completed(futures):
                                user = futures[future]
                                try:
                                    result = future.result()
                                    logging.info(f"Updated user {result}")
                                except Exception as e:
                                    logging.error(f"Error updating user {user}: {e}")
                        #Update user_pseudo_running_stage

                        break

                    # Add next tasks to queue
                    next_tasks = task_flow.get(node_perform_id, [])
                    for next_task in next_tasks:
                        queue.append((next_task, node_perform_id))
                
                if face_wait:
                    logging.info(f"Waiting wait node: {automation_id} in {round(time.time() - start_time, 2)} sec")
                    update = database_logger.update_trigger_mannager_status(property_id, automation_id, execution_id, 'running')
                else:
                    logging.info(f"Finished {automation_id} in {round(time.time() - start_time, 2)} sec")
                    update = database_logger.update_trigger_mannager_status(property_id, automation_id, execution_id, 'completed', datetime.now(timezone).strftime("%Y-%m-%d %H:%M:%S"))

                    #Clear user from running stage
                    rowCount = database_logger.clear_user_from_running_stage(property_id, automation_id, execution_id)
                    print(f"DELETE: {rowCount}")

            except Exception as e:
                logging.info(f"Error in automation {execute_cnode_nnode}: {e}")

        max_outer = min(8, len(auto_execute_currentnode_nextnode))
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_outer) as executor:
            futures = [executor.submit(process_execution_node_pair, item) for item in auto_execute_currentnode_nextnode]
            for f in concurrent.futures.as_completed(futures):
                _ = f.result()

        end_time_2 = time.time()
        elapsed_seconds_2 = round(end_time_2 - start_time, 2)
        logging.info(f"\nAll Workflow finished in {elapsed_seconds_2} seconds")
    
    return True