import re
from typing import Any, Dict, List, Union, Iterable, Optional, Tuple

class Keywords:
    
    def _is_meta_key(k: str) -> bool:
        META_KEYS = {
            "name", "other", "level", "createdate", "lastupdate",
            "include_kw", "exclude_kw",
            "category_1", "category_2", "category_3", "category_4", "category_5"
        }
        return k in META_KEYS or k.lower().startswith("category_")

    def _cat_index(level_value: str) -> int:
        try:
            return int(level_value.split("_", 1)[1])
        except Exception:
            return -1

    def _dfs_deepest(node: dict, deepest: dict):
        if not isinstance(node, dict):
            return
        lvl = node.get("level")
        if lvl:
            idx = Keywords._cat_index(lvl)
            if idx > deepest["idx"]:
                deepest["idx"] = idx
                deepest["label"] = lvl
        for k, v in node.items():
            if Keywords._is_meta_key(k):
                continue
            Keywords._dfs_deepest(v, deepest)

    def _structure_view(node: dict):
        """
        Returns either:
        - list[str]  if all real children are leaves (no further dict children)
        - dict[str, list|dict] if any child has its own children (recursively formatted)
        Skips meta keys and omits empty children.
        """
        # collect real children (dicts, non-meta)
        real_children = [(k, v) for k, v in node.items()
                        if isinstance(v, dict) and not Keywords._is_meta_key(k)]
        if not real_children:
            return None  # no structure here

        # Determine for each child whether it has grandchildren
        leaf_children = []
        branch_children = {}  # k -> formatted sub-structure (must be list/dict and non-empty)
        for k, v in real_children:
            # grandchildren = dict children of v (non-meta)
            grandchildren = [(gk, gv) for gk, gv in v.items()
                            if isinstance(gv, dict) and not Keywords._is_meta_key(gk)]
            if not grandchildren:
                # This child is a leaf category
                leaf_children.append(k)
            else:
                sub = Keywords._structure_view(v)  # recurse
                if sub:  # keep only if non-empty
                    branch_children[k] = sub if isinstance(sub, dict) else sub

        # If all children are leaves → return list
        if branch_children == {}:
            # return list only if we actually have leaves
            return leaf_children if leaf_children else None

        # Else return dict of branch children; if there are also leaf children, include them
        # as a synthetic group under "__leaf__" (optional). We’ll omit to keep it clean.
        return branch_children

    def summarize_minimal(templates: dict):
        """
        Input  : dict from fb.db.reference('template/keyword').get()
        Output : list of {createdate, lastupdate, deepest_level, structure}
        - structure collapses to list when children are leaves, otherwise dict mapping.
        """
        results = []
        for top_key, top_obj in (templates or {}).items():
            if not isinstance(top_obj, dict):
                continue

            createdate = top_obj.get("createdate")
            lastupdate = top_obj.get("lastupdate")

            deepest = {"idx": 1, "label": "category_1"}
            Keywords._dfs_deepest(top_obj, deepest)

            view = Keywords._structure_view(top_obj)
            # wrap under the outermost key, only if view is present
            structure = {top_key: view} if view else {top_key: []}

            results.append({
                "createdate": createdate,
                "lastupdate": lastupdate,
                "deepest_level": deepest["label"],
                "hierarchy": structure
            })
        return results


    DEFAULT_RESERVED = {
        "level","name","other","apply_to","createdate","lastupdate",
        "include_kw","exclude_kw","priority"
    }
    EXCLUDED_KEYS = {
        "apply_to","createdate","lastupdate","name","level","other",
        "include_kw","exclude_kw"
    }

    def _sanitize_cat_part(s: str) -> str:
        """Match your category naming (keep case, Thai; turn spaces/punct into underscores)."""
        s = (s or "").strip().replace("/", " ")
        s = re.sub(r"\s+", " ", s)
        s = re.sub(r"[^\w\u0E00-\u0E7F]+", "_", s)
        s = re.sub(r"_+", "_", s).strip("_")
        return s

    def _to_str_list(x: Any) -> List[str]:
        if x is None:
            return []
        if isinstance(x, list):
            arr = [str(v).strip() for v in x if isinstance(v, (str,int,float)) and str(v).strip()]
        elif isinstance(x, dict) and x and all(isinstance(k, str) and k.isdigit() for k in x.keys()):
            arr = [str(x[k]).strip() for k in sorted(x.keys(), key=lambda z: int(z))
                if isinstance(x[k], (str,int,float)) and str(x[k]).strip()]
        else:
            return []
        seen, out = set(), []
        for s in arr:
            if s not in seen:
                seen.add(s); out.append(s)
        return out

    def flatten_keyword_tree_to_list(
        root: Dict[str, Any],
        excluded_fields: Iterable[str] = ("apply_to","createdate","lastupdate","name")
    ) -> List[Dict[str, Dict[str, List[str]]]]:
        """
        Emit one category per bucket that has include_kw.
        Category key = path joined with underscores (e.g., product_Augment_nose_alarpastry).
        Skips any keys in excluded_fields + DEFAULT_RESERVED while walking.
        """
        reserved = Keywords.DEFAULT_RESERVED | set(excluded_fields)
        results: List[Dict[str, Dict[str, List[str]]]] = []

        def walk(node: Any, path: List[str]):
            if not isinstance(node, (dict, list)):
                return

            if isinstance(node, dict):
                if "include_kw" in node:
                    include_kw = Keywords._to_str_list(node.get("include_kw"))
                    exc_kw = Keywords._to_str_list(node.get("exc_kw") or node.get("exclude_kw"))
                    if path:
                        cat_key = "_".join(Keywords._sanitize_cat_part(p) for p in path)
                        payload = {"include_kw": include_kw}
                        if exc_kw:
                            payload["exc_kw"] = exc_kw
                        results.append({cat_key: payload})

                for k, v in node.items():
                    if k in reserved:
                        continue
                    if isinstance(v, (dict, list)):
                        walk(v, path + [k])

            else:
                for v in node:
                    if isinstance(v, (dict, list)):
                        walk(v, path)

        for top_key, top_val in (root or {}).items():
            if isinstance(top_val, (dict, list)):
                walk(top_val, [Keywords._sanitize_cat_part(top_key)])

        return results
    
    
    def _to_int_or_none(x: Any) -> Optional[int]:
        try:
            if x is None: return None
            return int(str(x).strip())
        except Exception:
            return None
    
    # -------------------------
    # Counting helpers
    # -------------------------
    
    Pattern = Union[str, List[str]]
    
    def _norm(s: str) -> str:
        return " ".join((s or "").split())

    def _count_token(text: str, token: str) -> int:
        if not token:
            return 0
        return len(re.findall(re.escape(token), text, flags=re.IGNORECASE))

    def _parse_pattern(p: str) -> Pattern:
        # supports "A&&B" -> ["A","B"]
        if "&&" in p:
            parts = [t.strip() for t in p.split("&&") if t.strip()]
            return parts if parts else p
        return p

    def _count_pattern(text: str, patt: Pattern) -> int:
        if isinstance(patt, list):  # AND-group: min of counts
            return min(Keywords._count_token(text, t) for t in patt) if patt else 0
        return Keywords._count_token(text, patt)

    def _dedupe(seq: Iterable[str]) -> List[str]:
        seen, out = set(), []
        for s in seq:
            if s not in seen:
                seen.add(s); out.append(s)
        return out


    def label_to_bq(message: str, schema_dict: Dict[str, Dict[str, List[str]]]) -> List[Dict[str, Any]]:
        t = Keywords._norm(message or "")
        labeling: List[Dict[str, Any]] = []

        for cat, cfg in schema_dict.items():
            include_kw = Keywords._dedupe(cfg.get("include_kw", []))
            exclude_kw = Keywords._dedupe(cfg.get("exc_kw", cfg.get("exclude_kw", [])))

            # count includes
            inc_hits = []
            for tok in include_kw:
                patt = Keywords._parse_pattern(tok)
                c = Keywords._count_pattern(t, patt)
                if c > 0:
                    inc_hits.append({"keyword": tok, "value": int(c)})

            # count excludes (optional)
            exc_hits = []
            for tok in exclude_kw:
                patt = Keywords._parse_pattern(tok)
                c = Keywords._count_pattern(t, patt)
                if c > 0:
                    exc_hits.append({"keyword": tok, "value": int(c)})

            # Only emit categories that matched at least one include
            if inc_hits:
                labeling.append({
                    "category": {
                        "include": inc_hits,
                        "exclude": exc_hits,        # [] if none hit
                        "category_name": cat        # <-- new field
                    }
                })

        return labeling
    
    
    def _normalize_fb_list(x: Any) -> List[str]:
        """Normalize list from Firebase (supports list or dict with numeric keys)."""
        if x is None:
            return []
        if isinstance(x, list):
            vals = [str(v).strip() for v in x if isinstance(v, (str, int))]
        elif isinstance(x, dict) and x and all(isinstance(k, str) and k.isdigit() for k in x.keys()):
            vals = [str(x[k]).strip() for k in sorted(x.keys(), key=lambda z: int(z))]
        else:
            vals = [str(x).strip()] if isinstance(x, (str, int)) else []
        return [v for v in vals if v]

    def build_apply_to_map_if_contains_id(
        fb,
        property_id: str,
        required_id: str,
        channel_key: str = "facebook_comment",  # e.g. "facebook_message" later
    ) -> Dict[str, Dict[str, List[str]]]:
        """
        Returns only groups whose apply_to/<channel_key> contains required_id.
        Example: {'branch': {'facebook_comment': ['1043...']}, 'product': {...}}
        If missing or not found -> returns {} for that group (skips it).
        """
        root_path = f"account/{property_id}/keyword"
        kw_root = fb.db.reference(root_path).get() or {}
        out: Dict[str, Dict[str, List[str]]] = {}

        for group, node in (kw_root.items() if isinstance(kw_root, dict) else []):
            if not isinstance(node, dict):
                continue
            fb_chan_values = (((node.get("apply_to") or {}).get(channel_key)))
            vals = Keywords._normalize_fb_list(fb_chan_values)
            if not vals:
                continue
            if str(required_id) in {str(v) for v in vals}:
                out[group] = {channel_key: vals}
        return out

    # ---- Optional: use the map to filter schema_dict by group/prefix ----
    def filter_schema_by_apply_to(
        schema_dict: Dict[str, Dict[str, List[str]]],
        apply_to_map: Dict[str, Dict[str, List[str]]],
    ) -> Dict[str, Dict[str, List[str]]]:
        """
        Keep only categories whose prefix (before first underscore) is in allowed groups.
        e.g., 'branch_Asok' -> group 'branch'
        """
        if not apply_to_map:
            return schema_dict  # do nothing
        allowed_groups = set(apply_to_map.keys())
        filtered = {}
        for cat, cfg in schema_dict.items():
            group = cat.split("_", 1)[0] if "_" in cat else cat
            if group in allowed_groups:
                filtered[cat] = cfg
        return filtered
    
    def load_group_priorities_deep(
        fb,
        property_id: str,
        groups: Optional[List[str]] = None,
        prefer_deepest: bool = True,          # if True, deeper path wins on conflict
        ascending_priority: bool = True       # smaller number = higher priority
    ) -> Dict[str, Dict[str, int]]:
        """
        Returns e.g.:
        {
        "product": {"Augment_nose_alarpastry": 4, "Augment_nose_revision": 7, ...},
        "purpose": {"mid_intent": 10, "low_intent": 20}
        }
        Scans any depth under each group for a 'priority' field.
        """
        root_path = f"account/{property_id}/keyword"
        kw_root = fb.db.reference(root_path).get() or {}
        out: Dict[str, Dict[str, int]] = {}
        meta: Dict[Tuple[str,str], Tuple[int,int]] = {}  # (group, sub) -> (depth, pr)

        if not isinstance(kw_root, dict):
            return out

        # Limit to given groups, else all top-level dict groups
        top_groups = groups or [g for g, v in kw_root.items() if isinstance(v, dict)]

        def walk(group: str, node: Any, parts_after_group: List[str], depth: int):
            if not isinstance(node, dict):
                return

            # If this node has a priority, record it for this subpath
            pr = Keywords._to_int_or_none(node.get("priority"))
            if pr is not None and parts_after_group:
                sub = "_".join(Keywords._sanitize_cat_part(p) for p in parts_after_group)
                key = (group, sub)
                if key not in meta:
                    meta[key] = (depth, pr)
                else:
                    old_depth, old_pr = meta[key]
                    take = False
                    if prefer_deepest and depth > old_depth:
                        take = True
                    elif prefer_deepest and depth == old_depth:
                        # same depth -> pick better priority
                        take = (pr < old_pr) if ascending_priority else (pr > old_pr)
                    elif not prefer_deepest:
                        # ignore depth, compare priority only
                        take = (pr < old_pr) if ascending_priority else (pr > old_pr)
                    if take:
                        meta[key] = (depth, pr)

            # Recurse into children (skip reserved keys)
            for k, v in node.items():
                if k in Keywords.EXCLUDED_KEYS:
                    continue
                if isinstance(v, dict):
                    walk(group, v, parts_after_group + [k], depth + 1)
                elif isinstance(v, list):
                    # Some trees put dicts in lists; walk dict items inside lists
                    for item in v:
                        if isinstance(item, dict):
                            walk(group, item, parts_after_group + [k], depth + 1)

        for g in top_groups:
            gnode = kw_root.get(g)
            if not isinstance(gnode, dict):
                continue
            # Start under the group (subpath is empty at root of group)
            walk(g, gnode, [], 0)

        # Build output map
        for (group, sub), (_depth, pr) in meta.items():
            out.setdefault(group, {})[sub] = pr

        return out
    
    
    def choose_single_group_labels(
        labeling_row: Any,
        priority_map: Dict[str, Dict[str, int]],
        ascending_priority: bool = True,            # smaller number = higher priority
        limit_to_groups: Optional[List[str]] = None,# e.g. ["purpose"]; None = all
        drop_excluded: bool = True,                 # <-- NEW: skip categories with exclude hits
        drop_all_if_any_exclude: bool = False       # <-- optional: clear everything if any exclude hit
    ) -> List[Dict[str, str]]:
        """
        labeling_row is your BQ-shaped list from label_to_bq, e.g.:
        [
            {"category": {
                "include":[{"keyword":"สาขาไหน","value":1}],
                "exclude":[...],                     # present only if matched
                "category_name":"purpose_mid_intent"
            }},
            ...
        ]

        Returns e.g. [{"key":"purpose","value":"mid_intent"}]
        """
        if not labeling_row or not isinstance(labeling_row, list):
            return []

        # Optional global drop if any exclude in the row
        if drop_all_if_any_exclude:
            for rec in labeling_row:
                cat = (rec or {}).get("category") or {}
                exc = cat.get("exclude") or []
                if isinstance(exc, list) and len(exc) > 0:
                    return []  # remove the whole single_grouping

        # 1) Collect candidates per top-level group (skip excluded if requested)
        per_group: Dict[str, List[Dict[str, Any]]] = {}
        for rec in labeling_row:
            cat = (rec or {}).get("category") or {}
            cat_name = cat.get("category_name")
            if not isinstance(cat_name, str) or "_" not in cat_name:
                continue

            # Skip this category if any exclude matched and drop_excluded is True
            if drop_excluded:
                exc = cat.get("exclude") or []
                if isinstance(exc, list) and len(exc) > 0:
                    continue

            group, sub = cat_name.split("_", 1)
            if limit_to_groups and group not in set(limit_to_groups):
                continue

            include_list = cat.get("include") or []
            score = sum(int(item.get("value", 0) or 0) for item in include_list if isinstance(item, dict))
            per_group.setdefault(group, []).append({"sub": sub, "score": int(score)})

        # 2) For each group, pick one using priority (then score, then alphabetical)
        chosen: List[Dict[str, str]] = []
        for group, items in per_group.items():
            if not items:
                continue

            gprio = priority_map.get(group, {})
            with_prio = [(it["sub"], gprio[it["sub"]], it["score"])
                        for it in items if it["sub"] in gprio]

            if with_prio:
                # sort by priority (best first), then score desc, then name
                if ascending_priority:
                    with_prio.sort(key=lambda x: (x[1], -x[2], x[0]))
                else:
                    with_prio.sort(key=lambda x: (-x[1], -x[2], x[0]))
                best_sub = with_prio[0][0]
            else:
                items.sort(key=lambda x: (-x["score"], x["sub"]))
                best_sub = items[0]["sub"]

            chosen.append({"key": group, "value": best_sub})

        return chosen
    
    
