Skip to content
Snippets Groups Projects

Perf: Implement a new algorithm for pattern matching compilation to reduce code size

Merged Melwyn Saldanha requested to merge melwyn95@compile_pattern_matching_to_decision_tree into dev
Compare and Show latest version
1 file
+ 71
12
Compare changes
  • Side-by-side
  • Inline
@@ -13,6 +13,7 @@ module O = Ast_expanded
@@ -13,6 +13,7 @@ module O = Ast_expanded
module C = Ast_aggregated.Combinators
module C = Ast_aggregated.Combinators
let t_unit = I.Combinators.t_unit
let t_unit = I.Combinators.t_unit
 
let empty_label = Label.of_string ""
(* TODO: figure out how to deal with name bound in patterns *)
(* TODO: figure out how to deal with name bound in patterns *)
@@ -160,15 +161,15 @@ let rec to_simple_pattern
@@ -160,15 +161,15 @@ let rec to_simple_pattern
let hd_ty = Option.value_exn ~here:[%here] (C.get_t_list ty) in
let hd_ty = Option.value_exn ~here:[%here] (C.get_t_list ty) in
List.fold_right
List.fold_right
ps
ps
~init:[ nil ty ]
~init:[ nil (get_variant_nested_type nil_label ty) ]
~f:(fun hd tl ->
~f:(fun hd tl ->
let hd = to_simple_pattern hd_ty hd in
let hd = to_simple_pattern hd_ty hd in
[ cons hd tl ty ])
[ cons hd tl (get_variant_nested_type cons_label ty) ])
| P_list (Cons (h, t)) ->
| P_list (Cons (h, t)) ->
let hd_ty = Option.value_exn ~here:[%here] (C.get_t_list ty) in
let hd_ty = Option.value_exn ~here:[%here] (C.get_t_list ty) in
let h = to_simple_pattern hd_ty h in
let h = to_simple_pattern hd_ty h in
let t = to_simple_pattern ty t in
let t = to_simple_pattern ty t in
[ cons h t ty ]
[ cons h t (get_variant_nested_type cons_label ty) ]
| P_variant (c, p) ->
| P_variant (c, p) ->
let p_ty = get_variant_nested_type c ty in
let p_ty = get_variant_nested_type c ty in
let signature = get_signature_of_sum_type ty in
let signature = get_signature_of_sum_type ty in
@@ -191,6 +192,37 @@ let rec to_simple_pattern
@@ -191,6 +192,37 @@ let rec to_simple_pattern
to_simple_pattern p_ty p)
to_simple_pattern p_ty p)
 
let get_occurances : I.type_expression -> Label.t list =
 
fun t ->
 
if C.is_t_record t
 
then (
 
let rec aux : I.type_expression -> Label.t list =
 
fun t ->
 
match C.get_t_record t with
 
| Some row ->
 
let fields = I.Row.to_alist row in
 
List.concat_map fields ~f:(fun (l, t) ->
 
let t = aux t in
 
if List.is_empty t then [ l ] else List.map ~f:(Label.join l) t)
 
| None -> []
 
in
 
aux t)
 
else [ Label.of_int 0 ]
 
 
 
let specialize_occurances : Label.t list -> I.type_expression -> Label.t list =
 
fun os ty ->
 
let o = List.hd_exn os in
 
(* Format.printf
 
"specialize_occurances | o = %a | ty = %a\n"
 
Label.pp
 
o
 
I.PP.type_expression
 
ty; *)
 
let o1_n = get_occurances ty in
 
