import json
import os
import hashlib
import re
from typing import Dict, List
from firebase.firebase import Firebase
from collections import defaultdict
from datetime import datetime, timedelta
from feature.automation.line import *
import pytz
import logging
logger = logging.getLogger(__name__)
fb = Firebase(host=os.environ.get("FIREBASE_HOST"))
timezone = pytz.timezone('Asia/Bangkok')
class Node:
    def __init__(self, node_id, name):
        self.node_id = node_id
        self.name = name

    def execute(self, context):
        """
        Execute the node with the given context.
        
        Args:
            context: Dictionary containing execution context and data
            
        Returns:
            bool: True if execution should continue, False if execution should stop
        """
        raise NotImplementedError("Execute method must be implemented in subclass.")
    
    def add_to_context(self, context, key, value):
        """
        Safely add a key-value pair to the context.
        
        Args:
            context: The context dictionary
            key: The key to add
            value: The value to add
        """
        context[key] = value
    
    def get_from_context(self, context, key, default=None):
        """
        Safely get a value from the context.
        
        Args:
            context: The context dictionary
            key: The key to retrieve
            default: Default value if key doesn't exist
            
        Returns:
            The value from context or default
        """
        return context.get(key, default)
    
    def log_execution(self, context, **kwargs):
        """
        Add a standardized execution log entry to the context.
        
        Args:
            context: The context dictionary
            **kwargs: Additional fields to include in the log entry
        """
        log_entry = {
            "node_id": self.node_id,
            "name": self.name,
            "timestamp": datetime.now().isoformat(),
            **kwargs
        }
        
        # Add split target if available
        split_target = context.get('current_split_target')
        if split_target:
            log_entry["split_target"] = split_target
            
        context['execution_log'].append(log_entry)

