import json
import sys
import os
import pytz
import time
import uuid
import logging
from collections import deque
from datetime import datetime
from firebase.firebase import Firebase
from feature.automationV2.src.models.execution_context import ExecutionContext
from feature.automationV2.src.models.execution_result import ExecutionResult
from feature.automationV2.src.models import Workflow
from feature.automationV2.src import nodes
from feature.automationV2.src.engine import NodeFactory

from feature.automationV2.src.integrations.database_connection import DatabaseConnection
from feature.automationV2.src.integrations.database_logger import DatabaseLogger
import concurrent.futures
from uuid import UUID
from collections import defaultdict

# sys.path.insert(0, os.path.join(os.path.dirname(__file__), '.'))
fb = Firebase(host=os.environ.get("FIREBASE_HOST"))
# timezone = pytz.timezone('Asia/Bangkok')
timezone = pytz.utc
logging.basicConfig(level=logging.INFO)

LOGGING_PREFIX = 'TRIGGER_WAITING: '

typeConditions = {
    'audience': nodes.AudienceNode,
    'api_call': nodes.APICallNode,
    'wait': nodes.WaitNode,
    'switchcase': nodes.SwitchCaseNode,
    'line_api': nodes.LINEAPI,
    "facebook": nodes.Facebook
}

def update_user_status_single(database_logger, execution_id: str, node_id: str, user_pseudo_id: str, status: int):
    database_logger.update_user_status_running_stage(execution_id, node_id, user_pseudo_id, status)
    return user_pseudo_id

def update_user_status_running_stage_wait_node(database_logger, execution_id: str, node_id: str, user_pseudo_id: str, status: int, next_node_id, next_node_type, wait_until):
    database_logger.update_user_status_running_stage_wait_node(execution_id, node_id, user_pseudo_id, status, next_node_id, next_node_type, wait_until)
    return user_pseudo_id

def normalize_log(log: dict) -> dict:
    normalized = {}
    for k, v in log.items():
        if isinstance(v, datetime):
            normalized[k] = v.isoformat()
        else:
            normalized[k] = v
    return normalized

def process_execution(database_logger:DatabaseLogger):
    try:
        #Get user in waiting node
        user_running = database_logger.get_waiting_stage_context()
        return user_running
    except Exception as e:
        logging.error(f"{LOGGING_PREFIX}Error processing execution: {e}")

def group_automation_data(raw_data):
    """
    Groups flat automation logs into a structured dictionary:
    Automation -> Execution -> List of Users
    """
    result = {}

    for entry in raw_data:
        # Extract keys as strings
        prop_id = str(entry['property_id'])
        auto_id = str(entry['automation_id'])
        exec_id = str(entry['execution_id'])
        next_node_id = entry['next_node_id']

        # Level 1: Property ID
        if prop_id not in result:
            result[prop_id] = {}

        # Level 2: Automation ID
        if auto_id not in result[prop_id]:
            result[prop_id][auto_id] = {}

        # Level 3: Execution ID
        if exec_id not in result[prop_id][auto_id]:
            result[prop_id][auto_id][exec_id] = {
                'execution_id': exec_id,
                'steps': {}
            }

        # Level 4: Next Node ID (Steps)
        steps_group = result[prop_id][auto_id][exec_id]['steps']

        if next_node_id not in steps_group:
            # Store configuration once per node group
            steps_group[next_node_id] = {
                'target_node_id': next_node_id,
                'target_node_type': entry['next_node_type'],
                'target_node_config': entry['next_node_json'],
                'users': []
            }

        # Level 5: Process User Data
        user_obj = entry.copy()

        # Remove grouping keys from individual user records to avoid redundancy
        fields_to_remove = [
            'property_id', 'automation_id', 'execution_id', 
            'next_node_id', 'next_node_type', 'next_node_json'
        ]
        for field in fields_to_remove:
            user_obj.pop(field, None)

        # Convert UUID to string if necessary
        if isinstance(user_obj.get('user_pseudo_id'), UUID):
            user_obj['user_pseudo_id'] = str(user_obj['user_pseudo_id'])

        # Append user to the specific step
        steps_group[next_node_id]['users'].append(user_obj)

    return result