class BigQueryKeyword:
    def facebook_comment_query(property_id,fb_pageId,date_start,date_end):
        facebook_comment_query = f"""
            SELECT
                eventId,
                eventTimeStamp,
                eventName,
                pageId,
                user_pseudo_id,
                source,
                MAX(IF(ep.key = "ps_id", ep.value, NULL)) AS ps_id,
                MAX(IF(ep.key = "facebook_name", ep.value, NULL)) as facebook_name,
                MAX(IF(ep.key = "comment_id", ep.value, NULL)) as comment_id,
                MAX(IF(ep.key = "parent_comment_id", ep.value, NULL)) as parent_comment_id,
                MAX(IF(ep.key = "comment_type", ep.value, NULL)) as comment_type,
                MAX(IF(ep.key = "post_id", ep.value, NULL)) as post_id,
                MAX(IF(ep.key = "comment_text", ep.value, NULL)) as comment_text
            FROM `customer-360-profile.client_{property_id}.event` ,
                UNNEST(eventProperty) AS ep
            WHERE eventName = 'comment_add' 
                AND source = 'facebook'
                AND pageId = '{fb_pageId}'
                AND eventTimeStamp BETWEEN '{date_start}' AND '{date_end}'
            GROUP BY 1,2,3,4,5,6
            HAVING ps_id != '{fb_pageId}'
                AND (comment_text IS NOT NULL OR comment_text != "")
        """
        return facebook_comment_query
    
    def facebook_message_query(property_id,facebook_channel_id,date_start,date_end):
        facebook_chat_query = f"""
            SELECT
                eventId,
                eventTimeStamp,
                eventName,
                pageId,
                user_pseudo_id,
                source,
                id as line_uid,
                MAX(IF(ep.key = "message", ep.value, NULL)) as message
            FROM `customer-360-profile.client_{property_id}.event` ,
                UNNEST(eventProperty) AS ep
            WHERE eventName = 'user_message' 
                AND source = 'facebook'
                AND pageId = '{facebook_channel_id}'
                AND eventTimeStamp BETWEEN '{date_start}' AND '{date_end}'
            GROUP BY 1,2,3,4,5,6,7
            HAVING (message IS NOT NULL OR message != "")
        """
        return facebook_chat_query
    
    def line_chat_query(property_id,line_id,date_start,date_end):
        line_chat_query = f"""
            SELECT
                eventId,
                eventTimeStamp,
                eventName,
                pageId,
                user_pseudo_id,
                source,
                id as line_uid,
                MAX(IF(ep.key = "message", ep.value, NULL)) as message
            FROM `customer-360-profile.client_{property_id}.event` ,
                UNNEST(eventProperty) AS ep
            WHERE eventName = 'user_message' 
                AND source = 'line'
                AND pageId = '{line_id}'
                AND eventTimeStamp BETWEEN '{date_start}' AND '{date_end}'
            GROUP BY 1,2,3,4,5,6,7
            HAVING (message IS NOT NULL OR message != "")
        """
        return line_chat_query