class TriggerNode(Node):
    def __init__(self, node_id, name, trigger, cron=None, interval:list = None, last_triggered=None, now:datetime = None, datebegin:str = None, dateend: str= None, alwaywhen:str='off'):
        super().__init__(node_id, name)
        self.now = now
        # self.now = timezone.localize(now.replace(second=0, microsecond=0))
        self.trigger = trigger
        self.cron = cron
        self.interval = interval
        self.last_triggered = last_triggered
        try:
            self.datebegin = timezone.localize(datetime.strptime(datebegin, "%Y-%m-%d %H:%M:%S"))
            self.dateend   = timezone.localize(datetime.strptime(dateend, "%Y-%m-%d %H:%M:%S"))
        except ValueError:
            self.datebegin = timezone.localize(datetime.strptime(datebegin, "%Y-%m-%d"))
            self.dateend   = timezone.localize(datetime.strptime(dateend, "%Y-%m-%d"))
        self.alwaywhen = alwaywhen
        format_string = "%Y-%m-%d %H:%M:%S"
        self.last_triggered_datetime = timezone.localize(datetime.strptime(self.last_triggered, format_string).replace(second=0, microsecond=0))
        logging.info(self.now)
        logging.info(self.last_triggered_datetime)
    
    def should_trigger_now(self):
        if self.alwaywhen == 'off':
            if not self.datebegin <= self.now <= self.dateend:
                logging.info('Not in range date trigger')
                return False
        else:
            if self.datebegin >= self.now:
                logging.info('Not in range datebegin')
                return False

        for interval in self.interval:
            field = interval.get("field")

            trigger_hour = interval.get("triggerAtHour", 0)
            trigger_minute = interval.get("triggerAtMinute", 0)

            if field == "minutes":
                minutes_interval = interval.get("minutesInterval", None)
                if minutes_interval is not None:
                    if self.now.minute % minutes_interval == 0:
                        return True

            elif field == "hours":
                hours_interval = interval.get("hoursInterval", None)
                if hours_interval is not None:
                    if self.now.hour % hours_interval == 0 and self.now.minute == trigger_minute:
                        return True

            elif field == "days":
                days_interval = interval.get("daysInterval", None)
                if days_interval is not None:
                    reference_date = self.last_triggered_datetime
                    delta_days = (self.now.date() - reference_date.date()).days
                    if delta_days % days_interval == 0 and self.now.hour == trigger_hour and self.now.minute == trigger_minute:
                        return True

            elif field == "weeks":
                weeks_interval = interval.get("weeksInterval", None)
                trigger_days = interval.get("triggerAtDay", [])
                if weeks_interval is not None:
                    reference_date = self.last_triggered_datetime
                    delta_weeks = ((self.now - reference_date).days) // 7
                    if delta_weeks % weeks_interval == 0 and self.now.weekday() in trigger_days and self.now.hour == trigger_hour and self.now.minute == trigger_minute:
                        return True

            elif field == "months":
                months_interval = interval.get("monthsInterval", None)
                trigger_day_of_month = interval.get("triggerAtDayOfMonth", [1])
                if months_interval is not None:
                    reference_month = self.last_triggered_datetime.month
                    months_since = (self.now.year - 2024) * 12 + (self.now.month - reference_month)
                    if months_since % months_interval == 0 and self.now.day in trigger_day_of_month and self.now.hour == trigger_hour and self.now.minute == trigger_minute:
                        return True
            
            elif field == "specific":
                for dt_str in interval.get("specific", []):
                    logging.info(dt_str)
                    try:
                        trigger_time = timezone.localize(datetime.fromisoformat(dt_str))
                        logging.info(trigger_time, self.now)
                        if trigger_time.replace(second=0, microsecond=0) == self.now.replace(second=0, microsecond=0):
                            return True
                    except ValueError:
                        continue

        return False
    
    def generate_schedule(self):
        """Generate list of times this node should run"""
        if self.alwaywhen == "on":
            return []  # always trigger, no need for specific schedule

        schedule = []
        current = self.datebegin

        while current <= self.dateend:
            for interval in self.interval:
                field = interval.get("field")
                trigger_hour = interval.get("triggerAtHour", 0)
                trigger_minute = interval.get("triggerAtMinute", 0)

                if field == "minutes":
                    minutes_interval = interval.get("minutesInterval", None)
                    if minutes_interval is not None:
                        current = self.datebegin.replace(second=0, microsecond=0)
                        while current <= self.dateend:
                            schedule.append(current)
                            current += timedelta(minutes=minutes_interval)

                elif field == "hours":
                    hours_interval = interval.get("hoursInterval", None)
                    if hours_interval is not None:
                        current = self.datebegin.replace(minute=trigger_minute, second=0, microsecond=0)
                        while current <= self.dateend:
                            schedule.append(current)
                            current += timedelta(hours=hours_interval)

                elif field == "days":
                    days_interval = interval.get("daysInterval", None)
                    if days_interval is not None:
                        current = self.datebegin.replace(hour=trigger_hour, minute=trigger_minute, second=0, microsecond=0)
                        if current < self.datebegin:
                            current = self.datebegin
                        while current <= self.dateend:
                            schedule.append(current)
                            current += timedelta(days=days_interval)
                        if not schedule and self.datebegin.date() == self.dateend.date():
                            schedule.append(
                                self.datebegin.replace(
                                    hour=trigger_hour, minute=trigger_minute, second=0, microsecond=0
                                )
                            )

                elif field == "weeks":
                    weeks_interval = interval.get("weeksInterval", None)
                    trigger_days = interval.get("triggerAtDay", [])
                    if weeks_interval is not None:
                        current = self.datebegin.replace(hour=trigger_hour, minute=trigger_minute, second=0, microsecond=0)
                        while current <= self.dateend:
                            if current.weekday() in trigger_days:
                                schedule.append(current)
                            current += timedelta(days=1)

                elif field == "months":
                    months_interval = interval.get("monthsInterval", None)
                    trigger_day_of_month = interval.get("triggerAtDayOfMonth", [1])
                    if months_interval is not None:
                        reference_month = self.last_triggered_datetime.month
                        months_since = (current.year - self.last_triggered_datetime.year) * 12 + (current.month - reference_month)
                        if months_since % months_interval == 0 and current.day in trigger_day_of_month and current.hour == trigger_hour and current.minute == trigger_minute:
                            schedule.append(current)

                elif field == "specific":
                    for dt_str in interval.get("specific", []):
                        try:
                            trigger_time = timezone.localize(datetime.fromisoformat(dt_str))
                            if trigger_time.date() >= self.datebegin.date() and trigger_time.date() <= self.dateend.date():
                                schedule.append(trigger_time.replace(second=0, microsecond=0))
                        except ValueError:
                            continue

            current += timedelta(minutes=1)  # iterate minute by minute

        return sorted(set([dt.strftime("%Y-%m-%d %H:%M:%S") for dt in schedule])), sorted(set([dt.replace(hour=0, minute=trigger_minute, second=0, microsecond=0) for dt in schedule]))

    def execute(self, context):
        now = datetime.now()
        # #TEST
        # self.log_execution(context, type="trigger", triggered=True)
        # logging.info("Manual trigger fired!")
        # return True
    
        if self.trigger =='interval' and self.interval:
            if self.should_trigger_now():
                self.log_execution(context, type="trigger", triggered=True)
                logging.info("Trigger fired!")
                return True
            else:
                logging.info("Trigger not due yet.")
                self.log_execution(context, type="trigger", triggered=False)
                return False
        elif self.trigger == 'manual':
            self.log_execution(context, type="trigger", triggered=True)
            logging.info("Manual trigger fired!")
            return True