def execute_workflow_waiting():
    start_time = time.time()
    db_conn = DatabaseConnection()
    database_logger = DatabaseLogger(db_connection=db_conn)

    user_running = process_execution(database_logger)

    #Check wait
    if len(user_running) <=0:
        logging.info(f"{LOGGING_PREFIX}No user in waiting status")
        return False, "No user in waiting status"

    time_now = datetime.now(timezone).replace(second=0, microsecond=0)
    wait_stop_user = []
    for user in user_running:
        wait_until_str = user['wait_until']
        dt = datetime.strptime(wait_until_str, "%Y-%m-%dT%H:%M:%S")
        dt_utc = dt.replace(tzinfo=timezone)
        dt_utc = dt_utc.replace(second=0, microsecond=0)
        
        # compare directly
        if dt_utc == time_now:
            wait_stop_user.append(user)
        
        # #for test
        # else:
        #     wait_stop_user.append(user)

    #Check match
    if len(wait_stop_user) <= 0:
        logging.info(f"{LOGGING_PREFIX}Automation waiting Job: No user match {time_now.isoformat()}")
        return False, f"{LOGGING_PREFIX}Automation waiting Job: No user match {time_now.isoformat()}"
    
    #Group automation | execution | Node flow
    grouped_result = group_automation_data(wait_stop_user)
    
    #Loop Automation
    for pr in grouped_result:
        for au in grouped_result[pr]:
            try:
                # Get automation context from Firebase
                automation_context = fb.db.reference().child(f"account/{pr}/automation_v2/{au}").get()
                
                # Initialize Workflow Helper
                wf = Workflow(workflow_id=au, definition=automation_context, database_logger=database_logger)
                
                # Loop Executions
                for ex in grouped_result[pr][au]:
                    execution_data = grouped_result[pr][au][ex]
                    execution_id = execution_data['execution_id']

                    # Loop Steps (Target Nodes after Wait)
                    for next_node_id, step_data in execution_data['steps'].items():
                        try:
                            logging.info(f"{LOGGING_PREFIX} Processing {au} -> {next_node_id}")
                            
                            target_node_type = step_data['target_node_type']
                            users_list = step_data['users']
                            
                            # Extract User IDs
                            user_pseudo_ids = [u['user_pseudo_id'] for u in users_list]
                            
                            # Identify the node the users just came from (The Wait Node)
                            # We assume a batch in 'steps' comes from the same previous node
                            previous_node_id = users_list[0]['current_node_id'] if users_list else None

                            # 1. Update status to Running (3) to lock these users
                            # Using threading for speed as per your pattern
                            max_workers = min(16, len(user_pseudo_ids))
                            with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
                                futures = {
                                    executor.submit(update_user_status_single, database_logger, execution_id, previous_node_id, user, 3): user
                                    for user in user_pseudo_ids
                                }
                                for future in concurrent.futures.as_completed(futures):
                                    pass # Wait for completion

                            # 2. Initialize Queue for Workflow Engine
                            # Start immediately at the next_node_id
                            queue = deque([(next_node_id, previous_node_id)])
                            masterResult = ExecutionResult(success=True)
                            
                            # Set the users for the initial context
                            success_user_pseudo_ids = user_pseudo_ids

                            # 3. BFS Execution Loop
                            while queue:
                                node_perform_id, prev_node_id = queue.popleft()
                                
                                # Safety check if node exists in flow
                                if node_perform_id not in automation_context['nodes']:
                                    logging.error(f"Node {node_perform_id} not found in automation {au}")
                                    continue

                                node_context_json = automation_context['nodes'][node_perform_id]
                                node_perform_type = node_context_json['type']
                                node_perform_name = node_context_json['name']

                                # Prepare Factory
                                node_factory = NodeFactory()
                                node_factory.register_node_type(node_perform_type, typeConditions.get(node_perform_type))
                                node_build = node_factory.create_node(node_context_json, pr, database_logger=database_logger)

                                # Context Management
                                # If we are deeper in the loop, check if users were filtered by previous nodes
                                if node_perform_id != next_node_id: # If not the very first node of this batch
                                    if masterResult.data and prev_node_id in masterResult.data:
                                        if 'user_pseudo_ids' in masterResult.data[prev_node_id]:
                                            success_user_pseudo_ids = masterResult.data[prev_node_id]['user_pseudo_ids']
                                        elif node_perform_type == 'switchcase': 
                                            # Logic to handle switchcase branching would go here 
                                            # (simplified for this snippet based on your manual code)
                                            pass

                                context = ExecutionContext(au, user_id='SYSTEM_WAIT_TRIGGER')
                                context.set_data('workflow_id', au)
                                context.set_data('execution_id', execution_id)
                                context.data['user_pseudo_ids'] = success_user_pseudo_ids

                                # Execute Node
                                try:
                                    result = node_build.execute(context=context)
                                    masterResult.merge_data(result.data)
                                except Exception as e:
                                    logging.error(f"Node Execution Failed {node_perform_id}: {e}")
                                    continue

                                # Log Results (Threaded)
                                log_result = result.data.get(node_perform_id, {}).get('log_result', [])
                                if log_result:
                                    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as log_executor:
                                        log_futures = []
                                        for log in log_result:
                                            clean_log = normalize_log(log)
                                            logData = {
                                                'property_id': pr,
                                                'user_pseudo_id': log['user_pseudo_id'],
                                                'automation_id': au,
                                                'execution_id': execution_id,
                                                'node_id': node_perform_id,
                                                'node_type': node_perform_type,
                                                'node_name': node_perform_name,
                                                'input_data': {}, # Could populate from previous step if needed
                                                'output_data': clean_log,
                                                'status': log['status'],
                                                'execution_time_ms': 0,
                                                "error_message": str(log.get('message', None))
                                            }
                                            log_futures.append(log_executor.submit(database_logger.log_node_context, logData))
                                        for f in concurrent.futures.as_completed(log_futures):
                                            _ = f.result()

                                # Handle Wait Node (Stop processing for these users)
                                face_wait = False
                                if node_perform_type == 'wait':
                                    face_wait = True
                                    wait_until = result.data[node_perform_id]['wait_calcualted']
                                    
                                    # Identify Next Node after THIS wait
                                    next_after_wait_id = wf.connections.get(node_perform_id, [None])[0]
                                    if next_after_wait_id:
                                        next_after_wait_type = automation_context['nodes'][next_after_wait_id]['type']
                                        
                                        # Update DB: Status 3 -> Wait, set new wait_until
                                        with concurrent.futures.ThreadPoolExecutor(max_workers=min(16, len(success_user_pseudo_ids))) as executor:
                                            futures = {
                                                executor.submit(update_user_status_running_stage_wait_node, 
                                                                database_logger, execution_id, node_perform_id, 
                                                                user, 3, next_after_wait_id, next_after_wait_type, wait_until): user
                                                for user in success_user_pseudo_ids
                                            }
                                            for f in concurrent.futures.as_completed(futures):
                                                _ = f.result()
                                    
                                    # Break the queue for this branch as users are now waiting
                                    break

                                # Add Next Tasks to Queue
                                next_tasks = wf.connections.get(node_perform_id, [])
                                for next_task in next_tasks:
                                    queue.append((next_task, node_perform_id))
                            
                            # End of Queue Loop
                            # If queue finished and no wait encountered, users are done.
                            # Cleanup logic (e.g., delete from running stage) usually happens here
                            # depending on business logic (partially shown in manual trigger)
                            if face_wait:
                                logging.info(f"Waiting wait node: {au} in {round(time.time() - start_time, 2)} sec")
                                update = database_logger.update_trigger_mannager_status(pr, au, execution_id, 'running')
                            else:
                                logging.info(f"Finished {au} in {round(time.time() - start_time, 2)} sec")
                                update = database_logger.update_trigger_mannager_status(pr, au, execution_id, 'completed', datetime.now(timezone).strftime("%Y-%m-%d %H:%M:%S"))

                                #clear user in running stage
                                rowCount = database_logger.clear_user_from_running_stage(pr, au, execution_id)
                                print(f"DELETE: {rowCount}")
                            
                        except Exception as e:
                            logging.error(f"{LOGGING_PREFIX}Error processing step {next_node_id}: {e}")

            except Exception as e:
                logging.error(f"{LOGGING_PREFIX}Error processing automation {au}: {e}")
    
    return True, "All waiting is done"