open Cairo
open Diagram
open UtilsLib
open AcgData.Signature
open AcgData.Acg_lexicon
open Logic.Lambda.Lambda
open Show_exts

module Lambda_show (T : Show_text_sig) = struct
  open T

  let rec fix (f : ('a -> 'b) -> 'a -> 'b) : 'a -> 'b = fun x -> f (fix f) x

  let parenthesize_d ((d, b) : diagram * bool) : diagram =
    match b with true -> d | false -> hcat [ n "("; d; n ")" ]

  let term_to_diagram_open
      (recur_fn : term -> env * env -> consts -> diagram * bool)
      (t : term) ((l_env, env) : env * env)
      (id_to_sym : consts) : diagram * bool =
    let recurse t (l_env, env) =
      recur_fn t (l_env, env) id_to_sym
    in
    let d, b =
      match t with
      | Var i -> (n @@ VNEnv.get i env, true)
      | LVar i -> (n @@ VNEnv.get i l_env, true)
      | Const id | DConst id -> (n @@ snd @@ id_to_sym id, true)
      | Abs (x, t) ->
          let x' = generate_var_name x (l_env, env) in
          let vars, (_, n_env), u =
            unfold_abs [x'] (l_env, VNEnv.add x' env) t
          in
          ( hcat
              [
                n "λ";
                n @@ Utils.string_of_list " " (fun x -> x) @@ (List.rev vars);
                n ". ";
                fst @@ recurse u (l_env, n_env);
              ],
            false )
      | LAbs (x, t) ->
          let x' = generate_var_name x (l_env, env) in
          let vars, (n_l_env, _), u =
            unfold_labs
              [x']
              (VNEnv.add x' l_env, env)
              t
          in
          ( hcat
              [
                n "λᵒ";
                n @@ Utils.string_of_list " " (fun x -> x) @@ (List.rev vars);
                n ". ";
                fst @@ recurse u (n_l_env, env);
              ],
            false )
      | App (((Const id | DConst id) as binder), Abs (x, u))
        when is_binder id id_to_sym ->
          let x' = generate_var_name x (l_env, env) in
          let vars, (n_l_env, n_env), u =
            unfold_binder id id_to_sym
              [ x' ]
              (l_env, VNEnv.add x' env)
              u
          in
          ( hcat
              [
                parenthesize_d @@ recurse binder (n_l_env, n_env);
                n " ";
                n @@ Utils.string_of_list " " (fun x -> x) @@ List.rev vars;
                n ". ";
                fst @@ recurse u (n_l_env, n_env);
              ],
            false )
      | App (((Const id | DConst id) as binder), LAbs (x, u))
        when is_binder id id_to_sym ->
          let x' = generate_var_name x (l_env, env) in
          let vars, (n_l_env, n_env), u =
            unfold_binder id id_to_sym
              [ x' ]
              (VNEnv.add x' l_env, env)
              u
          in
          ( hcat
              [
                parenthesize_d @@ recurse binder (n_l_env, n_env);
                n " ";
                n @@ Utils.string_of_list " " (fun x -> x) @@ List.rev vars;
                n ". ";
                fst @@ recurse u (n_l_env, n_env);
              ],
            false )
      | App (App (((Const id | DConst id) as op), t1), t2)
        when is_infix id id_to_sym ->
          ( hcat
              [
                parenthesize_d @@ recurse t1 (l_env, env);
                n " ";
                parenthesize_d @@ recurse op (l_env, env);
                n " ";
                parenthesize_d @@ recurse t2 (l_env, env);
              ],
            false )
      | App (t1, t2) ->
          let args, fn = unfold_app [ t2 ] t1 in
          ( hcat
            @@ [
                 parenthesize_d @@ recurse fn (l_env, env); n " ";
               ]
            @ Utils.intersperse (n " ")
            @@ List.map
                 (fun x ->
                   parenthesize_d @@ recurse x (l_env, env))
                 args,
            false )
      | _ -> failwith "Not yet implemented"
    in
    (centerX d, b)

  let term_to_diagram (t : term) (id_to_sym : consts) : diagram =
    fst @@ fix term_to_diagram_open t VNEnv.(empty, empty) id_to_sym
end

module Make
    (T : Show_text_sig)
    (C : Show_colors_sig)
    (Emb : Show_embellish_sig) =
struct
  type signature = Data_Signature.t
  type lexicon = AcgData.Acg_lexicon.Data_Lexicon.t
  type term = Data_Signature.term
  type 'a tree = 'a Tree.t

  open T
  module L = Lambda_show (T)
  open L

  let replace_with_dict : (string * string) list -> string -> string =
    List.fold_right (fun (ugly, pretty) ->
        Str.global_replace (Str.regexp_string ugly) pretty)

  let[@warning "-unused-value-declaration"] type_to_diagram (sg : signature) (ty : stype) : diagram =
    Format.asprintf "%a" (Data_Signature.pp_type sg) ty
    |> replace_with_dict [ ("->", "⊸"); ("=>", "→") ]
    |> n

  let abstract_sig (lex : lexicon) : signature = fst @@ Data_Lexicon.get_sig lex
  let object_sig (lex : lexicon) : signature = snd @@ Data_Lexicon.get_sig lex
  let sig_name (sg : signature) : string = fst @@ Data_Signature.name sg

  let interpret_term (t : term) (lex : lexicon) : term =
    Data_Lexicon.interpret_term t lex
    |> normalize ~id_to_term:(fun i ->
           Data_Signature.unfold_term_definition i @@ object_sig lex)

  let rec term_to_graph (sg : signature) (t : term) : term tree =
    let children =
      match t with
      | Var _ | LVar _ | Const _ | DConst _ -> []
      | Abs (_, body) | LAbs (_, body) -> [ body ]
      | App (App (((Const id | DConst id) as op), t1), t2)
        when is_infix id (Data_Signature.id_to_string_unsafe sg) ->
          [ t1; op; t2 ]
      | App (fn, arg) -> [ fn; arg ]
      | _ -> failwith "Not yet implemented"
    in
    Tree.T (t, List.map (term_to_graph sg) children)

  let rec render_term_graph ?(non_linear = false)
      ((l_env, env) : env * env)
      (render_term : term -> env * env -> diagram)
      (Tree.T (term, children) : term tree) : diagram tree =
    let render_children_in (l_env, env) =
      List.map
        (render_term_graph ~non_linear (l_env, env) render_term)
        children
    in

    let children_d =
      match term with
      | LAbs (x, _) when not non_linear ->
          let x' = generate_var_name x (l_env, env) in
          render_children_in (VNEnv.add x' l_env, env)
      | Abs (x, _) | LAbs (x, _) ->
          let x' = generate_var_name x (l_env, env) in
          render_children_in (l_env, VNEnv.add x' env)
      | _ -> render_children_in (l_env, env)
    in

    Tree.T (render_term term (l_env, env), children_d)

  let term_to_diagram_in (config : Rendering_config.config) (sg : signature)
      (t : term) ((l_env, env) : env * env) :
      diagram =
    let ttd =
      term_to_diagram_open
      |> Emb.embellishments_functions (sig_name sg) config
      |> fix
    in
    let consts = Data_Signature.id_to_string_unsafe sg in
    fst @@ ttd t (l_env, env) consts

  let merge_trees : 'a tree list -> 'a list tree =
    List.map (Tree.map (fun x -> [ x ])) >> Utils.fold_left1 (Tree.map2 ( @ ))

  let decorate_lines (lines : diagram list) : diagram list =
    lines
    |> List.map (pad_abs ~horizontal:2.0)
    |> List.map (pad_rel ~vertical:0.05)
    |> List.map2 color @@ Utils.cycle (List.length lines) C.lines

  let rec align_sister_lines (tree : diagram list tree) : diagram list tree =
    match tree with
    | Tree.T (_lines, []) -> tree
    | Tree.T (lines, children) ->
        let children = List.map align_sister_lines children in
        let heights =
          List.map
            (fun (Tree.T (c_lines, _)) ->
              List.map (fun c_line -> (extents c_line).h) c_lines)
            children
        in
        let max_heights = Utils.fold_left1 (List.map2 max) heights in
        let children =
          List.map
            (fun (Tree.T (c_lines, c_children)) ->
              Tree.T
                ( List.map2
                    (fun c_line new_height ->
                      let height_diff = new_height -. (extents c_line).h in
                      pad_abs ~vertical:(height_diff /. 2.) c_line)
                    c_lines max_heights,
                  c_children ))
            children
        in
        Tree.T (lines, children)

  let realize_diagram (abs_term : term) (lexs : lexicon list)
      (config : Rendering_config.config) : diagram =
    let abs_sig = abstract_sig @@ List.hd lexs in

    let expanded_abs_term =
      normalize (Data_Signature.expand_term abs_term abs_sig)
    in
    let abs_terms_differ = abs_term != expanded_abs_term in

    let term_graph = term_to_graph abs_sig abs_term in

    let render_abs_term = term_to_diagram_in config abs_sig in

    let render_obj_term lex abs_term =
      let obj_sig = object_sig lex in
      let obj_term = interpret_term abs_term lex in
      term_to_diagram_in config obj_sig obj_term
    in

    let abs_term_tree =
      render_term_graph VNEnv.(empty, empty) render_abs_term term_graph
    in
    let last_abs_term_graph =
      if abs_terms_differ then term_to_graph abs_sig expanded_abs_term
      else term_graph
    in
    let obj_term_trees =
      List.map
        (fun lex ->
          render_term_graph
            ~non_linear:(not (Data_Lexicon.is_linear lex))
            VNEnv.(empty, empty) (render_obj_term lex) last_abs_term_graph)
        lexs
    in

    let trees =
      if abs_terms_differ then
        let expanded_abs_term_tree =
          render_term_graph VNEnv.(empty, empty) render_abs_term last_abs_term_graph
        in
        expanded_abs_term_tree :: obj_term_trees
      else abs_term_tree :: obj_term_trees
    in

    trees |> merge_trees |> Tree.map decorate_lines |> align_sister_lines
    |> Tree.map (List.map centerX >> vcat)
    |> Tree.map (bg_color (C.node_background config))
    |> Tree.to_diagram
    |> setup (fun cr -> set_line_width cr 1.5)
    |> bg_color (C.background config)
    |> color C.tree
end