class AudienceNode(Node):
    def __init__(self, property_id, node_id, name, audience_id):
        super().__init__(node_id, name)
        self.audience_id = audience_id
        self.property_id = property_id
    
    def execute(self, context):
        # Call API
        
        # audience_size  = fb.db.reference().child(f"account/{self.property_id}/audience/{self.audience_id}/audience_size").get()
        # audience_list  = fb.db.reference().child(f"account/{self.property_id}/audience/{self.audience_id}/user_pseudo_id").get()

        #Change Audience list to Looker builder
        audience_context = fb.db.reference().child(f'account/{self.property_id}/audience/{self.audience_id}').get()
        if not audience_context:
            return {'status': 'not fond', 'message': 'Audience not fond'},404

        if 'nodes' not in audience_context:
            return {'status': 'ok', 'data': audience_context}, 200
        
        from feature.audience.audience_builder import AudinceBuilder
        builder = AudinceBuilder(json=audience_context, cache=True)
        if builder.connections != {}:
            final_result = builder.get_final_audience()
        logging.info(f"Filtering audience by ID: {self.audience_id} | {len(final_result)}\n\n")
        
        # Add audience data to context
        self.add_to_context(context, 'audience', f'audience-{self.audience_id}')
        self.add_to_context(context, 'audience_size', len(final_result))
        self.add_to_context(context, 'audience_list', final_result)

        self.log_execution(context, type="audience", audience_id=self.audience_id, user_pseudo_ids=final_result,total_user_pseudo_ids=len(final_result))
        return True

