import os
import json
import sqlalchemy
from sqlalchemy.exc import DBAPIError
from google.cloud.sql.connector import Connector
from google.oauth2 import service_account
from datetime import datetime
import pandas as pd
import logging
import threading
from typing import Optional, Dict, Any, List
from contextlib import contextmanager
from dotenv import load_dotenv
load_dotenv()

logger = logging.getLogger(__name__)


class DatabaseConnection:
    """
    Google Cloud SQL database connection manager with connection pooling.
    Provides thread-safe connection management and query execution.
    """
    
    _instance = None
    _lock = threading.Lock()
    
    def __new__(cls, *args, **kwargs):
        """Singleton pattern to ensure single connection pool instance."""
        if not cls._instance:
            with cls._lock:
                if not cls._instance:
                    cls._instance = super(DatabaseConnection, cls).__new__(cls)
        return cls._instance
    
    def __init__(self, 
                 zone: str = None,
                 instance_name: str = None,
                 credential_env_name: str = None,
                 user: str = None,
                 password: str = None,
                 database: str = None,
                 min_connections: int = 1,
                 max_connections: int = 10):
        """
        Initialize Google Cloud SQL connection pool.
        
        Args:
            zone: GCP zone (defaults to env var GCP_ZONE)
            instance_name: Cloud SQL instance name (defaults to env var GCP_INSTANCE_NAME)
            credential_env_name: Environment variable name containing service account JSON
            user: Database user (defaults to env var DB_USER or 'postgres')
            password: Database password (defaults to env var DB_PASSWORD)
            database: Database name (defaults to env var DB_NAME or 'marketing_automation')
            min_connections: Minimum connections in pool
            max_connections: Maximum connections in pool
        """
        # Prevent re-initialization of singleton
        if hasattr(self, '_initialized'):
            return
            
        self.project_id = os.environ.get("GCP_PROJECT")
        self.zone = zone or os.environ.get("POSTGRES_ZONE")
        self.instance_name = instance_name or os.environ.get("INSTANCE_NAME")
        self.user = user or os.getenv('POSTGRES_USER')
        self.password = password or os.getenv('POSTGRES_PASSWORD')
        self.database = database or os.getenv('POSTGRES_DB', 'c360')
        
        if not all([self.project_id, self.zone, self.instance_name]):
            raise ValueError("Missing required GCP configuration: GCP_PROJECT, POSTGRES_ZONE, INSTANCE_NAME")

        if credential_env_name:
            print("HAVE credential_env_name")
            credentials = service_account.Credentials.from_service_account_info(
                json.loads(os.environ.get(credential_env_name))
            )
            self.connector = Connector(credentials=credentials)
        else:
            print("New credential_env_name")
            # Try to use the CUSTOMER_PROFILE_SERVICE_ACCOUNT from your .env
            customer_profile_json = os.environ.get("CUSTOMER_PROFILE_SERVICE_ACCOUNT")
            if customer_profile_json:
                try:
                    credentials = service_account.Credentials.from_service_account_info(
                        json.loads(customer_profile_json)
                    )
                    self.connector = Connector(credentials=credentials)
                    logger.info("Using CUSTOMER_PROFILE_SERVICE_ACCOUNT credentials")
                except Exception as e:
                    logger.warning(f"Failed to use CUSTOMER_PROFILE_SERVICE_ACCOUNT: {e}")
                    self.connector = Connector()
            else:
                self.connector = Connector()

        self.pool = None
        self.min_connections = min_connections
        self.max_connections = max_connections
        self._initialized = True
        
        self._create_connection_pool()

    def _get_connection(self):
        """Get a raw connection from Google Cloud SQL."""
        return self.connector.connect(
            instance_connection_string=f"{self.project_id}:{self.zone}:{self.instance_name}",
            driver="pg8000",
            user=self.user,
            password=self.password,
            db=self.database,
        )

    def _create_connection_pool(self):
        """Create SQLAlchemy engine with connection pooling."""
        try:
            self.pool = sqlalchemy.create_engine(
                "postgresql+pg8000://",
                creator=self._get_connection,
                pool_recycle=1800,
                pool_size=self.max_connections,
                max_overflow=2,
                pool_timeout=30,
                pool_pre_ping=True
            )
            logger.info(f"Google Cloud SQL connection pool created successfully. "
                       f"Pool size: {self.min_connections}-{self.max_connections}")
        except Exception as e:
            logger.error(f"Failed to create Google Cloud SQL connection pool: {e}")
            logger.warning("Database logging will be disabled. Workflow execution will continue without database logging.")
            self.pool = None  # Set to None to indicate database is not available

    @contextmanager
    def get_connection(self):
        """
        Context manager to get a connection from the pool.
        
        Yields:
            sqlalchemy.engine.Connection: Database connection
        """
        if not self.pool:
            raise RuntimeError("Connection pool not initialized")
        
        connection = None
        try:
            connection = self.pool.connect()
            yield connection
        except Exception as e:
            if connection:
                connection.rollback()
            logger.error(f"Database connection error: {e}")
            raise

    @contextmanager
    def get_cursor(self, commit: bool = True):
        """
        Context manager to get a cursor with automatic transaction management.
        Note: SQLAlchemy doesn't use cursors in the same way as psycopg2.
        This method provides compatibility with the existing interface.
        
        Args:
            commit: Whether to commit the transaction automatically
            
        Yields:
            SQLAlchemy connection that can execute queries
        """
        with self.get_connection() as connection:
            trans = connection.begin()
            try:
                yield connection
                if commit:
                    trans.commit()
            except Exception as e:
                trans.rollback()
                logger.error(f"Database transaction error: {e}")
                raise

    def execute_query(self, query: str, params: tuple = None, fetch: bool = False) -> Optional[List[Dict[str, Any]]]:
        """
        Execute a SQL query.
        
        Args:
            query: SQL query string
            params: Query parameters (tuple or dict)
            fetch: Whether to fetch and return results
            
        Returns:
            Query results if fetch=True, None otherwise
        """
        if not self.pool:
            logger.warning("Database not available, skipping query execution")
            return [] if fetch else None
            
        try:
            with self.get_connection() as conn:
                if params:
                    # Convert tuple params to dict for SQLAlchemy compatibility
                    if isinstance(params, tuple):
                        # For positional parameters, convert to named parameters
                        param_dict = {f'param_{i}': param for i, param in enumerate(params)}
                        # Replace %s with :param_0, :param_1, etc.
                        modified_query = query
                        for i in range(len(params)):
                            modified_query = modified_query.replace('%s', f':param_{i}', 1)
                        result = conn.execute(sqlalchemy.text(modified_query), param_dict)
                    else:
                        # Already a dict, use as-is
                        result = conn.execute(sqlalchemy.text(query), params)
                else:
                    result = conn.execute(sqlalchemy.text(query))
                
                conn.commit()
                
                if fetch:
                    # Convert SQLAlchemy result to list of dicts for compatibility
                    rows = result.fetchall()
                    if rows:
                        return [dict(row._mapping) for row in rows]
                    return []
                return None
        except Exception as e:
            logger.error(f"Error executing query: {e}")
            return [] if fetch else None

    def execute_many(self, query: str, params_list: List[tuple]) -> None:
        """
        Execute a SQL query multiple times with different parameters.
        
        Args:
            query: SQL query string
            params_list: List of parameter tuples
        """
        try:
            with self.get_connection() as conn:
                for params in params_list:
                    conn.execute(sqlalchemy.text(query), params)
                conn.commit()
        except Exception as e:
            logger.error(f"Error executing batch query: {e}")
            raise

    def insert_and_return_id(self, query: str, params: tuple = None) -> int:
        """
        Execute an INSERT query and return the generated ID.
        
        Args:
            query: INSERT SQL query string
            params: Query parameters (tuple or dict)
            
        Returns:
            Generated ID from the insert
        """
        if not self.pool:
            logger.warning("Database not available, returning dummy ID")
            return 1  # Return dummy ID when database is not available
            
        try:
            with self.get_connection() as conn:
                if params:
                    # Convert tuple params to dict for SQLAlchemy compatibility
                    if isinstance(params, tuple):
                        # For positional parameters, convert to named parameters
                        param_dict = {f'param_{i}': param for i, param in enumerate(params)}
                        # Replace %s with :param_0, :param_1, etc.
                        modified_query = query
                        for i in range(len(params)):
                            modified_query = modified_query.replace('%s', f':param_{i}', 1)
                        result = conn.execute(sqlalchemy.text(modified_query), param_dict)
                    else:
                        # Already a dict, use as-is
                        result = conn.execute(sqlalchemy.text(query), params)
                else:
                    result = conn.execute(sqlalchemy.text(query))
                
                conn.commit()
                
                # Get the inserted ID
                row = result.fetchone()
                if row:
                    return dict(row._mapping)['id']
                raise Exception("No ID returned from insert query")
        except Exception as e:
            logger.error(f"Error executing insert query: {e}")
            return 1  # Return dummy ID on error

    def insert_row(self, query: str, params: tuple = None) -> int:
        """
        Execute an INSERT query and return the generated ID.
        
        Args:
            query: INSERT SQL query string
            params: Query parameters (tuple or dict)
            
        Returns:
            Generated ID from the insert
        """
        if not self.pool:
            logger.warning("Database not available, returning dummy ID")
            return 1  # Return dummy ID when database is not available
            
        try:
            with self.get_connection() as conn:
                if params:
                    # Convert tuple params to dict for SQLAlchemy compatibility
                    if isinstance(params, tuple):
                        # For positional parameters, convert to named parameters
                        param_dict = {f'param_{i}': param for i, param in enumerate(params)}
                        # Replace %s with :param_0, :param_1, etc.
                        modified_query = query
                        for i in range(len(params)):
                            modified_query = modified_query.replace('%s', f':param_{i}', 1)
                        result = conn.execute(sqlalchemy.text(modified_query), param_dict)
                    else:
                        # Already a dict, use as-is
                        result = conn.execute(sqlalchemy.text(query), params)
                else:
                    result = conn.execute(sqlalchemy.text(query))
                
                conn.commit()
                return True
                # raise Exception("No ID returned from insert query")
        except Exception as e:
            logger.error(f"Error executing insert query: {e}")
            return 1  # Return dummy ID on error

    def insert_many_and_return_ids(self, query: str, values: list[tuple]) -> list[int]:
        """
        Execute a batch INSERT query and return generated IDs.

        Args:
            query: INSERT SQL query string with VALUES (:param, ...)
            values: List of dictionaries or tuples containing values to insert

        Returns:
            List of generated IDs from the inserts
        """
        with self.get_connection() as conn:
            result = conn.execute(query, values)
            return [row[0] for row in result.fetchall()]

    def batch_insert_and_return_ids(self, query: str, values: list[tuple]) -> list[int]:
        """
        Execute a batch INSERT query and return generated IDs.

        Args:
            query: INSERT SQL query string with VALUES %s placeholder
            values: List of tuples containing values to insert

        Returns:
            List of generated IDs from the inserts
        """
        with self.get_connection() as conn:
            result = conn.execute(query, values)
            return [row[0] for row in result.fetchall()]

    def table_exists(self, table_name: str) -> bool:
        """
        Check if a table exists in the database.
        
        Args:
            table_name: Name of the table to check
            
        Returns:
            True if table exists, False otherwise
        """
        query = """
        SELECT EXISTS (
            SELECT FROM information_schema.tables 
            WHERE table_schema = 'public' 
            AND table_name = %s
        );
        """
        result = self.execute_query(query, (table_name,), fetch=True)
        return result[0]['exists'] if result else False

    def create_schema(self, schema_file_path: str = None) -> None:
        """
        Create database schema from SQL file.
        
        Args:
            schema_file_path: Path to SQL schema file
        """
        if not schema_file_path:
            schema_file_path = os.path.join(os.path.dirname(__file__), '..', '..', 'config', 'database_schema.sql')
        
        try:
            with open(schema_file_path, 'r') as file:
                schema_sql = file.read()
            
            with self.get_connection() as conn:
                conn.execute(sqlalchemy.text(schema_sql))
                conn.commit()
            
            logger.info("Database schema created successfully")
        except Exception as e:
            logger.error(f"Failed to create database schema: {e}")
            raise

    def update(self, query, fetch='all'):
        if not self.pool:
            logger.warning("Database not available, returning empty result")
            if fetch == "one":
                return None
            elif fetch == "all":
                return []
            else:
                return None
        try:
            with self.pool.connect() as conn:
                result = conn.execute(sqlalchemy.text(query))
                conn.commit()
        except Exception as e:
            logger.error(f"Error executing query: {e}")
            if fetch == "one":
                return None
            elif fetch == "all":
                return []
            else:
                return None

    def query(self, query, fetch="all"):
        """
        Legacy method for backward compatibility.
        Run a SQL query.
        
        fetch: 'one', 'all', or 'none'
        """
        if not self.pool:
            logger.warning("Database not available, returning empty result")
            if fetch == "one":
                return None
            elif fetch == "all":
                return []
            else:
                return None

        try:
            with self.pool.connect() as conn:
                result = conn.execute(sqlalchemy.text(query))
                conn.commit()
                if fetch == "one":
                    row = result.fetchone()
                    return dict(row._mapping) if row else None
                elif fetch == "all":
                    rows = result.fetchall()
                    return [dict(row._mapping) for row in rows] if rows else []
                elif fetch == "delete":
                    return result.rowcount  # rows affected
                else:
                    return None
        except Exception as e:
            logger.error(f"Error executing query: {e}")
            if fetch == "one":
                return None
            elif fetch == "all":
                return []
            else:
                return None

    def close_all_connections(self):
        """Close all connections in the pool."""
        self.close()

    def get_pool_status(self) -> Dict[str, int]:
        """
        Get connection pool status information.
        
        Returns:
            Dictionary with pool status information
        """
        if not self.pool:
            return {"error": "Pool not initialized"}
        
        return {
            "min_connections": self.min_connections,
            "max_connections": self.max_connections,
            "engine_status": "active" if self.pool else "inactive"
        }

    def close(self):
        """Close the SQLAlchemy pool and Connector."""
        if self.pool:
            self.pool.dispose()
        self.connector.close()

    def __del__(self):
        """Ensure cleanup when object is destroyed."""
        try:
            self.close()
        except:
            pass  # Ignore errors during cleanup
