from flask import Blueprint, request, jsonify
import logging
from typing import Dict, List, Any
from datetime import datetime
import pytz
import uuid

from feature.automationV2.src.integrations.database_connection import DatabaseConnection
from feature.automationV2.trigger_manager import typeConditions

execute = Blueprint('execute', __name__)
logger = logging.getLogger(__name__)


def is_valid_uuid(value):
    """Validate UUID format."""
    try:
        uuid.UUID(str(value))
        return True
    except (ValueError, AttributeError, TypeError):
        return False


def is_valid_property_id(value):
    """Validate property_id format (should be numeric string)."""
    return value and isinstance(value, str) and value.isdigit() and len(value) <= 50


STATUS_MAP_STAGE = {
    0: "pending",
    1: "completed",
    2: "failed",
    3: "running",
}


def _group_running_stage(rows: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
    """Aggregate user_pseudo_running_stage rows per node with status buckets."""
    nodes: Dict[str, Dict[str, Any]] = {}
    for row in rows:
        node_id = row.get("next_node_id")
        node_type = row.get("next_node_type")
        user_id = str(row.get("user_pseudo_id"))
        status_int = row.get("status", 0)
        status = STATUS_MAP_STAGE.get(status_int, "pending")

        if node_id not in nodes:
            nodes[node_id] = {
                "node_id": node_id,
                "node_type": node_type,
                "node_name": None,
                "pending_user_pseudo_ids": [],
                "running_user_pseudo_ids": [],
                "completed_user_pseudo_ids": [],
                "failed_user_pseudo_ids": [],
                "execution_time_ms": 0,
                "error_count": 0,
                "sample_errors": [],
            }

        bucket_key = f"{status}_user_pseudo_ids"
        if bucket_key in nodes[node_id]:
            nodes[node_id][bucket_key].append(user_id)
    return nodes


def _merge_node_context(nodes: Dict[str, Dict[str, Any]], rows: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
    """Merge node_context results (completed/failed) into node buckets with additional statistics."""
    for row in rows:
        node_id = row.get("node_id")
        node_type = row.get("node_type")
        node_name = row.get("node_name")
        user_id = str(row.get("user_pseudo_id"))
        status = row.get("status")
        execution_time_ms = row.get("execution_time_ms", 0) or 0
        error_message = row.get("error_message")

        if node_id not in nodes:
            nodes[node_id] = {
                "node_id": node_id,
                "node_type": node_type,
                "node_name": node_name,
                "pending_user_pseudo_ids": [],
                "running_user_pseudo_ids": [],
                "completed_user_pseudo_ids": [],
                "failed_user_pseudo_ids": [],
                "execution_time_ms": 0,
                "error_count": 0,
                "sample_errors": [],
            }

        target_bucket = "completed_user_pseudo_ids" if status == "completed" else "failed_user_pseudo_ids"
        nodes[node_id][target_bucket].append(user_id)
        
        # Prefer node_type from node_context if missing
        if not nodes[node_id].get("node_type"):
            nodes[node_id]["node_type"] = node_type
        
        # Set node_name if available
        if node_name and not nodes[node_id].get("node_name"):
            nodes[node_id]["node_name"] = node_name
        
        # Accumulate execution time
        if execution_time_ms:
            nodes[node_id]["execution_time_ms"] = nodes[node_id].get("execution_time_ms", 0) + execution_time_ms
        
        # Count and collect errors
        if status == "failed" and error_message:
            nodes[node_id]["error_count"] = nodes[node_id].get("error_count", 0) + 1
            # Collect unique error messages (limit to 5 samples)
            if error_message not in nodes[node_id].get("sample_errors", []) and len(nodes[node_id].get("sample_errors", [])) < 5:
                if "sample_errors" not in nodes[node_id]:
                    nodes[node_id]["sample_errors"] = []
                nodes[node_id]["sample_errors"].append(error_message)
    
    return nodes


@execute.route('/execute', methods=['POST'])
def execute_workflow():
    """
    Summary view for automation executions.
    Request JSON: { "property_id": "...", "automation_id": "..." }
    Response: executions with per-node user statuses (pending/running/completed/failed).
    """
    try:
        data = request.get_json(silent=True) or {}
        property_id = data.get("property_id")
        automation_id = data.get("automation_id")

        missing = [k for k in ["property_id", "automation_id"] if not data.get(k)]
        if missing:
            return jsonify({"status": "error", "message": f"Missing required fields: {', '.join(missing)}"}), 400

        # Validate input formats
        if not is_valid_property_id(property_id):
            return jsonify({"status": "error", "message": "Invalid property_id format"}), 400

        if not is_valid_uuid(automation_id):
            return jsonify({"status": "error", "message": "Invalid automation_id format"}), 400

        db_conn = DatabaseConnection()

        # 1) Fetch executions for this property/automation
        try:
            # First try simple query to verify data exists
            simple_query = """
                SELECT 
                    property_id,
                    automation_id,
                    CAST(execution_id AS TEXT) AS execution_id,
                    status,
                    datetime_trigger,
                    start_time,
                    end_time,
                    path_to_json_flow
                FROM trigger_manager
                WHERE property_id = :property_id
                  AND automation_id = :automation_id
                ORDER BY datetime_trigger DESC NULLS LAST
            """
            params = {"property_id": property_id, "automation_id": automation_id}
            logger.info(f"Executing simple query for property_id={property_id}, automation_id={automation_id}")
            simple_executions = db_conn.execute_query(simple_query, params, fetch=True)
            logger.info(f"Simple query returned {len(simple_executions) if simple_executions else 0} executions")
            
            if not simple_executions:
                logger.warning(f"No executions found for property_id={property_id}, automation_id={automation_id}")
                return jsonify({"status": "ok", "data": {"property_id": property_id, "automation_id": automation_id, "executions": []}}), 200
            
            # If simple query works, try full query with all fields
            exec_query = """
                SELECT 
                    property_id,
                    automation_id,
                    CAST(execution_id AS TEXT) AS execution_id,
                    status,
                    automation_name,
                    version,
                    -- UTC format for datetime_trigger
                    CASE 
                        WHEN datetime_trigger IS NOT NULL 
                        THEN TO_CHAR(datetime_trigger AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
                        ELSE NULL
                    END AS datetime_trigger,
                    -- Local time for datetime_trigger
                    CASE 
                        WHEN datetime_trigger IS NOT NULL 
                        THEN CAST(datetime_trigger AT TIME ZONE 'UTC' AT TIME ZONE 'Asia/Bangkok' AS TEXT)
                        ELSE NULL
                    END AS datetime_trigger_local,
                    -- UTC format for start_time
                    CASE 
                        WHEN start_time IS NOT NULL 
                        THEN TO_CHAR(start_time AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
                        ELSE NULL
                    END AS start_time,
                    -- Local time for start_time
                    CASE 
                        WHEN start_time IS NOT NULL 
                        THEN CAST(start_time AT TIME ZONE 'UTC' AT TIME ZONE 'Asia/Bangkok' AS TEXT)
                        ELSE NULL
                    END AS start_time_local,
                    -- UTC format for end_time
                    CASE 
                        WHEN end_time IS NOT NULL 
                        THEN TO_CHAR(end_time AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
                        ELSE NULL
                    END AS end_time,
                    -- Local time for end_time
                    CASE 
                        WHEN end_time IS NOT NULL 
                        THEN CAST(end_time AT TIME ZONE 'UTC' AT TIME ZONE 'Asia/Bangkok' AS TEXT)
                        ELSE NULL
                    END AS end_time_local,
                    path_to_json_flow
                FROM trigger_manager
                WHERE property_id = :property_id
                  AND automation_id = :automation_id
                ORDER BY datetime_trigger DESC NULLS LAST
            """
            params = {"property_id": property_id, "automation_id": automation_id}
            logger.debug(f"Executing full query for property_id={property_id}, automation_id={automation_id}")
            executions = db_conn.execute_query(exec_query, params, fetch=True)
            logger.info(f"Full query returned {len(executions) if executions else 0} executions")
            
            # If full query fails but simple query works, use simple query results and add defaults
            if not executions and simple_executions:
                logger.warning("Full query returned no results, but simple query worked. Using simple query results with defaults.")
                executions = []
                timezone_utc = pytz.UTC
                timezone_bangkok = pytz.timezone('Asia/Bangkok')
                
                for simple_exec in simple_executions:
                    dt_trigger = simple_exec.get("datetime_trigger")
                    start_t = simple_exec.get("start_time")
                    end_t = simple_exec.get("end_time")
                    
                    # Format datetime_trigger
                    datetime_trigger_utc = None
                    datetime_trigger_local = None
                    if dt_trigger:
                        if isinstance(dt_trigger, datetime):
                            if dt_trigger.tzinfo is None:
                                dt_trigger = timezone_utc.localize(dt_trigger)
                            datetime_trigger_utc = dt_trigger.strftime('%Y-%m-%dT%H:%M:%SZ')
                            datetime_trigger_local = dt_trigger.astimezone(timezone_bangkok).strftime('%Y-%m-%d %H:%M:%S')
                        else:
                            datetime_trigger_utc = str(dt_trigger)
                    
                    # Format start_time
                    start_time_utc = None
                    start_time_local = None
                    if start_t:
                        if isinstance(start_t, datetime):
                            if start_t.tzinfo is None:
                                start_t = timezone_utc.localize(start_t)
                            start_time_utc = start_t.strftime('%Y-%m-%dT%H:%M:%SZ')
                            start_time_local = start_t.astimezone(timezone_bangkok).strftime('%Y-%m-%d %H:%M:%S')
                        else:
                            start_time_utc = str(start_t)
                    
                    # Format end_time
                    end_time_utc = None
                    end_time_local = None
                    if end_t:
                        if isinstance(end_t, datetime):
                            if end_t.tzinfo is None:
                                end_t = timezone_utc.localize(end_t)
                            end_time_utc = end_t.strftime('%Y-%m-%dT%H:%M:%SZ')
                            end_time_local = end_t.astimezone(timezone_bangkok).strftime('%Y-%m-%d %H:%M:%S')
                        else:
                            end_time_utc = str(end_t)
                    
                    executions.append({
                        "property_id": simple_exec.get("property_id"),
                        "automation_id": simple_exec.get("automation_id"),
                        "execution_id": simple_exec.get("execution_id"),
                        "status": simple_exec.get("status"),
                        "automation_name": None,
                        "version": None,
                        "datetime_trigger": datetime_trigger_utc,
                        "datetime_trigger_local": datetime_trigger_local,
                        "start_time": start_time_utc,
                        "start_time_local": start_time_local,
                        "end_time": end_time_utc,
                        "end_time_local": end_time_local,
                        "timezone": "Asia/Bangkok",
                        "path_to_json_flow": simple_exec.get("path_to_json_flow")
                    })
            
            if executions:
                logger.debug(f"First execution sample: {executions[0] if len(executions) > 0 else 'N/A'}")
            
            if not executions:
                logger.warning(f"No executions found after all queries for property_id={property_id}, automation_id={automation_id}")
                return jsonify({"status": "ok", "data": {"property_id": property_id, "automation_id": automation_id, "executions": []}}), 200
        except Exception as query_error:
            logger.error(f"Error executing trigger_manager query: {str(query_error)}", exc_info=True)
            logger.error(f"Query was: {exec_query if 'exec_query' in locals() else 'N/A'}")
            return jsonify({"status": "error", "message": f"Database query error: {str(query_error)}"}), 500

        execution_ids_list = [e['execution_id'] for e in executions if e.get("execution_id")]

        # 2) Fetch running stage rows
        try:
            if not execution_ids_list:
                running_stage = []
            else:
                placeholders = [f":exec_id_{i}" for i in range(len(execution_ids_list))]
                stage_query = f"""
                    SELECT 
                        CAST(execution_id AS TEXT) AS execution_id,
                        CAST(user_pseudo_id AS TEXT) AS user_pseudo_id,
                        next_node_id,
                        next_node_type,
                        status
                    FROM user_pseudo_running_stage
                    WHERE property_id = :property_id
                      AND automation_id = :automation_id
                      AND execution_id IN ({','.join(placeholders)})
                """
                params = {
                    "property_id": property_id,
                    "automation_id": automation_id,
                    **{f"exec_id_{i}": exec_id for i, exec_id in enumerate(execution_ids_list)}
                }
                running_stage = db_conn.execute_query(stage_query, params, fetch=True) or []
            logger.debug(f"Found {len(running_stage)} running stage rows")
        except Exception as query_error:
            logger.error(f"Error executing user_pseudo_running_stage query: {str(query_error)}")
            running_stage = []

        # 3) Fetch node_context rows (completed / failed)
        try:
            if not execution_ids_list:
                node_context_rows = []
            else:
                placeholders = [f":exec_id_{i}" for i in range(len(execution_ids_list))]
                ctx_query = f"""
                    SELECT 
                        CAST(execution_id AS TEXT) AS execution_id,
                        CAST(user_pseudo_id AS TEXT) AS user_pseudo_id,
                        node_id,
                        node_type,
                        node_name,
                        status,
                        input_data,
                        output_data,
                        execution_time_ms,
                        error_message
                    FROM node_context
                    WHERE property_id = :property_id
                      AND automation_id = :automation_id
                      AND execution_id IN ({','.join(placeholders)})
                """
                params = {
                    "property_id": property_id,
                    "automation_id": automation_id,
                    **{f"exec_id_{i}": exec_id for i, exec_id in enumerate(execution_ids_list)}
                }
                node_context_rows = db_conn.execute_query(ctx_query, params, fetch=True) or []
            logger.debug(f"Found {len(node_context_rows)} node_context rows")
        except Exception as query_error:
            logger.error(f"Error executing node_context query: {str(query_error)}")
            node_context_rows = []

        # 4) Build response per execution
        execution_map = {e["execution_id"]: e for e in executions if e.get("execution_id")}
        logger.debug(f"Building response for {len(execution_map)} executions")

        stage_by_exec: Dict[str, List[Dict[str, Any]]] = {}
        for row in running_stage:
            stage_by_exec.setdefault(row["execution_id"], []).append(row)

        ctx_by_exec: Dict[str, List[Dict[str, Any]]] = {}
        for row in node_context_rows:
            ctx_by_exec.setdefault(row["execution_id"], []).append(row)

        logger.debug(f"Grouped data: {len(stage_by_exec)} executions with running_stage, {len(ctx_by_exec)} executions with node_context")

        response_execs = []
        for exec_id, exec_row in execution_map.items():
            nodes_bucket = _group_running_stage(stage_by_exec.get(exec_id, []))
            nodes_bucket = _merge_node_context(nodes_bucket, ctx_by_exec.get(exec_id, []))

            # Optional: annotate unknown node types
            for node in nodes_bucket.values():
                ntype = node.get("node_type")
                if ntype and ntype not in typeConditions:
                    node["unknown_type"] = True

            response_execs.append({
                "execution_id": exec_id,
                "status": exec_row.get("status"),
                "automation_name": exec_row.get("automation_name"),
                "version": exec_row.get("version"),
                "datetime_trigger": exec_row.get("datetime_trigger"),
                "datetime_trigger_local": exec_row.get("datetime_trigger_local"),
                "start_time": exec_row.get("start_time"),
                "start_time_local": exec_row.get("start_time_local"),
                "end_time": exec_row.get("end_time"),
                "end_time_local": exec_row.get("end_time_local"),
                "timezone": "Asia/Bangkok",
                "path_to_json_flow": exec_row.get("path_to_json_flow"),
                "nodes": list(nodes_bucket.values())
            })
        
        logger.info(f"Successfully built response with {len(response_execs)} executions")

        return jsonify({
            "status": "ok",
            "data": {
                "property_id": property_id,
                "automation_id": automation_id,
                "executions": response_execs
            }
        }), 200

    except Exception as e:
        logger.error(f"Error in /execute endpoint: {e}")
        return jsonify({"status": "error", "message": str(e)}), 500