class DestinationNode(Node):
    def __init__(self, property_id, node_id, name, channel_id: List, channel_type, content):
        super().__init__(node_id, name)
        self.property_id = property_id
        self.channel_id = channel_id
        self.channel_type = channel_type
        self.content = content
        self.content_id = content['content_id']
        self.content_name = content['name']
    
    def findSocialId(self, user_pseudo_id):
        social_id = fb.db.reference().child(f"account/{self.property_id}/profile/{user_pseudo_id}/{self.channel_type}").get()
        if social_id:
            grouped = defaultdict(set)

            for user in social_id.values():
                page_id = user['source']['page_id']
                user_id = user['id']
                grouped[page_id].add(user_id)

            # Convert sets to lists
            grouped_final = {k: list(v) for k, v in grouped.items()}
            return grouped_final
        return None
    
    def _get_page_access_token(self, page_id: str) -> str:
        url = f"https://graph.facebook.com/v24.0/me/accounts"
        params = {"access_token": os.environ.get('FB_TOKEN')}
        response = requests.get(url, params=params)
        data = response.json()
        for page in data['data']:
            if str(page_id) == str(page['id']):
                return page['access_token']
    
    def getContentContext(self):
        contentContext = fb.db.reference().child(f"account/{self.property_id}/content/{self.channel_type}/{self.content_id}").get()
        if not contentContext:
            return None
        
        contentObject = None
        if self.channel_type == 'line':
            contentObject = contentContext['json']
            return contentObject
        elif self.channel_type =='facebook':
            if 'json' not in contentContext:
                raise ValueError(f"No JSON object in Content {self.content_id}!")
            content_json = contentContext['json'][0]
            if 'data' not in content_json:
                raise ValueError(f"No data object in json content {self.content_id}!")
            
            content_json_data = content_json['data']
            if 'message' not in content_json_data:
                raise ValueError(f"No message object in json content data {self.content_id}!")
            
            return content_json_data['message']
        
        return False
    
    def execute(self, context):
        if self.channel_type == "line":
            audience_list = self.get_from_context(context, 'audience_list', [])
            split_target = self.get_from_context(context, 'current_split_target')
            
            if split_target:
                logging.info(f"Sending content '{self.content['name']}' to LINE OA Channel: {self.channel_id} (Split: {split_target})")
            else:
                logging.info(f"Sending content '{self.content['name']}' to LINE OA Channel: {self.channel_id}")
            
            if not audience_list:
                logging.info("No audience found for this destination node.")
                return False

            userContextList = []

            mappingUserSocial = {}
            
            for user in audience_list:
                social_id = self.findSocialId(user)
                userContextList.append(social_id)
                if social_id:
                    for _, user_ids in social_id.items():
                        for uid in user_ids:
                            mappingUserSocial[uid] = user
            
            
            merged = defaultdict(set)
            for d in userContextList:
                if d != None:
                    for page_id, user_ids in d.items():
                        merged[page_id].update(user_ids)

            # Convert sets to lists
            merged_final = {k: list(v) for k, v in merged.items()}
            logging.info(f"Found {len(merged_final)} channels with users to message")
            
            #Get content
            contentContext = self.getContentContext()
            
            if not contentContext:
                logging.info("No content found for this destination")
                return False
            
            #Loop by channel and send message
            sent = 0
            error = 0
            user_pseudo_ids_success = []
            for ch in self.channel_id:
                if ch in merged_final:
                    userList = merged_final[ch]
                    logging.info(f"Sending to channel {ch}: {len(userList)} users")
                    bot = MessageClient(access_token=os.environ.get(f"LINE_{ch}"))
                    # Loop by user
                    for id in userList:
                        body = CraftMessage.push_message(id, contentContext)
                        respone = bot.push(body)
                        if respone.status_code == 200:
                            sent += 1
                            try:
                                logging.info("TYTYTY",mappingUserSocial)
                                user_pseudo_ids_success.append(mappingUserSocial[id])
                            except:
                                pass
                        else:
                            error +=0
                            try:
                                user_pseudo_ids_fail.append(mappingUserSocial[id])
                            except:
                                pass

                else:
                    logging.info(f"Channel {ch} not found in merged results")
            
            user_pseudo_ids_fail = list(set(audience_list) - set(user_pseudo_ids_success))
            
            # Log execution with split context if available
            log_data = {
                "type": "destination",
                "channel_type": self.channel_type,
                "channel_id_count": len(self.channel_id),
                "total_user_pseudo_ids": len(audience_list),
                "user_pseudo_ids": audience_list,
                "user_pseudo_ids_success": user_pseudo_ids_success,
                "total_user_pseudo_ids_success": len(user_pseudo_ids_success),
                "user_pseudo_ids_fail": user_pseudo_ids_fail,
                "total_user_pseudo_ids_fail": len(user_pseudo_ids_fail),
            }
            
            self.log_execution(context, **log_data)
            
            logging.info(f"Successfully sent messages to {sent} users")
        elif self.channel_type == 'facebook':
            audience_list = self.get_from_context(context, 'audience_list', [])
            split_target = self.get_from_context(context, 'current_split_target')
            
            if split_target:
                logging.info(f"Sending content '{self.content['name']}' to Facebook Page Channel: {self.channel_id} (Split: {split_target})")
            else:
                logging.info(f"Sending content '{self.content['name']}' to Facebook Page Channel: {self.channel_id}")
            
            if not audience_list:
                logging.info("No audience found for this destination node.")
                return False
            userContextList = []

            mappingUserSocial = {}
            
            for user in audience_list:
                social_id = self.findSocialId(user)
                userContextList.append(social_id)
                if social_id:
                    for _, user_ids in social_id.items():
                        for uid in user_ids:
                            mappingUserSocial[uid] = user

            merged = defaultdict(set)
            for d in userContextList:
                if d != None:
                    for page_id, user_ids in d.items():
                        merged[page_id].update(user_ids)

            # Convert sets to lists
            merged_final = {k: list(v) for k, v in merged.items()}
            logging.info(f"merged_final: {merged_final}")
            logging.info(f"FACEBOOK: Found {len(merged_final)} channels with users to message")
            
            #Get content
            contentContext = self.getContentContext()
            
            if not contentContext:
                logging.info("No content found for this destination")
                return False
            else:
                logging.info(f"contentContext: {contentContext}")
            
            #Loop by channel and send message
            sent = 0
            error = 0
            user_pseudo_ids_success = []
            for ch in self.channel_id:
                logging.info(f"NOW CHANNEL: {ch}")
                if ch in merged_final:
                    userList = merged_final[ch]
                    logging.info(f"Sending to channel {ch}: {len(userList)} users")
                    accessToken = self._get_page_access_token(str(ch))
                    params = {"access_token": accessToken}
                    headers = {"Content-Type": "application/json"}
                    push_api_url = f"https://graph.facebook.com/v24.0/{ch}/messages"
                    
                    for id in userList:
                        payload = {
                            "recipient": {"id": id},
                            "messaging_type": "UPDATE",
                            'message': contentContext
                        }
                        response = requests.post(push_api_url, headers=headers, json=payload, params=params)
                        if response.status_code == 200:
                            sent += 1
                            try:
                                logging.info("TYTYTY",mappingUserSocial)
                                user_pseudo_ids_success.append(mappingUserSocial[id])
                            except:
                                pass
                        else:
                            error += 1
                            logging.error(f"FACEBOOK MESSAGE: {response.status_code} | {response.content}")
            user_pseudo_ids_fail = list(set(audience_list) - set(user_pseudo_ids_success))
            
            # Log execution with split context if available
            log_data = {
                "type": "destination",
                "channel_type": self.channel_type,
                "channel_id_count": len(self.channel_id),
                "total_user_pseudo_ids": len(audience_list),
                "user_pseudo_ids": audience_list,
                "user_pseudo_ids_success": user_pseudo_ids_success,
                "total_user_pseudo_ids_success": len(user_pseudo_ids_success),
                "user_pseudo_ids_fail": user_pseudo_ids_fail,
                "total_user_pseudo_ids_fail": len(user_pseudo_ids_fail),
            }
            
            self.log_execution(context, **log_data)
            
            logging.info(f"Successfully sent messages to {sent} users (error: {error})")
        return True

