from abc import ABC, abstractmethod
from typing import Dict, Set, Union, List, Optional
from bigquery.bq import BigQuery
from firebase.firebase import Firebase
import os
from datetime import datetime, timedelta
import json

fb = Firebase(host=os.environ.get("FIREBASE_HOST"))

class QueryBuilder:
    """
    Query builder for constructing BigQuery SQL based on user selections
    """
    
    def __init__(self, property_id: str):
        self.property_id = property_id
        self.base_table = f"`{os.environ.get('GCP_PROJECT')}.client_{property_id}.event`"
        self.user_table = f"`{os.environ.get('GCP_PROJECT')}.client_{property_id}.user`"
        
    def build_query(self, selections: Dict) -> str:
        """
        Build a BigQuery SQL query based on user selections
        
        Args:
            selections: Dictionary containing user selections
                - filters: List of filter conditions
                - date_range: Dictionary with start_date and end_date
                - user_properties: List of user properties to include
                - event_types: List of event types to filter
                - limit: Optional limit for results
        
        Returns:
            str: Generated SQL query
        """
        query_parts = []
        
        # SELECT clause
        select_clause = self._build_select_clause(selections.get('user_properties', []))
        query_parts.append(f"SELECT DISTINCT {select_clause}")
        
        # FROM clause
        query_parts.append(f"FROM {self.base_table}")
        
        # WHERE clause
        where_conditions = self._build_complex_conditions(selections)
        if where_conditions:
            query_parts.append(f"WHERE {' AND '.join(where_conditions)}")
        
        # ORDER BY clause
        if selections.get('order_by'):
            query_parts.append(f"ORDER BY {selections['order_by']}")
        
        # LIMIT clause
        if selections.get('limit'):
            query_parts.append(f"LIMIT {selections['limit']}")
        
        return "\n".join(query_parts)
    
    def _build_complex_conditions(self, selections: Dict) -> List[str]:
        conditions = []
        # Handle include
        if 'include' in selections:
            include_conditions = self._build_where_conditions(selections['include'])
            if include_conditions:
                conditions.append('(' + ' AND '.join(include_conditions) + ')')
        # Handle exclude
        if 'exclude' in selections:
            exclude_conditions = self._build_where_conditions(selections['exclude'])
            if exclude_conditions:
                conditions.append('NOT (' + ' AND '.join(exclude_conditions) + ')')
        # Fallback to legacy if neither
        if not conditions:
            conditions = self._build_where_conditions(selections)
        return conditions

    def _build_select_clause(self, user_properties: List[str]) -> str:
        """Build the SELECT clause with user properties"""
        select_fields = ["user_pseudo_id"]
        
        # Add user properties if specified
        if user_properties:
            for prop in user_properties:
                if prop in ['user_id', 'email', 'phone', 'first_name', 'last_name']:
                    select_fields.append(f"u.{prop}")
        
        return ", ".join(select_fields)
    
    def _build_where_conditions(self, selections: Dict) -> List[str]:
        """Build WHERE conditions based on user selections"""
        conditions = []
        
        # Date range filter
        if selections.get('date_range'):
            date_range = selections.get('date_range')
            operator = date_range['operator']
            custom_start = date_range.get('start', None)
            custom_end = date_range.get('end', None)
            date_range = self.build_date_range(operator, custom_start, custom_end)
            # date_range = selections['date_range']
            if date_range.get('start_date'):
                conditions.append(f"eventTimeStamp >= '{date_range['start_date']}'")
            if date_range.get('end_date'):
                conditions.append(f"eventTimeStamp <= '{date_range['end_date']}'")
        
        # Event type filter
        if selections.get('event'):
            event_types = selections['event']
            event_conditions = [f"eventName = '{event}'" for event in event_types]
            conditions.append(f"({' OR '.join(event_conditions)})")
        
        # Custom filters
        if selections.get('filters'):
            for filter_condition in selections['filters']:
                conditions.append(self._build_filter_condition(filter_condition))
        
        # User property filters
        if selections.get('user_filters'):
            for user_filter in selections['user_filters']:
                conditions.append(self._build_user_filter_condition(user_filter))
        
        return conditions
    
    def _build_filter_condition(self, filter_condition: Dict) -> str:
        """Build a filter condition for event parameters"""
        field = filter_condition.get('field')
        operator = filter_condition.get('operator', '=')
        value = filter_condition.get('value')
        
        if operator in ['IN', 'NOT IN']:
            if isinstance(value, list):
                value_str = "', '".join(value)
                return f"{field} {operator} ('{value_str}')"
            else:
                return f"{field} {operator} ('{value}')"
        elif operator in ['LIKE', 'NOT LIKE']:
            return f"{field} {operator} '%{value}%'"
        else:
            return f"{field} {operator} '{value}'"
    
    def _build_user_filter_condition(self, user_filter: Dict) -> str:
        """Build a filter condition for user properties"""
        field = user_filter.get('field')
        operator = user_filter.get('operator', '=')
        value = user_filter.get('value')
        
        if operator in ['IN', 'NOT IN']:
            if isinstance(value, list):
                value_str = "', '".join(value)
                return f"u.{field} {operator} ('{value_str}')"
            else:
                return f"u.{field} {operator} ('{value}')"
        elif operator in ['LIKE', 'NOT LIKE']:
            return f"u.{field} {operator} '%{value}%'"
        else:
            return f"u.{field} {operator} '{value}'"
    
    def get_available_filters(self) -> Dict:
        """Get available filter options for the query builder"""
        return {
            "event_types": [
                "page_view",
                "scroll",
                "click",
                "form_submit",
                "purchase",
                "add_to_cart",
                "remove_from_cart",
                "view_item",
                "begin_checkout",
                "add_payment_info",
                "add_shipping_info",
                "purchase_refund"
            ],
            "user_properties": [
                "user_id",
                "email", 
                "phone",
                "first_name",
                "last_name",
                "country",
                "city",
                "device_category",
                "platform"
            ],
            "operators": [
                "=",
                "!=",
                ">",
                ">=",
                "<",
                "<=",
                "IN",
                "NOT IN",
                "LIKE",
                "NOT LIKE"
            ],
            "date_operators": [
                "last_7_days",
                "last_30_days", 
                "last_90_days",
                "last_12_months",
                "custom_range"
            ]
        }
    
    def build_date_range(self, date_operator: str, custom_start: str = None, custom_end: str = None) -> Dict:
        """Build date range based on operator or custom dates"""
        now = datetime.now()
        
        if date_operator == "last_7_days":
            start_date = (now - timedelta(days=7)).strftime("%Y-%m-%d")
            end_date = now.strftime("%Y-%m-%d")
        elif date_operator == "last_30_days":
            start_date = (now - timedelta(days=30)).strftime("%Y-%m-%d")
            end_date = now.strftime("%Y-%m-%d")
        elif date_operator == "last_90_days":
            start_date = (now - timedelta(days=90)).strftime("%Y-%m-%d")
            end_date = now.strftime("%Y-%m-%d")
        elif date_operator == "last_12_months":
            start_date = (now - timedelta(days=365)).strftime("%Y-%m-%d")
            end_date = now.strftime("%Y-%m-%d")
        elif date_operator == "custom_range":
            start_date = custom_start
            end_date = custom_end
        else:
            # Default to last 30 days
            start_date = (now - timedelta(days=30)).strftime("%Y-%m-%d")
            end_date = now.strftime("%Y-%m-%d")
        
        return {
            "start_date": start_date,
            "end_date": end_date
        }
    
    def validate_selections(self, selections: Dict) -> Dict:
        """
        Validate user selections and return any errors or warnings
        
        Returns:
            Dict with 'valid' boolean and 'errors'/'warnings' lists
        """
        errors = []
        warnings = []
        
        # Check required fields
        if not selections.get('property_id'):
            errors.append("property_id is required")
        
        # Validate date range
        if selections.get('date_range'):
            date_range = selections['date_range']
            if date_range.get('start_date') and date_range.get('end_date'):
                try:
                    start = datetime.strptime(date_range['start_date'], "%Y-%m-%d")
                    end = datetime.strptime(date_range['end_date'], "%Y-%m-%d")
                    if start > end:
                        errors.append("Start date cannot be after end date")
                except ValueError:
                    errors.append("Invalid date format. Use YYYY-MM-DD")
        
        # Validate event types
        if selections.get('event_types'):
            available_types = self.get_available_filters()['event_types']
            for event_type in selections['event_types']:
                if event_type not in available_types:
                    warnings.append(f"Unknown event type: {event_type}")
        
        # Validate user properties
        if selections.get('user_properties'):
            available_props = self.get_available_filters()['user_properties']
            for prop in selections['user_properties']:
                if prop not in available_props:
                    warnings.append(f"Unknown user property: {prop}")
        
        # Validate filters
        if selections.get('filters'):
            available_operators = self.get_available_filters()['operators']
            for i, filter_condition in enumerate(selections['filters']):
                if not filter_condition.get('field'):
                    errors.append(f"Filter {i+1}: field is required")
                if filter_condition.get('operator') not in available_operators:
                    errors.append(f"Filter {i+1}: invalid operator")
        
        return {
            'valid': len(errors) == 0,
            'errors': errors,
            'warnings': warnings
        }
    
    def preview_query(self, selections: Dict) -> Dict:
        """
        Preview the query that would be generated without executing it
        
        Returns:
            Dict with 'query' string and 'estimated_cost' info
        """
        try:
            query = self.build_query(selections)
            return {
                'query': query,
                'estimated_cost': 'Query preview generated successfully',
                'valid': True
            }
        except Exception as e:
            return {
                'query': None,
                'error': str(e),
                'valid': False
            }
    
    def get_query_templates(self) -> Dict:
        """Get predefined query templates for common use cases"""
        return {
            "recent_purchasers": {
                "name": "Recent Purchasers",
                "description": "Users who made a purchase in the last 30 days",
                "selections": {
                    "event_types": ["purchase"],
                    "date_range": {"operator": "last_30_days"},
                    "user_properties": ["user_id", "email"]
                }
            },
            "cart_abandoners": {
                "name": "Cart Abandoners", 
                "description": "Users who added to cart but didn't purchase",
                "selections": {
                    "event_types": ["add_to_cart"],
                    "exclude_events": ["purchase"],
                    "date_range": {"operator": "last_30_days"},
                    "user_properties": ["user_id", "email"]
                }
            },
            "high_value_users": {
                "name": "High Value Users",
                "description": "Users with multiple purchases",
                "selections": {
                    "event_types": ["purchase"],
                    "filters": [
                        {"field": "purchase_count", "operator": ">=", "value": "3"}
                    ],
                    "date_range": {"operator": "last_90_days"},
                    "user_properties": ["user_id", "email", "first_name", "last_name"]
                }
            },
            "new_users": {
                "name": "New Users",
                "description": "Users who registered in the last 7 days",
                "selections": {
                    "event_types": ["first_visit"],
                    "date_range": {"operator": "last_7_days"},
                    "user_properties": ["user_id", "email", "first_name", "last_name"]
                }
            }
        }
    
    def estimate_audience_size(self, selections: Dict, client) -> Dict:
        """
        Estimate the size of the audience without fetching all user IDs
        
        Returns:
            Dict with estimated count and confidence level
        """
        try:
            # Create a count query instead of fetching all user IDs
            count_selections = selections.copy()
            count_selections['limit'] = None  # Remove limit for counting
            
            # Modify the query to count distinct users
            query = self.build_query(count_selections)
            count_query = query.replace("SELECT DISTINCT", "SELECT COUNT(DISTINCT")
            count_query = count_query.split("ORDER BY")[0] if "ORDER BY" in count_query else count_query
            count_query = count_query.split("LIMIT")[0] if "LIMIT" in count_query else count_query
            
            # Execute count query
            result = client.get_query_df(count_query)
            estimated_count = result.iloc[0, 0] if not result.empty else 0
            
            return {
                'estimated_count': int(estimated_count),
                'confidence': 'high' if estimated_count < 10000 else 'medium',
                'query_used': count_query
            }
        except Exception as e:
            return {
                'estimated_count': None,
                'error': str(e),
                'confidence': 'low'
            }


