import os
import pytz
import requests
import facebook

from facebook_business.api import FacebookAdsApi
from facebook_business.adobjects.user import User
from facebook_business.adobjects.page import Page
from datetime import datetime, timedelta

class Facebook:
    def get_all_page():
        access_token = os.getenv('FB_TOKEN')
        FacebookAdsApi.init(access_token=access_token)
        me = User('me')
        fields = ['id', 'name','access_token']
        params = {'limit': 1000}
        pages = me.get_accounts(fields=fields, params=params)
        return pages
    
    def get_page_obj(page_access_token, page_id):
        FacebookAdsApi.init(access_token=page_access_token)
        page_obj = Page(page_id)
        return page_obj
    
    def fetch_recent_messages(thread_id, access_token, cutoff_time, max_messages=100):
        all_messages = []
        url = f"https://graph.facebook.com/v19.0/{thread_id}/messages"
        params = {
            "access_token": access_token,
            "fields": "message,attachments{file_url},from,created_time",
            "limit": 25
        }

        while url and len(all_messages) < max_messages:
            try:
                res = requests.get(url, params=params).json()
                messages = res.get("data", [])
                for msg in messages:
                    created_str = msg.get("created_time")
                    if not created_str:
                        continue
                    created_dt = datetime.strptime(created_str, "%Y-%m-%dT%H:%M:%S%z")
                    if created_dt < cutoff_time:
                        return all_messages
                    all_messages.append(msg)

                url = res.get("paging", {}).get("next")
                params = None
            except Exception as e:
                print(f"⚠️ Error fetching messages for {thread_id}: {e}")
                break

        return all_messages

    def needs_more_messages(messages, cutoff_time):
        if len(messages) < 10:
            return False
        for msg in messages:
            created_str = msg.get("created_time")
            if not created_str:
                continue
            created_dt = datetime.strptime(created_str, "%Y-%m-%dT%H:%M:%S%z")
            if created_dt < cutoff_time:
                return False
        return True
    
    def update_msg_summary(property_facebook_pairs, FIELDS, LIMIT, fb):
        
        for facebook_pair in property_facebook_pairs:
            property_id = facebook_pair['property_id']
            facebook_page_id = facebook_pair['facebook_page_id']
            page_access_token = facebook_pair['access_token']
            
            graph = facebook.GraphAPI(access_token=page_access_token)

            response = graph.request(f"{facebook_page_id}/conversations", args={
                "fields": FIELDS,
                "limit": LIMIT
            })

            threads_by_user_id = {}
            timezone = pytz.timezone('Asia/Bangkok')
            cutoff_time = datetime.now(timezone) - timedelta(hours=2)

            while response:
                stop_pagination = False

                for thread in response.get("data", []):
                    updated_str = thread.get("updated_time")
                    if not updated_str:
                        continue

                    updated_dt = datetime.strptime(updated_str, "%Y-%m-%dT%H:%M:%S%z")

                    if updated_dt < cutoff_time:
                        stop_pagination = True
                        break

                    participants = thread.get("participants", {}).get("data", [])
                    user_participant = next((p for p in participants if p["id"] != facebook_page_id), None)

                    if user_participant:
                        user_id = user_participant["id"]
                        threads_by_user_id[user_id] = thread

                if stop_pagination:
                    break  # Exit the while loop

                next_url = response.get("paging", {}).get("next")
                if next_url:
                    try:
                        response = requests.get(next_url).json()
                    except Exception as e:
                        break
                else:
                    break  # No more pages

            # Ref to Firebase
            for user_id, thread in threads_by_user_id.items():
                thread_id = thread["id"]
                inline_messages = thread.get("messages", {}).get("data", [])
                fetch_more = Facebook.needs_more_messages(inline_messages, cutoff_time)

                if fetch_more:
                    more_messages = Facebook.fetch_recent_messages(thread_id, page_access_token, cutoff_time, 100)
                    all_messages = inline_messages + more_messages
                else:
                    all_messages = [
                        m for m in inline_messages
                        if "created_time" in m and datetime.strptime(m["created_time"], "%Y-%m-%dT%H:%M:%S%z") >= cutoff_time
                    ]

                # Remove duplicates by message ID
                message_map = {}
                for msg in all_messages:
                    msg_id = msg.get("id")
                    if msg_id:
                        message_map[msg_id] = msg
                thread["messages"] = {"data": list(message_map.values())}
                
            # Firebase sync
            base_ref = fb.db.reference(f'account/{property_id}/chat/facebook/{facebook_page_id}')
            updated_count = 0
            skipped_count = 0

            for user_id, new_thread in threads_by_user_id.items():
                doc_ref = base_ref.child(user_id)
                existing_data = doc_ref.get()

                if existing_data:
                    existing_messages = existing_data.get("messages", {}).get("data", [])
                    existing_ids = {m.get("id") for m in existing_messages}

                    new_msgs = new_thread.get("messages", {}).get("data", [])
                    combined_msgs = existing_messages + [m for m in new_msgs if m.get("id") not in existing_ids]

                    new_thread["messages"] = {"data": combined_msgs}

                    if new_thread.get("updated_time") != existing_data.get("updated_time") or len(combined_msgs) > len(existing_messages):
                        doc_ref.set(new_thread)
                        updated_count += 1
                    else:
                        skipped_count += 1
                else:
                    doc_ref.set(new_thread)
                    updated_count += 1
        return True