class ABNode(Node):
    def __init__(self, node_id, name, splits, seedKey):
        super().__init__(node_id, name)
        if len(splits) > 5:
            raise ValueError("ABNode supports a maximum of 5 paths.")
        self.splits = splits
        self.seedKey = seedKey
        self.total_percentage = sum(s["percentage"] for s in splits)

    def execute(self, context):
        audience_list = self.get_from_context(context, 'audience_list', [])
        if not audience_list:
            logging.info("No audience list found for ABNode splitting")
            return False
            
        cumulative = 0
        split_counts = []
        total = len(audience_list)
        
        logging.info(f"\n=== ABNode Execution ===")
        logging.info(f"Total audience: {total} users")
        logging.info(f"Audience list: {audience_list}")
        logging.info(f"Splits configuration: {self.splits}")
        
        for i, s in enumerate(self.splits):
            if i == len(self.splits) - 1:
                count = total - cumulative
            else:
                count = round(total * s["percentage"] / self.total_percentage)
                cumulative += count
            split_counts.append((s["target"], count))

        logging.info(f"\nABNode splitting {total} users into {len(self.splits)} groups:")
        for target, count in split_counts:
            logging.info(f"  {target}: {count} users ({count/total*100:.1f}%)")

        # Create split result dictionary
        split_result = {}
        index = 0
        for target, count in split_counts:
            split_result[target] = audience_list[index: index + count]
            logging.info(f"  {target} gets users: {split_result[target]}")
            index += count

        # Store the split result in context for downstream nodes
        self.add_to_context(context, 'audience_splits', split_result)
        self.add_to_context(context, 'current_split_target', None)  # Will be set by flow executor
        self.add_to_context(context, 'audience_size', total)

        self.log_execution(context, 
                          type="split", 
                          total_audience=total, 
                          splits=split_counts, 
                          split_result={k: len(v) for k, v in split_result.items()})

        logging.info(f"=== ABNode Execution Complete ===\n")
        return True