class Node(ABC):
    def __init__(self, node_id: str, name: str):
        self.id = node_id
        self.name = name

    @abstractmethod
    def evaluate(self) -> Set[str]:
        pass


class QueryNode(Node):
    def __init__(self, node_id: str, name: str, query: str, client=None):
        super().__init__(node_id, name)
        self.query = query
        self.client = client

    def evaluate(self) -> Set[str]:
        query_job = self.client.get_query_df(self.query)
        user_ids = query_job['user_pseudo_id'].unique() if 'user_pseudo_id' in query_job else []

        user_ids = set(user_ids)
        print(f"QueryNode '{self.name}' fetched {len(user_ids)} user(s)\n")
        return user_ids


class AudienceReferenceNode(Node):
    def __init__(self, property_id:str, node_id: str, name: str, audience_id: str):
        super().__init__(node_id, name)
        self.property_id = property_id
        self.audience_id = audience_id

    def evaluate(self) -> Set[str]:
        # Get Audience from Firebase
        user_ids = fb.db.reference().child(f"account/{self.property_id}/audience/{self.audience_id}/user_pseudo_id").get()
        return set(user_ids) if user_ids else set()


class DynamicQueryNode(Node):
    def __init__(self, node_id: str, name: str, selections: Dict, client=None):
        super().__init__(node_id, name)
        self.selections = selections
        self.client = client
        self.query_builder = None

    def evaluate(self) -> Set[str]:
        # Initialize query builder if not already done
        if not self.query_builder:
            property_id = self.selections.get('property_id')
            if not property_id:
                raise ValueError("property_id is required for DynamicQueryNode")
            self.query_builder = QueryBuilder(property_id)
        
        # Build query from selections
        query = self.query_builder.build_query(self.selections)
        
        # Execute query
        query_job = self.client.get_query_df(query)
        user_ids = query_job['user_pseudo_id'].unique() if 'user_pseudo_id' in query_job else []

        user_ids = set(user_ids)
        print(f"DynamicQueryNode '{self.name}' fetched {len(user_ids)} user(s)")
        print(f"Generated query: {query}\n")
        return user_ids


