Skip to content
Snippets Groups Projects

Perf: Implement a new algorithm for pattern matching compilation to reduce code size

Merged Melwyn Saldanha requested to merge melwyn95@compile_pattern_matching_to_decision_tree into dev
Compare and Show latest version
2 files
+ 209
44
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -13,6 +13,7 @@ module O = Ast_expanded
module C = Ast_aggregated.Combinators
let t_unit = I.Combinators.t_unit
let empty_label = Label.of_string ""
(* TODO: figure out how to deal with name bound in patterns *)
@@ -25,11 +26,18 @@ type signature = LabelSet.t
(* NOTE: Maybe add type expression later if needed *)
type simple_pattern =
| SP_Var of Value_var.t
| SP_Constructor of Label.t * simple_pattern list * (signature[@sexp.opaque])
| SP_Var of Value_var.t * (I.type_expression[@sexp.opaque])
| SP_Constructor of
Label.t
* simple_pattern list
* (signature[@sexp.opaque])
* (I.type_expression[@sexp.opaque])
[@@deriving sexp]
(* TODO: make this into a record *)
type action = (Value_var.t * Label.t list) list * (I.expression[@sexp.opaque])
[@@deriving sexp]
type action = Value_var.t list * (I.expression[@sexp.opaque]) [@@deriving sexp]
type row = simple_pattern list * action [@@deriving sexp]
type matrix = row list [@@deriving sexp]
@@ -38,6 +46,13 @@ let print_matrix : matrix -> unit =
Format.printf "----------\n%a\n----------\n" Sexp.pp_hum (sexp_of_matrix matrix)
let get_number_of_fields : I.type_expression -> int =
fun t ->
match C.get_t_record t with
| Some row -> I.Row.length row
| None -> 1
(* let arity : simple_pattern -> int =
fun sp ->
match sp with
@@ -47,18 +62,72 @@ let print_matrix : matrix -> unit =
let rec n_vars n =
if n = 0
then []
else SP_Var (Value_var.fresh ~loc:Location.generated ()) :: n_vars (n - 1)
else
SP_Var (Value_var.fresh ~loc:Location.generated (), t_unit ~loc:Location.generated ())
:: n_vars (n - 1)
let nil_label = Label.of_string "Nil"
let cons_label = Label.of_string "Cons"
let list_signature = LabelSet.of_list [ nil_label; cons_label ]
let nil = SP_Constructor (nil_label, [], list_signature)
let cons hd tl = SP_Constructor (cons_label, hd @ tl, list_signature)
let nil ty = SP_Constructor (nil_label, [], list_signature, ty)
let cons hd tl ty = SP_Constructor (cons_label, hd @ tl, list_signature, ty)
let get_variant_nested_type label (tsum : I.row) =
let label_map = tsum.fields in
LabelMap.find_exn label_map label
let join_labels lbls =
[ List.fold_right lbls ~init:(Label.of_string "") ~f:(fun l acc -> Label.join l acc) ]
let rec get_vars_and_projections
: Label.t list -> I.type_expression I.Pattern.t -> (Value_var.t * Label.t list) list
=
fun path p ->
let open I.Pattern in
match Location.unwrap p with
| P_unit -> []
| P_var b ->
let v = Binder.get_var b in
[ v, join_labels path ]
| P_list (List []) -> []
| P_list (List ps) ->
let rec aux ps =
match ps with
| [] -> []
| [ h; t ] ->
get_vars_and_projections (path @ [ cons_label; Label.of_int 0 ]) h
@ get_vars_and_projections (path @ [ cons_label; Label.of_int 1 ]) t
| p :: ps ->
get_vars_and_projections (path @ [ cons_label; Label.of_int 0 ]) p @ aux ps
in
aux ps
| P_list (Cons (h, t)) ->
get_vars_and_projections (path @ [ cons_label; Label.of_int 0 ]) h
@ get_vars_and_projections (path @ [ cons_label; Label.of_int 1 ]) t
| P_variant (c, p) -> get_vars_and_projections (path @ [ c ]) p
| P_tuple ps ->
List.concat_mapi ps ~f:(fun i p ->
let path = path @ [ Label.of_int i ] in
get_vars_and_projections path p)
| P_record ps ->
let ps = Record.to_list ps in
List.concat_map ps ~f:(fun (l, p) ->
let path = path @ [ l ] in
get_vars_and_projections path p)
let get_variant_nested_type label ty =
match C.get_t_sum ty with
| Some tsum ->
let label_map = tsum.fields in
LabelMap.find_exn label_map label
| None ->
assert (C.is_t_list ty);
if Label.equal label cons_label
then
C.t_record
~loc:Location.generated
(I.Row.create_tuple [ Option.value_exn ~here:[%here] @@ C.get_t_list ty; ty ])
()
else C.t_unit ~loc:Location.generated ()
let get_signature_of_sum_type : I.type_expression -> signature =
@@ -69,13 +138,6 @@ let get_signature_of_sum_type : I.type_expression -> signature =
LabelSet.of_list cs
let get_number_of_fields : I.type_expression -> int =
fun t ->
match C.get_t_record t with
| Some row -> I.Row.length row
| None -> 1
let rec to_simple_pattern
: I.type_expression -> I.type_expression I.Pattern.t -> simple_pattern list
=
@@ -91,26 +153,28 @@ let rec to_simple_pattern
match Location.unwrap p with
| P_unit ->
let v = Value_var.fresh ~loc ~name:"unit_pattern" () in
[ SP_Var v ]
assert (C.is_t_unit ty);
[ SP_Var (v, ty) ]
| P_var _ -> n_vars (get_number_of_fields ty)
| P_list (List []) -> [ nil ]
| P_list (List []) -> [ nil ty ]
| P_list (List ps) ->
let hd_ty = Option.value_exn ~here:[%here] (C.get_t_list ty) in
List.fold_right ps ~init:[ nil ] ~f:(fun hd tl ->
List.fold_right
ps
~init:[ nil (get_variant_nested_type nil_label ty) ]
~f:(fun hd tl ->
let hd = to_simple_pattern hd_ty hd in
[ cons hd tl ])
[ cons hd tl (get_variant_nested_type cons_label ty) ])
| P_list (Cons (h, t)) ->
let hd_ty = Option.value_exn ~here:[%here] (C.get_t_list ty) in
let h = to_simple_pattern hd_ty h in
let t = to_simple_pattern ty t in
[ cons h t ]
[ cons h t (get_variant_nested_type cons_label ty) ]
| P_variant (c, p) ->
let p_ty =
get_variant_nested_type c (Option.value_exn ~here:[%here] (C.get_t_sum ty))
in
let p_ty = get_variant_nested_type c ty in
let signature = get_signature_of_sum_type ty in
let ps = if is_unit_pattern p then [] else to_simple_pattern p_ty p in
[ SP_Constructor (c, ps, signature) ]
[ SP_Constructor (c, ps, signature, p_ty) ]
| P_tuple ps ->
let row = Option.value_exn ~here:[%here] (C.get_t_record ty) in
List.concat_mapi ps ~f:(fun i p ->
@@ -128,14 +192,49 @@ let rec to_simple_pattern
to_simple_pattern p_ty p)
let get_occurances : I.type_expression -> Label.t list =
fun t ->
if C.is_t_record t
then (
let rec aux : I.type_expression -> Label.t list =
fun t ->
match C.get_t_record t with
| Some row ->
let fields = I.Row.to_alist row in
List.concat_map fields ~f:(fun (l, t) ->
let t = aux t in
if List.is_empty t then [ l ] else List.map ~f:(Label.join l) t)
| None -> []
in
aux t)
else [ Label.of_int 0 ]
let specialize_occurances : Label.t list -> I.type_expression -> Label.t list =
fun os ty ->
let o = List.hd_exn os in
(* Format.printf
"specialize_occurances | o = %a | ty = %a\n"
Label.pp
o
I.PP.type_expression
ty; *)
let o1_n = get_occurances ty in
List.map o1_n ~f:(fun oi -> Label.join o oi) @ os
let specialize : Label.t -> int -> matrix -> matrix =
fun c a matrix ->
List.filter_map matrix ~f:(fun (pattern, body) ->
match pattern with
| [] -> Some ([], body)
| SP_Constructor (c', qs, _) :: ps when Label.equal c c' -> Some (qs @ ps, body)
| SP_Constructor (c', qs, _, _ty) :: ps when Label.equal c c' ->
(* Format.printf "a = %d | c = %a | ty = %a\n" a Label.pp c' I.PP.type_expression ty; *)
Some (qs @ ps, body)
| SP_Constructor _ :: _ -> None
| SP_Var _ :: ps -> Some (n_vars a @ ps, body))
| SP_Var _ :: ps ->
(* Format.printf "a = %d | c = %a | ty = %a\n" a Label.pp c I.PP.type_expression ty; *)
Some (n_vars a @ ps, body))
let default : matrix -> matrix =
@@ -150,7 +249,8 @@ let default : matrix -> matrix =
type decision_tree =
| Leaf of action
| Fail
| Switch of l
| SwitchConstructor of l
| SwitchRecord of (Label.t * Label.t list) list * decision_tree
| Swap of int * decision_tree
and l = (Label.t * decision_tree) list [@@deriving sexp]
@@ -170,12 +270,15 @@ let has_all_var_pattersn : simple_pattern list -> bool =
| SP_Constructor _ -> false)
let get_first_column_constructor : row -> ((Label.t * int) * signature) option =
let get_first_column_constructor
: row -> ((Label.t * (int * I.type_expression)) * signature) option
=
fun (pattern, _) ->
match pattern with
| [] -> None
| SP_Var _ :: _ -> None
| SP_Constructor (c, ps, signature) :: _ -> Some ((c, List.length ps), signature)
| SP_Constructor (c, ps, signature, ty) :: _ ->
Some ((c, (List.length ps, ty)), signature)
let remove_first_column : row -> row =
@@ -186,16 +289,18 @@ let remove_first_column : row -> row =
let head_constructors_of_column_with_atleast_one_constructor
: matrix -> int * int LabelMap.t * signature
: matrix -> int * (int * I.type_expression) LabelMap.t * signature
=
fun matrix ->
let rec aux : matrix -> int -> int * int LabelMap.t * signature =
let rec aux : matrix -> int -> int * (int * I.type_expression) LabelMap.t * signature =
fun rows idx ->
let first_column_constructors, signatures =
List.unzip @@ List.filter_map rows ~f:get_first_column_constructor
in
let first_column_constructors =
List.dedup_and_sort first_column_constructors ~compare:(fun (l1, a1) (l2, a2) ->
List.dedup_and_sort
first_column_constructors
~compare:(fun (l1, (a1, _)) (l2, (a2, _)) ->
match Label.compare l1 l2 with
| 0 -> Int.compare a1 a2
| c -> c)
@@ -238,8 +343,31 @@ let swap_column_in_matrix : int -> matrix -> matrix =
failwith s
let rec compile : matrix -> decision_tree =
fun matrix ->
let rec generate_match_record
: Label.t -> I.type_expression -> decision_tree -> decision_tree
=
fun lbls t dt ->
match C.get_t_record t with
| Some row ->
let fields = I.Row.to_alist row in
let labels, _ = List.unzip fields in
let dt =
List.fold fields ~init:dt ~f:(fun dt (l, ty) ->
match C.get_t_record ty with
| Some _ -> generate_match_record (Label.join lbls l) ty dt
| None -> dt)
in
let labels = List.map labels ~f:(fun l -> l, join_labels (lbls :: [ l ])) in
SwitchRecord (labels, dt)
| None -> dt
let rec compile : Label.t list -> matrix -> decision_tree =
fun occurances matrix ->
(* Format.printf "++++++++++++++++++++++\n";
List.iter occurances ~f:(fun o -> Format.printf "%a " Label.pp o);
Format.printf "\n++++++++++++++++++++++\n"; *)
(* print_matrix matrix; *)
match matrix with
| [] -> Fail
| (row, body) :: _ when List.is_empty row -> Leaf body
@@ -256,20 +384,36 @@ let rec compile : matrix -> decision_tree =
in
let missing_constructors = LabelSet.diff signature head_constructors in
let head =
LabelMap.fold head_constructors_with_arity ~init:[] ~f:(fun ~key:c ~data:a l ->
let t = compile (specialize c a matrix) in
LabelMap.fold
head_constructors_with_arity
~init:[]
~f:(fun ~key:c ~data:(a, ty) l ->
let t =
compile (specialize_occurances occurances ty) (specialize c a matrix)
in
let t = generate_match_record (Label.join (List.hd_exn occurances) c) ty t in
(c, t) :: l)
in
let default =
List.map (LabelSet.to_list missing_constructors) ~f:(fun c ->
c, compile (default matrix))
c, compile occurances (default matrix))
in
if LabelSet.is_empty missing_constructors
then Switch head
else Switch (head @ default))
then SwitchConstructor head
else SwitchConstructor (head @ default))
else (
let swapped_matrix = swap_column_in_matrix i matrix in
Swap (i, compile swapped_matrix))
let swapped_occurances =
try swap_row i occurances with
| e ->
(* TODO: clean this up *)
(* print_matrix matrix;
Format.printf "i = %d $$$$$$$$$$$$$$$$$\n" i;
List.iter occurances ~f:(fun o -> Format.printf "%a " Label.pp o);
Format.printf "\n$$$$$$$$$$$$$$$$$\n"; *)
raise e
in
Swap (i, compile swapped_occurances swapped_matrix))
let test
@@ -284,10 +428,29 @@ let test
(List.map sp ~f:sexp_of_simple_pattern)) *)
let matrix =
List.map cases ~f:(fun { pattern; body } ->
let vars = List.map ~f:Binder.get_var @@ I.Pattern.binders pattern in
to_simple_pattern matchee_type pattern, (vars, body))
let vars_projs = get_vars_and_projections [] pattern in
to_simple_pattern matchee_type pattern, (vars_projs, body))
in
(* print_matrix matrix; *)
let dt = compile matrix in
let top_occurances = get_occurances matchee_type in
(* Format.printf
"top_occurances = %a\n"
(Simple_utils.PP_helpers.list_sep_d Label.pp)
top_occurances; *)
let dt = compile top_occurances matrix in
let dt = generate_match_record empty_label matchee_type dt in
(* print_decision_tree dt; *)
ignore dt
(*
if 1st occurance is a record/tuple generate SwitchRecord
and repeat till no more records are unmatched
Idea: generate all Match records then make another pass to remove unused nodes
- at top-level if record pattern destruct it
- occurance to match on for switch-constructor
- occurance to match on for switch-record
- Constructor variable
*)
Loading