(**************************************************************************)
(*                                                                        *)
(*                 ACG development toolkit                                *)
(*                                                                        *)
(*                  Copyright 2008-2024 INRIA                             *)
(*                                                                        *)
(*  More information on "https://acg.loria.fr/"                     *)
(*  License: CeCILL, see the LICENSE file or "http://www.cecill.info"     *)
(*  Authors: see the AUTHORS file                                         *)
(*                                                                        *)
(*                                                                        *)
(*                                                                        *)
(*                                                                        *)
(*                                                                        *)
(**************************************************************************)

module Make(Data:
              sig
                type key
                type elt
                val explode : key -> elt list
                val implode : elt list -> key
              end) = struct
  (*    type 'a option = None | Some of 'a *)

  type 'a t = ST of 'a option * (Data.elt * 'a t) list
  type key = string

  let empty = ST (None, [])

  let add ?(overwrite = false) id attr smtb =
    let rec insert1 lts (ST (a, s)) =
      match lts with
      | [] -> (
          match (a, overwrite) with
          | None, _ -> Some (ST (Some attr, s))
          | Some _, true -> Some (ST (Some attr, s))
          | Some _, false -> None)
      | l :: rm -> Option.map (fun o -> ST (a, o)) (insert2 l rm s)
    and insert2 lt lts stls =
      match stls with
      | [] -> Option.map (fun o -> [ (lt, o) ]) (insert1 lts empty)
      | (l, i) :: rm ->
          if lt = l then Option.map (fun o -> (lt, o) :: rm) (insert1 lts i)
          else if lt <= l then Option.map (fun o -> (lt, o) :: stls) (insert1 lts empty)
          else Option.map (fun o -> (l, i) :: o) (insert2 lt lts rm)
    in
    insert1 (Data.explode id) smtb

  let find w smtb =
    let rec lookup1 lts (ST (a, s)) =
      match lts with
      | [] -> a
      | l :: rm -> lookup2 l rm s
    and lookup2 lt lts stls =
      match stls with
      | [] -> None
      | (l, i) :: rm ->
          if lt = l then lookup1 lts i
          else if lt <= l then None
          else lookup2 lt lts rm
    in
    lookup1 (Data.explode w) smtb

  let fold f acc tr =
    let rec fold_aux key acc = function
      | ST (None, trs) ->
          List.fold_left (fun acc (c, t) -> fold_aux (c :: key) acc t) acc trs
      | ST (Some v, trs) ->
          let new_acc = f (Data.implode key) v acc in
          List.fold_left
            (fun acc (c, t) -> fold_aux (c :: key) acc t)
            new_acc trs
    in
    fold_aux [] acc tr

  let iter f tr =
    let rec iter_aux key = function
      | ST (None, trs) -> List.iter (fun (c, t) -> iter_aux (c :: key) t) trs
      | ST (v, trs) ->
          let () = match v with None -> () | Some v -> f (Data.implode key) v in
          List.iter (fun (c, t) -> iter_aux (c :: key) t) trs
    in
    iter_aux [] tr

  let pp ?(sep = format_of_string "@,") ppf m tr =
    let l_pp m (k, v) = ppf m k v in
    let first = ref true in
    iter
      (fun k v ->
        if !first then
          let () = first := false in
          ppf m k v
        else Format.fprintf m (sep ^^ "%a") l_pp (k, v))
      tr
end

module Tries = Make(
                   struct
                     type key = string
                     type elt = char
                     let explode str =
                       let rec explode_aux i ls =
                         if i = -1 then ls else explode_aux (i - 1) (String.get str i :: ls)
                       in
                       explode_aux (String.length str - 1) []

                     let implode lst =
                       let buff = Buffer.create (List.length lst) in
                       let () = List.fold_right (fun c _ -> Buffer.add_char buff c) lst () in
                       Buffer.contents buff
                   end)