class JoinStep:
    def __init__(self, id: str, left_node: str, right_node: str, operation: str, result_name: str):
        self.id = id
        self.left = left_node
        self.right = right_node
        self.operation = operation
        self.result_name = result_name


class AudienceBuilder:
    def __init__(self,property_id:str, audience_id:str, data: Dict, client=None):
        self.data = data
        self.property_id = property_id
        self.audience_id = audience_id
        self.nodes_raw = {node["id"]: node for node in data["nodes"]}
        self.steps = [JoinStep(**step) for step in data["connections"].get("steps", [])]
        self.final_step = data["connections"].get("final_step")
        self.results: Dict[str, Set[str]] = {}
        self.node_objects: Dict[str, Node] = {}
        self.node_sizes: Dict[str, int] = {}
        self.step_sizes: Dict[str, int] = {}
        self.client = client

        self._initialize_nodes()

    def _initialize_nodes(self):
        for node_id, node_data in self.nodes_raw.items():
            # print(node_data)
            if node_data["type"] == "query":
                node = QueryNode(node_data["id"], node_data["name"], node_data["query"], self.client)
            elif node_data["type"] == "audience":
                node = AudienceReferenceNode(self.property_id, node_data["id"], node_data["name"], node_data["audience_id"])
            elif node_data["type"] == "dynamic_query":
                node = DynamicQueryNode(node_data["id"], node_data["name"], node_data["selections"], self.client)
            else:
                raise ValueError(f"Unknown node type: {node_data['type']}")
            self.node_objects[node_id] = node

    def _resolve(self, ref: str) -> Set[str]:
        if ref in self.results:
            return self.results[ref]
        if ref in self.node_objects:
            result = self.node_objects[ref].evaluate()
            self.results[ref] = result
            self.node_sizes[ref] = len(result)
            return result
        
        # print(self.node_objects[ref])
        raise ValueError(f"Unknown reference: {ref}")

    def _execute_step(self, step: JoinStep):
        left_set = self._resolve(step.left)
        right_set = self._resolve(step.right)

        if step.operation == "intersect":
            result = left_set & right_set
        elif step.operation == "left_outer":
            result = left_set
        elif step.operation == "left_inner":
            result = left_set - right_set
        elif step.operation == "full_outer":
            result = left_set | right_set
        else:
            raise ValueError(f"Unsupported operation: {step.operation}")

        self.results[step.id] = result
        self.step_sizes[step.id] = len(result)

    def execute(self) -> Dict:
        for step in self.steps:
            self._execute_step(step)

        final_set = self.results.get(self.final_step, set())
        return self._build_output_json(final_set)
    
    def _build_output_json(self, final_result: Set[str]) -> Dict:
        return {
            "audience_id": self.audience_id,
            "nodes": [
                {
                    "id": node_id,
                    "name": self.nodes_raw[node_id]["name"],
                    "type": self.nodes_raw[node_id]["type"],
                    "size": self.node_sizes[node_id]
                }
                for node_id in self.nodes_raw
            ],
            "steps": [
                {
                    "id": step.id,
                    "operation": step.operation,
                    "left": step.left,
                    "right": step.right,
                    "result_name": step.result_name,
                    "size": self.step_sizes.get(step.id, 0)
                }
                for step in self.steps
            ],
            "final_size": len(final_result),
            "user_pseudo_id": list(final_result)
        }


