import os
from connectors.looker.lookerSDK import LookerSDK
from connectors.firebase.firebase import Firebase
import concurrent.futures
from dotenv import load_dotenv
load_dotenv()
fb = Firebase(host=os.environ.get("FIREBASE_HOST"))

import json
from pydantic import BaseModel, Field, field_validator, ValidationError
from typing import List, Dict, Union, Any, Optional

# Define the Pydantic models to match the JSON structure

class Filter(BaseModel):
    dimension: str
    condition: str
    value: List[str]

class Position(BaseModel):
    x: int
    y: int

class Node(BaseModel):
    name: str
    id: str
    type: str
    position: Position
    filters: Union[List[Filter], None] = None  # Optional field
    audience_id: Union[str, None] = None      # Optional field

    @field_validator('type')
    @classmethod
    def validate_node_type(cls, v):
        if v not in ["builder", "audience"]:
            raise ValueError("Node type must be 'builder' or 'audience'")
        return v
    
    @field_validator('filters', mode='before')
    @classmethod
    def validate_builder_filters(cls, v, info):
        if info.data.get('type') == 'builder' and not isinstance(v, list):
            raise ValueError("A 'builder' node must have a 'filters' list")
        return v
    
    @field_validator('audience_id', mode='before')
    @classmethod
    def validate_audience_id(cls, v, info):
        if info.data.get('type') == 'audience' and not isinstance(v, str):
            raise ValueError("An 'audience' node must have an 'audience_id'")
        return v

class Connection(BaseModel):
    id: str
    source: List[str]
    operation: str
    position: Position

    @field_validator('source')
    @classmethod
    def validate_source_list(cls, v):
        if len(v) < 2:
            raise ValueError("Source must contain at least two items")
        return v

    @field_validator('operation')
    @classmethod
    def validate_operation_type(cls, v):
        allowed_operations = ["intersect", "union", "full_outer_exclusive", "left_join"]
        if v not in allowed_operations:
            raise ValueError(f"Operation must be one of {allowed_operations}")
        return v

class Calculation(BaseModel):
    type: str
    time: str

class Settings(BaseModel):
    type: str
    calculation: Optional[Calculation] = None

class Audience(BaseModel):
    id: str
    property_id: str
    name: str
    nodes: Dict[str, Node]
    connections: Optional[Union[str, Dict[str, Any]]] = None
    settings: Settings

