Skip to content
Snippets Groups Projects

Perf: Implement a new algorithm for pattern matching compilation to reduce code size

Merged Melwyn Saldanha requested to merge melwyn95@compile_pattern_matching_to_decision_tree into dev
1 file
+ 71
12
Compare changes
  • Side-by-side
  • Inline
@@ -13,6 +13,7 @@ module O = Ast_expanded
module C = Ast_aggregated.Combinators
let t_unit = I.Combinators.t_unit
let empty_label = Label.of_string ""
(* TODO: figure out how to deal with name bound in patterns *)
@@ -160,15 +161,15 @@ let rec to_simple_pattern
let hd_ty = Option.value_exn ~here:[%here] (C.get_t_list ty) in
List.fold_right
ps
~init:[ nil ty ]
~init:[ nil (get_variant_nested_type nil_label ty) ]
~f:(fun hd tl ->
let hd = to_simple_pattern hd_ty hd in
[ cons hd tl ty ])
[ cons hd tl (get_variant_nested_type cons_label ty) ])
| P_list (Cons (h, t)) ->
let hd_ty = Option.value_exn ~here:[%here] (C.get_t_list ty) in
let h = to_simple_pattern hd_ty h in
let t = to_simple_pattern ty t in
[ cons h t ty ]
[ cons h t (get_variant_nested_type cons_label ty) ]
| P_variant (c, p) ->
let p_ty = get_variant_nested_type c ty in
let signature = get_signature_of_sum_type ty in
@@ -191,6 +192,37 @@ let rec to_simple_pattern
to_simple_pattern p_ty p)
let get_occurances : I.type_expression -> Label.t list =
fun t ->
if C.is_t_record t
then (
let rec aux : I.type_expression -> Label.t list =
fun t ->
match C.get_t_record t with
| Some row ->
let fields = I.Row.to_alist row in
List.concat_map fields ~f:(fun (l, t) ->
let t = aux t in
if List.is_empty t then [ l ] else List.map ~f:(Label.join l) t)
| None -> []
in
aux t)
else [ Label.of_int 0 ]
let specialize_occurances : Label.t list -> I.type_expression -> Label.t list =
fun os ty ->
let o = List.hd_exn os in
(* Format.printf
"specialize_occurances | o = %a | ty = %a\n"
Label.pp
o
I.PP.type_expression
ty; *)
let o1_n = get_occurances ty in
List.map o1_n ~f:(fun oi -> Label.join o oi) @ os
let specialize : Label.t -> int -> matrix -> matrix =
fun c a matrix ->
List.filter_map matrix ~f:(fun (pattern, body) ->
@@ -218,7 +250,7 @@ type decision_tree =
| Leaf of action
| Fail
| SwitchConstructor of l
| SwitchRecord of (Label.t * (Label.t list)) list * decision_tree
| SwitchRecord of (Label.t * Label.t list) list * decision_tree
| Swap of int * decision_tree
and l = (Label.t * decision_tree) list [@@deriving sexp]
@@ -311,7 +343,9 @@ let swap_column_in_matrix : int -> matrix -> matrix =
failwith s
let rec generate_match_record : Label.t list -> I.type_expression -> decision_tree -> decision_tree =
let rec generate_match_record
: Label.t -> I.type_expression -> decision_tree -> decision_tree
=
fun lbls t dt ->
match C.get_t_record t with
| Some row ->
@@ -320,16 +354,20 @@ let rec generate_match_record : Label.t list -> I.type_expression -> decision_tr
let dt =
List.fold fields ~init:dt ~f:(fun dt (l, ty) ->
match C.get_t_record ty with
| Some _ -> generate_match_record (join_labels (lbls @ [l])) ty dt
| Some _ -> generate_match_record (Label.join lbls l) ty dt
| None -> dt)
in
let labels = List.map labels ~f:(fun l -> l, join_labels (lbls @ [l])) in
let labels = List.map labels ~f:(fun l -> l, join_labels (lbls :: [ l ])) in
SwitchRecord (labels, dt)
| None -> dt
let rec compile : Label.t list -> matrix -> decision_tree =
fun occurances matrix ->
(* Format.printf "++++++++++++++++++++++\n";
List.iter occurances ~f:(fun o -> Format.printf "%a " Label.pp o);
Format.printf "\n++++++++++++++++++++++\n"; *)
(* print_matrix matrix; *)
match matrix with
| [] -> Fail
| (row, body) :: _ when List.is_empty row -> Leaf body
@@ -350,8 +388,10 @@ let rec compile : Label.t list -> matrix -> decision_tree =
head_constructors_with_arity
~init:[]
~f:(fun ~key:c ~data:(a, ty) l ->
let t = compile occurances (specialize c a matrix) in
let t = generate_match_record [c] ty t in
let t =
compile (specialize_occurances occurances ty) (specialize c a matrix)
in
let t = generate_match_record (Label.join (List.hd_exn occurances) c) ty t in
(c, t) :: l)
in
let default =
@@ -363,7 +403,17 @@ let rec compile : Label.t list -> matrix -> decision_tree =
else SwitchConstructor (head @ default))
else (
let swapped_matrix = swap_column_in_matrix i matrix in
Swap (i, compile occurances swapped_matrix))
let swapped_occurances =
try swap_row i occurances with
| e ->
(* TODO: clean this up *)
(* print_matrix matrix;
Format.printf "i = %d $$$$$$$$$$$$$$$$$\n" i;
List.iter occurances ~f:(fun o -> Format.printf "%a " Label.pp o);
Format.printf "\n$$$$$$$$$$$$$$$$$\n"; *)
raise e
in
Swap (i, compile swapped_occurances swapped_matrix))
let test
@@ -382,8 +432,13 @@ let test
to_simple_pattern matchee_type pattern, (vars_projs, body))
in
(* print_matrix matrix; *)
let dt = compile [] matrix in
let dt = generate_match_record [] matchee_type dt in
let top_occurances = get_occurances matchee_type in
(* Format.printf
"top_occurances = %a\n"
(Simple_utils.PP_helpers.list_sep_d Label.pp)
top_occurances; *)
let dt = compile top_occurances matrix in
let dt = generate_match_record empty_label matchee_type dt in
(* print_decision_tree dt; *)
ignore dt
@@ -394,4 +449,8 @@ if 1st occurance is a record/tuple generate SwitchRecord
Idea: generate all Match records then make another pass to remove unused nodes
- at top-level if record pattern destruct it
- occurance to match on for switch-constructor
- occurance to match on for switch-record
- Constructor variable
*)
Loading