# Example usage and documentation
"""
QUERY BUILDER USAGE EXAMPLES
============================

The QueryBuilder allows you to construct BigQuery SQL queries based on user selections.
Here are some examples of how to use it:

1. Basic Query Builder Usage:
```python
# Initialize query builder
qb = QueryBuilder("your_property_id")

# Define user selections
selections = {
    "property_id": "your_property_id",
    "event_types": ["purchase", "add_to_cart"],
    "date_range": {
        "start_date": "2024-01-01",
        "end_date": "2024-01-31"
    },
    "user_properties": ["user_id", "email"],
    "limit": 1000
}

# Build and execute query
query = qb.build_query(selections)
print(query)
```

2. Using Dynamic Query Node:
```python
# Create audience builder with dynamic query node
data = {
    "id": "audience_1",
    "name": "Recent Purchasers",
    "property_id": "your_property_id",
    "nodes": {
        "node_1": {
            "id": "node_1",
            "name": "Recent Purchasers Query",
            "type": "dynamic_query",
            "selections": {
                "property_id": "your_property_id",
                "event_types": ["purchase"],
                "date_range": {"operator": "last_30_days"},
                "user_properties": ["user_id", "email"]
            }
        }
    },
    "connections": {
        "steps": [],
        "final_step": "node_1"
    }
}

# Initialize and execute
builder = AudienceBuilder(data, client=BigQuery())
result = builder.execute()
```

3. Advanced Filtering:
```python
selections = {
    "property_id": "your_property_id",
    "event_types": ["purchase"],
    "filters": [
        {
            "field": "value",
            "operator": ">=",
            "value": "100"
        },
        {
            "field": "currency",
            "operator": "=",
            "value": "USD"
        }
    ],
    "user_filters": [
        {
            "field": "country",
            "operator": "=",
            "value": "US"
        }
    ],
    "date_range": {"operator": "last_90_days"}
}
```

4. Query Templates:
```python
qb = QueryBuilder("your_property_id")
templates = qb.get_query_templates()

# Use a template
recent_purchasers = templates["recent_purchasers"]["selections"]
recent_purchasers["property_id"] = "your_property_id"
query = qb.build_query(recent_purchasers)
```

5. Validation and Preview:
```python
# Validate selections before building query
validation = qb.validate_selections(selections)
if validation['valid']:
    # Preview the query
    preview = qb.preview_query(selections)
    print(preview['query'])
else:
    print("Errors:", validation['errors'])
```

6. Estimate Audience Size:
```python
# Estimate size without fetching all user IDs
estimate = qb.estimate_audience_size(selections, client)
print(f"Estimated audience size: {estimate['estimated_count']}")
```

AVAILABLE FILTERS
================

Event Types:
- page_view, scroll, click, form_submit, purchase, add_to_cart
- remove_from_cart, view_item, begin_checkout, add_payment_info
- add_shipping_info, purchase_refund

User Properties:
- user_id, email, phone, first_name, last_name
- country, city, device_category, platform

Operators:
- =, !=, >, >=, <, <=, IN, NOT IN, LIKE, NOT LIKE

Date Operators:
- last_7_days, last_30_days, last_90_days, last_12_months, custom_range
"""
