"""
Workflow class for managing workflow definitions and node relationships.
"""

from typing import Dict, Any, List, Optional
import json
from .base_node import BaseNode


class WorkflowParsingError(Exception):
    """Exception raised when workflow parsing fails."""
    pass


class WorkflowValidationError(Exception):
    """Exception raised when workflow validation fails."""
    pass


class Workflow:
    """
    Represents a complete workflow with nodes, connections, and configuration.
    
    Manages the workflow definition, node instances, and execution flow.
    Provides methods for workflow validation and node relationship management.
    """
    
    def __init__(self, workflow_id: str, definition: Dict[str, Any], database_logger=None):
        """
        Initialize workflow from JSON definition.
        
        Args:
            workflow_id: Unique identifier for the workflow
            definition: Complete workflow definition dictionary
        """
        self.database_logger = database_logger
        self.property_id = definition.get("property_id")
        self.workflow_id = workflow_id
        self.definition = definition
        self.name = definition.get('automation_name', f'Workflow {workflow_id}')
        self.version = definition.get('current_version', 1)
        self.property_id = definition.get('property_id', None)
        
        print(f"self.property_id: {self.property_id}")
        
        # Node management
        self.nodes: Dict[str, BaseNode(property_id=self.property_id)] = {}
        self.connections: Dict[str, List[str]] = {}
        self.parallel_groups: Dict[str, List[str]] = definition.get('parallel_groups', {})
        
        # Workflow settings
        self.settings = definition.get('settings', {})
        self.timezone = self.settings.get('timezone', 'UTC')
        self.error_handling = self.settings.get('error_handling', 'continue')
        
        # Execution state
        self.start_node_id: Optional[str] = None
        
        # Parse connections from definition
        self._parse_connections()
        self.__parse_nodes()
        
        # Find start node automatically if not explicitly set
        self._find_start_node()
        self._find_final_node()
    
    def install_automation_context(self, fb, now) -> None:
        # Append context to Firebase
        fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}").set(self.definition)
        fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}/edit").set([self.definition])
        fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}/display").set(True)
        fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}/active").set(False)
        fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}/createdate").set(now)
        fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}/lastupdate").set(now)
        self.path_to_json_flow = f"account/{self.property_id}/automation_v2/{self.workflow_id}"
        self.database_logger.add_new_automation(self.definition)
    
    def update_automation_context(self, fb, now) -> None:
        display = fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}/display").get()
        active = fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}/active").get()
        createdate = fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}/createdate").get()
        automation_history = fb.db.reference().child(f'account/{self.property_id}/automation_v2/{self.workflow_id}/edit').get()

        fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}").set(self.definition)
        fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}/display").set(display)
        fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}/active").set(active)
        fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}/createdate").set(createdate)
        fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}/lastupdate").set(now)

        automation_history.append(self.definition)
        fb.db.reference().child(f"account/{self.property_id}/automation_v2/{self.workflow_id}/edit").set(automation_history)

    
    def _parse_connections(self) -> None:
        """Parse node connections from the workflow definition."""
        connections_def = self.definition.get('connections', {})
        if not isinstance(connections_def, dict):
            return  # Skip parsing if connections is not a dict
        
        for source_node_id, target_node_ids in connections_def.items():
            if isinstance(target_node_ids, str):
                target_node_ids = [target_node_ids]
            self.connections[source_node_id] = target_node_ids

    def __parse_nodes(self) -> None:
        nodes_def = self.definition.get('nodes', {})
        if not isinstance(nodes_def, dict):
            return
        
        for source_node_id, target_node_ids in nodes_def.items():
            if isinstance(target_node_ids, str):
                target_node_ids = [target_node_ids]
            self.nodes[source_node_id] = target_node_ids
    
    def _find_start_node(self) -> None:
        """
        Automatically find the start node by identifying nodes with no incoming connections.
        """
        if 'nodes' not in self.definition:
            return
        
        nodes_def = self.definition['nodes']
        if not isinstance(nodes_def, dict):
            return  # Skip if nodes is not a dict
        
        all_node_ids = set(nodes_def.keys())
        target_node_ids = set()
        
        # Collect all target nodes from connections
        for targets in self.connections.values():
            target_node_ids.update(targets)
        
        # Start nodes are those with no incoming connections
        start_nodes = all_node_ids - target_node_ids
        
        if len(start_nodes) == 1:
            self.start_node_id = list(start_nodes)[0]
        elif len(start_nodes) > 1:
            # If multiple start nodes, look for trigger nodes
            for node_id in start_nodes:
                node_def = nodes_def[node_id]
                if isinstance(node_def, dict) and node_def.get('type') == 'trigger':
                    self.start_node_id = node_id
                    break
            # If no trigger found, use the first one
            if not self.start_node_id:
                self.start_node_id = list(start_nodes)[0]

    def _find_final_node(self):
        all_nodes = set()
        referenced_nodes = set()

        for src, targets in self.connections.items():
            all_nodes.add(src)
            referenced_nodes.update(targets)

        # Final node(s) are those referenced but not used as keys
        final_nodes = referenced_nodes - all_nodes
        self.final_nodes = final_nodes

        return list(final_nodes)
    
    def add_node(self, node: BaseNode) -> None:
        """
        Add a node to the workflow.
        
        Args:
            node: Node instance to add
        """
        self.nodes[node.node_id] = node
        
        # Set up node connections
        if node.node_id in self.connections:
            for target_node_id in self.connections[node.node_id]:
                node.add_connection(target_node_id)
    
    def get_node(self, node_id: str) -> Optional[BaseNode]:
        """
        Get a node by its ID.
        
        Args:
            node_id: ID of the node to retrieve
            
        Returns:
            Node instance or None if not found
        """
        return self.nodes.get(node_id)
    
    def get_start_node(self) -> Optional[BaseNode]:
        """
        Get the starting node of the workflow.
        
        Returns:
            Starting node instance or None if not set
        """
        if self.start_node_id:
            return self.get_node(self.start_node_id)
        return None
    
    def set_start_node(self, node_id: str) -> None:
        """
        Set the starting node for workflow execution.
        
        Args:
            node_id: ID of the node to set as start node
        """
        if node_id in self.nodes:
            self.start_node_id = node_id
        else:
            raise ValueError(f"Node {node_id} not found in workflow")
    
    def get_connected_nodes(self, node_id: str) -> List[BaseNode]:
        """
        Get all nodes connected to the specified node.
        
        Args:
            node_id: ID of the source node
            
        Returns:
            List of connected node instances
        """
        connected_nodes = []
        if node_id in self.connections:
            for target_node_id in self.connections[node_id]:
                target_node = self.get_node(target_node_id)
                if target_node:
                    connected_nodes.append(target_node)
        return connected_nodes
    
    def get_parallel_group_nodes(self, group_name: str) -> List[BaseNode]:
        """
        Get all nodes in a parallel execution group.
        
        Args:
            group_name: Name of the parallel group
            
        Returns:
            List of nodes in the parallel group
        """
        group_nodes = []
        if group_name in self.parallel_groups:
            for node_id in self.parallel_groups[group_name]:
                node = self.get_node(node_id)
                if node:
                    group_nodes.append(node)
        return group_nodes
    
    def is_parallel_group(self, node_ids: List[str]) -> bool:
        """
        Check if a set of nodes forms a parallel execution group.
        
        Args:
            node_ids: List of node IDs to check
            
        Returns:
            True if nodes form a parallel group, False otherwise
        """
        for group_nodes in self.parallel_groups.values():
            if set(node_ids) == set(group_nodes):
                return True
        return False
    
    def validate(self) -> List[str]:
        """
        Validate the workflow configuration.
        
        Returns:
            List of validation error messages (empty if valid)
        """
        errors = []
        
        # Check if workflow has nodes
        if not self.nodes:
            errors.append("Workflow has no nodes")
        
        # Check if start node is set
        if not self.start_node_id:
            errors.append("No start node specified")
        elif self.start_node_id not in self.nodes:
            errors.append(f"Start node {self.start_node_id} not found")
        
        # Validate node connections
        for source_node_id, target_node_ids in self.connections.items():
            if source_node_id not in self.nodes:
                errors.append(f"Source node {source_node_id} not found")
            
            for target_node_id in target_node_ids:
                if target_node_id not in self.nodes:
                    errors.append(f"Target node {target_node_id} not found")
        
        # Validate parallel groups
        for group_name, node_ids in self.parallel_groups.items():
            for node_id in node_ids:
                if node_id not in self.nodes:
                    errors.append(f"Parallel group {group_name} references missing node {node_id}")
        
        # Validate individual nodes
        for node in self.nodes.values():
            if not node.validate_config():
                errors.append(f"Node {node.node_id} has invalid configuration")
        
        return errors
    
    def get_node_count(self) -> int:
        """
        Get the total number of nodes in the workflow.
        
        Returns:
            Number of nodes
        """
        return len(self.nodes)
    
    def get_connection_count(self) -> int:
        """
        Get the total number of connections in the workflow.
        
        Returns:
            Number of connections
        """
        return sum(len(targets) for targets in self.connections.values())
    
    def to_dict(self) -> Dict[str, Any]:
        """
        Convert workflow to dictionary representation.
        
        Returns:
            Dictionary representation of the workflow
        """
        return {
            'workflow_id': self.workflow_id,
            'name': self.name,
            'version': self.version,
            'definition': self.definition,
            'node_count': self.get_node_count(),
            'connection_count': self.get_connection_count(),
            'start_node_id': self.start_node_id
        }
    
    def __str__(self) -> str:
        """String representation of the workflow."""
        return f"Workflow(id={self.workflow_id}, name={self.name}, nodes={len(self.nodes)})"
    
    def __repr__(self) -> str:
        """Detailed string representation of the workflow."""
        return f"Workflow(id={self.workflow_id}, name={self.name}, nodes={len(self.nodes)}, connections={len(self.connections)})"
    
    @classmethod
    def from_json(cls, json_data: str) -> 'Workflow':
        """
        Create a workflow from JSON string.
        
        Args:
            json_data: JSON string containing workflow definition
            
        Returns:
            Workflow instance
            
        Raises:
            WorkflowParsingError: If JSON parsing fails
        """
        try:
            definition = json.loads(json_data)
            return cls.from_dict(definition)
        except json.JSONDecodeError as e:
            raise WorkflowParsingError(f"Invalid JSON format: {str(e)}")
        except Exception as e:
            raise WorkflowParsingError(f"Failed to parse workflow from JSON: {str(e)}")
    
    @classmethod
    def from_dict(cls, definition: Dict[str, Any]) -> 'Workflow':
        """
        Create a workflow from dictionary definition.
        
        Args:
            definition: Dictionary containing workflow definition
            
        Returns:
            Workflow instance
            
        Raises:
            WorkflowParsingError: If definition is invalid
        """
        if not isinstance(definition, dict):
            raise WorkflowParsingError("Workflow definition must be a dictionary")
        
        if 'workflow_id' not in definition:
            raise WorkflowParsingError("Workflow definition must contain 'workflow_id'")
        
        workflow_id = definition['workflow_id']
        if not isinstance(workflow_id, str) or not workflow_id.strip():
            raise WorkflowParsingError("workflow_id must be a non-empty string")
        
        return cls(workflow_id, definition)
    
    def instantiate_nodes(self, node_factory, database_logger=None) -> None:
        """
        Instantiate all nodes using the provided NodeFactory.
        
        Args:
            node_factory: NodeFactory instance for creating nodes
            database_logger: Optional database logger to pass to nodes
            
        Raises:
            WorkflowParsingError: If node instantiation fails
        """
        if 'nodes' not in self.definition:
            raise WorkflowParsingError("Workflow definition must contain 'nodes' section")
        
        try:
            # Import here to avoid circular imports
            from ..engine.node_factory import NodeCreationError
            
            created_nodes = node_factory.create_nodes_from_workflow(self.definition, database_logger)
            
            # Add all created nodes to the workflow
            for node_id, node in created_nodes.items():
                self.add_node(node)
                
        except NodeCreationError as e:
            raise WorkflowParsingError(f"Failed to instantiate nodes: {str(e)}")
        except Exception as e:
            raise WorkflowParsingError(f"Unexpected error during node instantiation: {str(e)}")
    
    def to_json(self, indent: Optional[int] = None) -> str:
        """
        Convert workflow to JSON string.
        
        Args:
            indent: JSON indentation level (None for compact format)
            
        Returns:
            JSON string representation of the workflow
        """
        return json.dumps(self.definition, indent=indent)
    
    def validate_structure(self) -> List[str]:
        """
        Validate the workflow JSON structure without instantiating nodes.
        
        Returns:
            List of validation error messages (empty if valid)
        """
        errors = []
        
        # Check required top-level fields
        required_fields = ['automation_id']
        for field in required_fields:
            if field not in self.definition:
                errors.append(f"Missing required field: {field}")
        
        # Validate workflow_id
        if 'automation_id' in self.definition:
            workflow_id = self.definition['automation_id']
            if not isinstance(workflow_id, str) or not workflow_id.strip():
                errors.append("automation_id must be a non-empty string")
        
        # Validate nodes section if present
        if 'nodes' in self.definition:
            nodes = self.definition['nodes']
            if not isinstance(nodes, dict):
                errors.append("'nodes' must be a dictionary")
            else:
                for node_id, node_def in nodes.items():
                    if not isinstance(node_def, dict):
                        errors.append(f"Node '{node_id}' definition must be a dictionary")
                        continue
                    
                    # Check required node fields
                    required_node_fields = ['type', 'config']
                    for field in required_node_fields:
                        if field not in node_def:
                            errors.append(f"Node '{node_id}' missing required field: {field}")
        
        # Validate connections section if present
        if 'connections' in self.definition:
            connections = self.definition['connections']
            if not isinstance(connections, dict):
                errors.append("'connections' must be a dictionary")
            else:
                node_ids = set(self.definition.get('nodes', {}).keys())
                for source_id, targets in connections.items():
                    if source_id not in node_ids:
                        errors.append(f"Connection source '{source_id}' not found in nodes")
                    
                    if isinstance(targets, str):
                        targets = [targets]
                    elif not isinstance(targets, list):
                        errors.append(f"Connection targets for '{source_id}' must be string or list")
                        continue
                    
                    for target_id in targets:
                        if not isinstance(target_id, str):
                            errors.append(f"Connection target must be string, got {type(target_id)}")
                        elif target_id not in node_ids:
                            errors.append(f"Connection target '{target_id}' not found in nodes")
        
        return errors