class WaitNode(Node):
    def __init__(self, node_id, name, wait):
        super().__init__(node_id, name)
        self.wait = wait
        logging.info(self.wait)
    
    def parse_duration(self, duration_str:str):
        """
        Parse duration strings like '2h', '10m', '1d' into timedelta.
        Supports:
        - 's' for seconds
        - 'm' for minutes
        - 'h' for hours
        - 'd' for days
        """
        pattern = r"(?P<value>\d+)(?P<unit>[smhd])"
        match = re.fullmatch(pattern, duration_str.strip())
        
        if not match:
            raise ValueError(f"Invalid duration format: {duration_str}")
        
        value = int(match.group("value"))
        unit = match.group("unit")
        
        if unit == "s":
            return timedelta(seconds=value)
        elif unit == "m":
            return timedelta(minutes=value)
        elif unit == "h":
            return timedelta(hours=value)
        elif unit == "d":
            return timedelta(days=value)
        
        raise ValueError(f"Unsupported time unit: {unit}")


    def execute(self, context):
        now = datetime.now()
        split_target = self.get_from_context(context, 'current_split_target')

        logging.info(context)
            
        if self.wait["type"] == "duration":
            delay = self.parse_duration(self.wait["value"])
            resume_time = now + delay
        else:
            resume_time = datetime.fromisoformat(self.wait["value"])

        # Save state in DB or scheduler system
        self.add_to_context(context, "_pause_until", resume_time.isoformat())
        self.add_to_context(context, "_paused", True)

        self.log_execution(context, 
                          type="wait", 
                          wait_type=self.wait["type"], 
                          resume_time=resume_time.isoformat())

        # Your executor must stop here and persist this context
        return False