class AudinceBuilder:
    def __init__(self, json: dict, cache:bool = True):
        self.json = json
        self.cache = cache
        self.property_id = json['property_id']
        self.audience_id = json['id']
        self.audience_name = json['name']
        self.nodes = json['nodes']
        self.connections = json['connections']
        self.settings = json['settings']
        self.looker = LookerSDK(property_id=self.property_id)

        self._validate_json_data()

    def _validate_json_data(self):
        """
        Validates the provided JSON data against the Pydantic model.

        Args:
            json_data (dict): The dictionary to validate.

        Returns:
            bool: True if the JSON is valid, False otherwise.
        """
        try:
            Audience(**self.json)
            print("JSON is valid!")
            return True, True
        except ValidationError as e:
            print("JSON is NOT valid.")
            print(e)
            return False, e
    
    def get_connection(self):
        return self.connections

    def get_setting(self):
        return self.settings

    def get_node_audience(self, node_id):
        filters = self.nodes[node_id]['filters']
        view = self.nodes[node_id]['view']
        conditions = self.looker.build_filter_expression([filters])
        data = self.looker.build_audience(view, conditions, self.cache)
        return data
    
    # def get_all_node_audience(self):
    #     audience_data = {}
    #     for i in self.nodes:
    #         if self.nodes[i]['type'] == 'builder':
    #             filters = self.nodes[i]['filters']
    #             conditions = self.looker.build_filter_expression([filters])
    #             data = self.looker.build_audience(conditions, self.cache)
    #             audience_data[i] = {
    #                 "user_pseudo_ids":  data[0]['event.list_of_user'] if data[0]['event.list_of_user'] else [],
    #                 "total": len(data[0]['event.list_of_user']) if data[0]['event.list_of_user'] else 0
    #             }
    #         elif self.nodes[i]['type'] == 'audience':
    #             audience_id = self.nodes[i]['audience_id']
    #             user_pseudo_ids = fb.db.reference().child(f"account/{self.property_id}/audience/{audience_id}/user_pseudo_id").get()
    #             audience_data[i] = {
    #                 "user_pseudo_ids":  user_pseudo_ids if user_pseudo_ids else [],
    #                 "total": len(user_pseudo_ids) if user_pseudo_ids else 0
    #             }
    #     self.audience_data = audience_data
    def get_all_node_audience(self):
        audience_data = {}
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
            future_to_node = {
                executor.submit(self._fetch_node_data, node_id, node_data): node_id
                for node_id, node_data in self.nodes.items()
            }

            for future in concurrent.futures.as_completed(future_to_node):
                node_id = future_to_node[future]
                try:
                    result = future.result()
                    audience_data[node_id] = result
                except Exception as exc:
                    audience_data[node_id] = {'user_pseudo_ids': [], 'total': 0}
                    print(f'{node_id} generated an exception: {exc}')

        self.audience_data = audience_data

    def _fetch_node_data(self, node_id, node_data):
        """A helper method to fetch data for a single node."""
        if node_data['type'] == 'builder':
            filters = node_data['filters']
            view = node_data['view']
            conditions = self.looker.build_filter_expression([filters])
            data = self.looker.build_audience(view, conditions, self.cache)
            user_list = data[0].get(f'{view}.list_of_user', [])
            print(node_data['id'], len(user_list))
            return {
                "user_pseudo_ids": user_list,
                "total": len(user_list)
            }
        elif node_data['type'] == 'audience':
            audience_id = node_data['audience_id']
            user_pseudo_ids = fb.db.reference().child(f"account/{self.property_id}/audience/{audience_id}/user_pseudo_id").get()
            user_list = user_pseudo_ids if user_pseudo_ids else []
            return {
                "user_pseudo_ids": user_list,
                "total": len(user_list)
            }

    def get_final_audience(self):
        # See connection and build audience with relation betwwen node (use self.get_all_node_audience first to get node audience)
        if not hasattr(self, "audience_data"):
            self.get_all_node_audience()

        def resolve_source(source_id):
            """Return set of users for either a node or a connection"""
            if source_id in self.nodes:
                return set(self.audience_data[source_id]["user_pseudo_ids"])
            elif source_id in self.connections:
                return compute_connection(source_id)
            else:
                raise ValueError(f"Unknown source id: {source_id}")

        def compute_connection(conn_id):
            conn = self.connections[conn_id]
            operation = conn["operation"]
            sources = conn["source"]

            # resolve all sources (nodes or other connections)
            source_sets = [resolve_source(src) for src in sources]

            if operation == "intersect":
                result = set.intersection(*source_sets)
            elif operation == "union":
                result = set.union(*source_sets)
            elif operation == "left_join":
                # assume first is left, others are right
                result = source_sets[0] - set.union(*source_sets[1:])
            elif operation == 'full_outer_exclusive':
                intersection = set.intersection(*source_sets)
                union = set.union(*source_sets)
                result = union - intersection
            else:
                raise ValueError(f"Unsupported operation: {operation}")

            # cache so we don’t recompute
            self.audience_data[conn_id] = {"user_pseudo_ids": list(result), 'total': len(result)}
            return result
        
        if self.connections != {}:

            if type(self.connections) == str:
                final_result = self.audience_data[self.connections]['user_pseudo_ids']
            elif type(self.connections) == dict:

                last_conn_id = list(self.connections.keys())[-1]
                final_result = compute_connection(last_conn_id)

                self.final_result = final_result
            else:
                final_result = []
        else:
            final_result = []

        return list(final_result)