List.map o1_n ~f:(fun oi -> Label.join o oi) @ os
 
 
let specialize : Label.t -> int -> matrix -> matrix =
let specialize : Label.t -> int -> matrix -> matrix =
fun c a matrix ->
fun c a matrix ->
List.filter_map matrix ~f:(fun (pattern, body) ->
List.filter_map matrix ~f:(fun (pattern, body) ->
@@ -218,7 +250,7 @@ type decision_tree =
@@ -218,7 +250,7 @@ type decision_tree =
| Leaf of action
| Leaf of action
| Fail
| Fail
| SwitchConstructor of l
| SwitchConstructor of l
| SwitchRecord of (Label.t * (Label.t list)) list * decision_tree
| SwitchRecord of (Label.t * Label.t list) list * decision_tree
| Swap of int * decision_tree
| Swap of int * decision_tree
and l = (Label.t * decision_tree) list [@@deriving sexp]
and l = (Label.t * decision_tree) list [@@deriving sexp]
@@ -311,7 +343,9 @@ let swap_column_in_matrix : int -> matrix -> matrix =
@@ -311,7 +343,9 @@ let swap_column_in_matrix : int -> matrix -> matrix =
failwith s
failwith s
let rec generate_match_record : Label.t list -> I.type_expression -> decision_tree -> decision_tree =
let rec generate_match_record
 
: Label.t -> I.type_expression -> decision_tree -> decision_tree
 
=
fun lbls t dt ->
fun lbls t dt ->
match C.get_t_record t with
match C.get_t_record t with
| Some row ->
| Some row ->
@@ -320,16 +354,20 @@ let rec generate_match_record : Label.t list -> I.type_expression -> decision_tr
@@ -320,16 +354,20 @@ let rec generate_match_record : Label.t list -> I.type_expression -> decision_tr
let dt =
let dt =
List.fold fields ~init:dt ~f:(fun dt (l, ty) ->
List.fold fields ~init:dt ~f:(fun dt (l, ty) ->
match C.get_t_record ty with
match C.get_t_record ty with
| Some _ -> generate_match_record (join_labels (lbls @ [l])) ty dt
| Some _ -> generate_match_record (Label.join lbls l) ty dt
| None -> dt)
| None -> dt)
in
in
let labels = List.map labels ~f:(fun l -> l, join_labels (lbls @ [l])) in
let labels = List.map labels ~f:(fun l -> l, join_labels (lbls :: [ l ])) in
SwitchRecord (labels, dt)
SwitchRecord (labels, dt)
| None -> dt
| None -> dt
let rec compile : Label.t list -> matrix -> decision_tree =
let rec compile : Label.t list -> matrix -> decision_tree =
fun occurances matrix ->
fun occurances matrix ->
 
(* Format.printf "++++++++++++++++++++++\n";
 
List.iter occurances ~f:(fun o -> Format.printf "%a " Label.pp o);
 
Format.printf "\n++++++++++++++++++++++\n"; *)
 
(* print_matrix matrix; *)
match matrix with
match matrix with
| [] -> Fail
| [] -> Fail
| (row, body) :: _ when List.is_empty row -> Leaf body
| (row, body) :: _ when List.is_empty row -> Leaf body
@@ -350,8 +388,10 @@ let rec compile : Label.t list -> matrix -> decision_tree =
@@ -350,8 +388,10 @@ let rec compile : Label.t list -> matrix -> decision_tree =
head_constructors_with_arity
head_constructors_with_arity
~init:[]
~init:[]
~f:(fun ~key:c ~data:(a, ty) l ->
~f:(fun ~key:c ~data:(a, ty) l ->
let t = compile occurances (specialize c a matrix) in
let t =
let t = generate_match_record [c] ty t in
compile (specialize_occurances occurances ty) (specialize c a matrix)
 
in
 
let t = generate_match_record (Label.join (List.hd_exn occurances) c) ty t in
(c, t) :: l)
(c, t) :: l)
in
in
let default =
let default =
@@ -363,7 +403,17 @@ let rec compile : Label.t list -> matrix -> decision_tree =
@@ -363,7 +403,17 @@ let rec compile : Label.t list -> matrix -> decision_tree =
else SwitchConstructor (head @ default))
else SwitchConstructor (head @ default))
else (
else (
let swapped_matrix = swap_column_in_matrix i matrix in
let swapped_matrix = swap_column_in_matrix i matrix in
Swap (i, compile occurances swapped_matrix))
let swapped_occurances =
 
try swap_row i occurances with
 
| e ->
 
(* TODO: clean this up *)
 
(* print_matrix matrix;
 
Format.printf "i = %d $$$$$$$$$$$$$$$$$\n" i;
 
List.iter occurances ~f:(fun o -> Format.printf "%a " Label.pp o);
 
Format.printf "\n$$$$$$$$$$$$$$$$$\n"; *)
 
raise e
 
in
 
Swap (i, compile swapped_occurances swapped_matrix))
let test
let test
@@ -382,8 +432,13 @@ let test
@@ -382,8 +432,13 @@ let test
to_simple_pattern matchee_type pattern, (vars_projs, body))
to_simple_pattern matchee_type pattern, (vars_projs, body))
in
in
(* print_matrix matrix; *)
(* print_matrix matrix; *)
let dt = compile [] matrix in
let top_occurances = get_occurances matchee_type in
let dt = generate_match_record [] matchee_type dt in
(* Format.printf
 
"top_occurances = %a\n"
 
(Simple_utils.PP_helpers.list_sep_d Label.pp)
 
top_occurances; *)
 
let dt = compile top_occurances matrix in
 
let dt = generate_match_record empty_label matchee_type dt in
(* print_decision_tree dt; *)
(* print_decision_tree dt; *)
ignore dt
ignore dt
@@ -394,4 +449,8 @@ if 1st occurance is a record/tuple generate SwitchRecord
@@ -394,4 +449,8 @@ if 1st occurance is a record/tuple generate SwitchRecord
Idea: generate all Match records then make another pass to remove unused nodes
Idea: generate all Match records then make another pass to remove unused nodes
- at top-level if record pattern destruct it
- at top-level if record pattern destruct it
 
 
- occurance to match on for switch-constructor
 
- occurance to match on for switch-record
 
- Constructor variable
*)
*)
Loading