class FlowExecutor:
    def __init__(self, property_id:str, flow_json: Dict, now:datetime = None):
        self.now = now
        self.nodes = {}
        self.connections = flow_json['json']["connections"]
        self.context = {
            "execution_log": [],
            "property_id": property_id,
            "flow_id": flow_json.get('id', 'unknown')
        }
        self.property_id = property_id
        self.last_triggered = flow_json['last_triggered']
        self.build_nodes(flow_json['json']["nodes"])

    def build_nodes(self, nodes_data: List[Dict]):
        for n in nodes_data:
            if n["type"] == "trigger":
                node = TriggerNode(n["id"], n["name"], n['trigger'], n.get("cron",None), n.get("interval",None), self.last_triggered, self.now, n['datebegin'], n['dateend'], n.get('alwaywhen', 'off'))
            elif n["type"] == "audience":
                node = AudienceNode(self.property_id, n["id"], n["name"], n["audience_id"])
            elif n["type"] == "destination":
                node = DestinationNode(self.property_id, n["id"], n["name"], n["channel_id"], n['channel_type'], n["content"])
            elif n['type'] == 'splitNode':
                node = ABNode(n["id"], n["name"], n['splits'], n['seedKey'])
            elif n['type'] == 'waiting':
                node = WaitNode(n["id"], n["name"], n['waiting'])
            else:
                raise ValueError(f"Unknown node type: {n['type']}")
            self.nodes[n["id"]] = node
            logging.info(f"Built node: {node.__class__.__name__} ({node.name})")

    def create_split_context(self, base_context: Dict, split_target: str, audience_list: List) -> Dict:
        """
        Create a new context for a specific split path with isolated data.
        This ensures each split path has its own independent context.
        """
        import copy
        
        # Create a deep copy of the base context to avoid sharing references
        split_context = copy.deepcopy(base_context)
        
        # Set split-specific data
        split_context['current_split_target'] = split_target
        split_context['audience_list'] = audience_list
        split_context['audience_size'] = len(audience_list)
        
        # Add split path information to execution log
        split_context['execution_log'].append({
            "node_id": f"split_path_{split_target}",
            "name": f"Split Path: {split_target}",
            "type": "split_path_start",
            "audience_list": audience_list,
            "split_target": split_target,
            "audience_size": len(audience_list),
            "timestamp": datetime.now().isoformat()
        })
        
        logging.info(f"Created isolated context for split path '{split_target}' with {len(audience_list)} users")
        return split_context

    def merge_split_contexts(self, base_context: Dict, split_contexts: Dict[str, Dict]) -> Dict:
        """
        Merge all split contexts back into the base context after split execution is complete.
        This preserves the execution logs and any shared data from all split paths.
        """
        merged_context = base_context.copy()
        
        # Merge execution logs from all split paths
        all_logs = merged_context.get('execution_log', [])
        for split_target, split_context in split_contexts.items():
            split_logs = split_context.get('execution_log', [])
            all_logs.extend(split_logs)
        
        merged_context['execution_log'] = all_logs
        
        # Add summary of split execution
        merged_context['execution_log'].append({
            "node_id": "split_execution_summary",
            "name": "Split Execution Summary",
            "type": "split_summary",
            "total_split_paths": len(split_contexts),
            "split_paths": list(split_contexts.keys()),
            "timestamp": datetime.now().isoformat()
        })
        
        logging.info(f"Merged {len(split_contexts)} split contexts back into base context")
        return merged_context

    def execute_node_chain(self, node_id, context=None, split_target=None):
        """
        Execute a chain of nodes starting from the given node_id.
        
        Args:
            node_id: The ID of the node to start execution from
            context: The context to use for this execution path (if None, uses self.context)
            split_target: The split target if this is part of a split execution
        """
        if context is None:
            context = self.context
            
        node = self.nodes[node_id]
        
        # Execute the node with the provided context
        success = node.execute(context)
        
        # For WaitNodes in split contexts, we don't want to stop execution
        # as it would prevent other split paths from executing
        if not success and not (isinstance(node, WaitNode) and split_target):
            return context

        # Get next nodes
        next_nodes = self.connections.get(node_id, {}).get("main", [[]])[0]
        
        # If this was an ABNode, we need to handle multiple split paths
        if isinstance(node, ABNode) and 'audience_splits' in context:
            split_targets = list(context['audience_splits'].keys())
            logging.info(f"\n=== ABNode Split Execution ===")
            logging.info(f"ABNode completed, executing {len(split_targets)} split paths: {split_targets}")
            logging.info(f"Available connections: {len(next_nodes)}")
            
            # Create isolated contexts for each split path
            split_contexts = {}
            
            for i, conn in enumerate(next_nodes):
                if i < len(split_targets):
                    split_target = split_targets[i]
                    audience_list = context['audience_splits'].get(split_target, [])
                    
                    # Create isolated context for this split path
                    split_context = self.create_split_context(context, split_target, audience_list)
                    
                    logging.info(f"Starting split path '{split_target}' -> {conn['node']}")
                    # Execute the split path with its isolated context
                    final_split_context = self.execute_node_chain(conn["node"], split_context, split_target)
                    split_contexts[split_target] = final_split_context
                else:
                    logging.info(f"Warning: More connections ({len(next_nodes)}) than split targets ({len(split_targets)})")
            
            # Merge all split contexts back into the base context
            merged_context = self.merge_split_contexts(context, split_contexts)
            logging.info(f"=== ABNode Split Execution Complete ===\n")
            return merged_context
        else:
            # Normal single path execution
            for conn in next_nodes:
                context = self.execute_node_chain(conn["node"], context, split_target)
            return context

    def execute(self, is_manual=False):
        if is_manual:
            self.trigger = 'manual'
        final_contexts = []
        for node in self.nodes.values():
            if isinstance(node, TriggerNode):
                final_context = self.execute_node_chain(node.node_id)
                final_contexts.append(final_context)
        
        # If multiple triggers were executed, merge their contexts
        if len(final_contexts) > 1:
            merged_context = self.context.copy()
            all_logs = []
            for ctx in final_contexts:
                all_logs.extend(ctx.get('execution_log', []))
            merged_context['execution_log'] = all_logs
            return merged_context["execution_log"]
        elif len(final_contexts) == 1:
            return final_contexts[0]["execution_log"]
        else:
            return self.context["execution_log"]