import proto
import json
import altair as alt
from google.protobuf.json_format import MessageToDict

def handle_text_response(resp):
    parts = getattr(resp, 'parts')
    return ''.join(parts)

def sanitize_keys(obj):
    """Recursively remove '$' from all dict keys."""
    if isinstance(obj, dict):
        new_dict = {}
        for k, v in obj.items():
            new_key = k.replace("$", "").replace(".", "_")  # remove all '$'
            new_dict[new_key] = sanitize_keys(v)
        return new_dict
    elif isinstance(obj, list):
        return [sanitize_keys(el) for el in obj]
    else:
        return obj

def handle_chart_response(resp):
  def _convert(v):
    if isinstance(v, proto.marshal.collections.maps.MapComposite):
      return {k: _convert(v) for k, v in v.items()}
    elif isinstance(v, proto.marshal.collections.RepeatedComposite):
      return [_convert(el) for el in v]
    elif isinstance(v, (int, float, str, bool)):
      return v
    else:
      return MessageToDict(v)

  if 'query' in resp:
      return resp.query
  elif 'result' in resp:
    # Hack from https://github.com/streamlit/streamlit/issues/6269
    # TODO: Make use of st.altair_chart when either issues below are resolved:
    # https://github.com/streamlit/streamlit/issues/6269
    # https://github.com/streamlit/streamlit/issues/1196 
    # Then we can make use of the altair example in our python sdk documentation
    chart = alt.Chart.from_dict(_convert(resp.result.vega_config))
    json_obj = json.loads(chart.to_json())
    return sanitize_keys(json_obj)


def handle_data_response(resp):
  if 'query' in resp:
      return None
  elif 'result' in resp:
    fields = [field.name for field in resp.result.schema.fields]
    d = {}
    for el in resp.result.data:
      for field in fields:
        if field in d:
          d[field].append(el[field])
        else:
          d[field] = [el[field]]
    return sanitize_keys(d)