import os
import json
import sqlalchemy
from sqlalchemy.exc import DBAPIError
from google.cloud.sql.connector import Connector
from google.oauth2 import service_account
from datetime import datetime
import pandas as pd
import logging


class CloudSQL:
    def __init__(self, zone, instance_name, credential_env_name=None):
        self.project_id = os.environ.get("GCP_PROJECT")
        self.zone = zone
        self.instance_name = instance_name

        if credential_env_name:
            credentials = service_account.Credentials.from_service_account_info(
                json.loads(os.environ.get(credential_env_name))
            )
            self.connector = Connector(credentials=credentials)
        else:
            self.connector = Connector()

        self.pool = None

    def _get_connection(self, user, password, db_name):
        return self.connector.connect(
            instance_connection_string=f"{self.project_id}:{self.zone}:{self.instance_name}",
            driver="pg8000",
            user=user,
            password=password,
            db=db_name,
        )

    def create_engine(self, user, password, db_name):
        """Create SQLAlchemy engine with connection pooling."""
        self.pool = sqlalchemy.create_engine(
            "postgresql+pg8000://",
            creator=lambda: self._get_connection(user, password, db_name),
            pool_recycle=1800,
            pool_size=5,
            max_overflow=2,
            pool_timeout=30,
            pool_pre_ping=True
        )
        return self.pool

    def query(self, query, fetch="all"):
        """Run a SQL query.
        
        fetch: 'one', 'all', or 'none'
        """
        if not self.pool:
            raise RuntimeError("Engine not initialized. Call create_engine() first.")

        try:
            with self.pool.connect() as conn:
                result = conn.execute(sqlalchemy.text(query))
                conn.commit()
                if fetch == "one":
                    return result.fetchone()
                elif fetch == "all":
                    return result.fetchall()
                else:
                    return None
        except Exception as e:
            logging.info(f"Error executing query: {e}")
            return None
    
    def create_partition(self, table: str, unique_dates: set[datetime]):
        """Create daily partitions for a given table based on unique dates."""
        if not self.pool:
            raise RuntimeError("Engine not initialized. Call create_engine() first.")

        with self.pool.connect() as conn:
            for date in unique_dates:
                partition_name = date.strftime("%Y_%m_%d")
                start_date = date.strftime("%Y-%m-%d 00:00:00")
                end_date = (date.replace(hour=23, minute=59, second=59)).strftime("%Y-%m-%d %H:%M:%S")

                sql = f"""
                CREATE TABLE IF NOT EXISTS {table}_{partition_name}
                PARTITION OF {table}
                FOR VALUES FROM ('{start_date}') TO ('{end_date}');
                """
                try:
                    conn.execute(sqlalchemy.text(sql))
                    conn.commit()
                    logging.info(f"Created partition {table}_{partition_name}")
                except Exception as e:
                    logging.info(f"Error creating partition for {date}: {e}")

    def insert_dataframe(self, df: pd.DataFrame, table_name: str, if_exists: str = 'append', index: bool = False):
        """
        Inserts a pandas DataFrame into a Cloud SQL database table.

        Args:
            df (pd.DataFrame): The DataFrame to insert.
            table_name (str): The name of the target database table.
            if_exists (str, optional): How to behave if the table already exists.
                                        'fail': Raise ValueError.
                                        'replace': Drop the table before inserting new data.
                                        'append': Insert new data into the existing table.
                                        Defaults to 'append'.
            index (bool, optional): Whether to write the DataFrame's index as a column.
                                     Defaults to False.
        """
        if self.pool is None:
            raise ValueError("SQLAlchemy engine has not been created. "
                             "Call create_engine() first.")

        logging.info(f"Attempting to insert DataFrame into table '{table_name}'...")
        try:
            with self.pool.begin() as connection:
                df.to_sql(
                    name=table_name,
                    con=connection,
                    if_exists='append',
                    index=index,
                    method='multi',
                    chunksize=1000
                )
            logging.info("DataFrame successfully inserted.")
        except DBAPIError as e:
            if e.connection_invalidated:
                logging.warning("Cloud SQL connection was invalidated, retrying once...")
                with self.pool.begin() as conn:
                    df.to_sql(
                        name=table_name,
                        con=connection,
                        if_exists='append',
                        index=index,
                        method='multi',
                        chunksize=1000
                    )
            else:
                raise
        except sqlalchemy.exc.OperationalError as e:
            logging.error(f"Database connection error: {e}")
            logging.error("Please check your database credentials and connection string.")
        except Exception as e:
            logging.error(f"An unexpected error occurred: {e}")

    def close(self):
        """Close the SQLAlchemy pool and Connector."""
        if self.pool:
            self.pool.dispose()
        self.connector.close()

    def __del__(self):
        """Ensure cleanup when object is destroyed."""
        self.close()
