Skip to content
Snippets Groups Projects

Perf: Implement a new algorithm for pattern matching compilation to reduce code size

Merged Melwyn Saldanha requested to merge melwyn95@compile_pattern_matching_to_decision_tree into dev
Compare and Show latest version
3 files
+ 45
197
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -16,8 +16,10 @@ let label_to_var (Label.Label name) = Value_var.of_input_var ~loc:Location.gener
@@ -16,8 +16,10 @@ let label_to_var (Label.Label name) = Value_var.of_input_var ~loc:Location.gener
module LabelSet = Set.Make (Label)
module LabelSet = Set.Make (Label)
module LabelMap = Label.Map
module LabelMap = Label.Map
 
module Usages = Set.Make (Value_var)
type signature = LabelSet.t
type signature = LabelSet.t
 
type usages = Usages.t
(* NOTE: Maybe add type expression later if needed *)
(* NOTE: Maybe add type expression later if needed *)
type simple_pattern =
type simple_pattern =
@@ -390,34 +392,39 @@ let swap_column_in_matrix : int -> matrix -> matrix =
@@ -390,34 +392,39 @@ let swap_column_in_matrix : int -> matrix -> matrix =
let rec generate_match_record
let rec generate_match_record
: Value_var.t -> O.type_expression -> decision_tree -> decision_tree
: Value_var.t -> O.type_expression -> bool -> usages -> decision_tree
 
-> bool * decision_tree
=
=
fun constructor_var t dt ->
fun constructor_var t used usages dt ->
match C.get_t_record t with
match C.get_t_record t with
| Some row ->
| Some row ->
let fields = I.Row.to_alist row in
let fields = I.Row.to_alist row in
let case =
let used, case =
List.fold fields ~init:dt ~f:(fun dt (l, ty) ->
List.fold fields ~init:(used, dt) ~f:(fun (used, dt) (l, ty) ->
match C.get_t_record ty with
match C.get_t_record ty with
| Some _ ->
| Some _ ->
generate_match_record (join_vars [ constructor_var; var_of_label l ]) ty dt
let v = join_vars [ constructor_var; var_of_label l ] in
| None -> dt)
let used', dt = generate_match_record v ty used usages dt in
 
used || used', dt
 
| None -> used, dt)
in
in
let field_binders =
let used, field_binders =
Record.of_list
List.fold_map fields ~init:used ~f:(fun used (field_label, t) ->
@@ List.map fields ~f:(fun (field_label, t) ->
let var = join_vars [ constructor_var; var_of_label field_label ] in
let var = join_vars [ constructor_var; var_of_label field_label ] in
let binder = Binder.make var t in
let binder = Binder.make var t in
used || Usages.mem usages var, (field_label, binder))
field_label, binder)
in
in
 
let field_binders = Record.of_list field_binders in
(* TODO: check if the type expression is correct here *)
(* TODO: check if the type expression is correct here *)
let matchee = Binder.make constructor_var t in
let matchee = Binder.make constructor_var t in
SwitchRecord { matchee; field_binders; case; record_type = t }
used, SwitchRecord { matchee; field_binders; case; record_type = t }
| None -> dt
| None -> used, dt
let rec compile : (Value_var.t * O.type_expression) list -> matrix -> decision_tree =
let rec compile
fun occurances matrix ->
: usages -> (Value_var.t * O.type_expression) list -> matrix -> decision_tree
 
=
 
fun usages occurances matrix ->
(* Format.printf "++++++++++++++++++++++\n";
(* Format.printf "++++++++++++++++++++++\n";
List.iter occurances ~f:(fun (o, t) ->
List.iter occurances ~f:(fun (o, t) ->
Format.printf "(%a, %a) " Label.pp o O.PP.type_expression t);
Format.printf "(%a, %a) " Label.pp o O.PP.type_expression t);
@@ -446,17 +453,23 @@ let rec compile : (Value_var.t * O.type_expression) list -> matrix -> decision_t
@@ -446,17 +453,23 @@ let rec compile : (Value_var.t * O.type_expression) list -> matrix -> decision_t
~init:[]
~init:[]
~f:(fun ~key:c ~data:(a, ty, parent_ty) l ->
~f:(fun ~key:c ~data:(a, ty, parent_ty) l ->
let t =
let t =
compile (specialize_occurances c occurances ty) (specialize c a matrix)
compile
 
usages
 
(specialize_occurances c occurances ty)
 
(specialize c a matrix)
in
in
let constructor_var = join_vars [ o; var_of_label c ] in
let constructor_var = join_vars [ o; var_of_label c ] in
let t = generate_match_record constructor_var ty t in
let used, rt = generate_match_record constructor_var ty false usages t in
 
let t = if used then rt else t in
({ constructor = c; binder = constructor_var; subtree = t }, parent_ty)
({ constructor = c; binder = constructor_var; subtree = t }, parent_ty)
:: l)
:: l)
in
in
let default =
let default =
List.map (LabelSet.to_list missing_constructors) ~f:(fun c ->
List.map (LabelSet.to_list missing_constructors) ~f:(fun c ->
let binder = join_vars [ o; var_of_label c ] in
let binder = join_vars [ o; var_of_label c ] in
let subtree = compile (default_occurances occurances) (default matrix) in
let subtree =
 
compile usages (default_occurances occurances) (default matrix)
 
in
{ constructor = c; binder; subtree })
{ constructor = c; binder; subtree })
in
in
let variant_type = List.hd_exn parent_tys in
let variant_type = List.hd_exn parent_tys in
@@ -486,7 +499,7 @@ let rec compile : (Value_var.t * O.type_expression) list -> matrix -> decision_t
@@ -486,7 +499,7 @@ let rec compile : (Value_var.t * O.type_expression) list -> matrix -> decision_t
Format.printf "\n$$$$$$$$$$$$$$$$$\n"; *)
Format.printf "\n$$$$$$$$$$$$$$$$$\n"; *)
raise e
raise e
in
in
compile swapped_occurances swapped_matrix)
compile usages swapped_occurances swapped_matrix)
let fail ~loc ty () =
let fail ~loc ty () =
@@ -568,38 +581,18 @@ let test
@@ -568,38 +581,18 @@ let test
"%a\n"
"%a\n"
(Simple_utils.PP_helpers.list_sep_d Sexp.pp_hum)
(Simple_utils.PP_helpers.list_sep_d Sexp.pp_hum)
(List.map sp ~f:sexp_of_simple_pattern)) *)
(List.map sp ~f:sexp_of_simple_pattern)) *)
let matrix =
let usages, matrix =
List.map cases ~f:(fun { pattern; body } ->
List.fold_map cases ~init:Usages.empty ~f:(fun usages { pattern; body } ->
let vars_projs = get_vars_and_projections matchee pattern in
let vars_projs = get_vars_and_projections matchee pattern in
let body =
let body, usages =
List.fold
List.fold
vars_projs
vars_projs
~init:body
~init:(body, usages)
~f:(fun body { bound_var = to_subst; new_var } ->
~f:(fun (body, usages) { bound_var = to_subst; new_var } ->
(* Format.printf
( Substitution.substitute_var_in_body ~to_subst ~new_var body
"pattern = %a | to_subst = %a | new_var = %a | body = %a\n"
, Usages.add usages new_var ))
(I.Pattern.pp I.PP.type_expression)
pattern
Value_var.pp
to_subst
Value_var.pp
new_var
O.PP.expression
body; *)
Substitution.substitute_var_in_body ~to_subst ~new_var body)
in
in
(* Format.printf
usages, (to_simple_pattern matchee_type pattern, body))
"\n\
******************\n\
old body = %a\n\
******************\n\
\ new_body = %a\n\
******************\n"
O.PP.expression
body
O.PP.expression
body'; *)
to_simple_pattern matchee_type pattern, body)
in
in
(* print_matrix matrix; *)
(* print_matrix matrix; *)
let top_occurances =
let top_occurances =
@@ -609,8 +602,9 @@ let test
@@ -609,8 +602,9 @@ let test
"top_occurances = %a\n"
"top_occurances = %a\n"
(Simple_utils.PP_helpers.list_sep_d Label.pp)
(Simple_utils.PP_helpers.list_sep_d Label.pp)
top_occurances; *)
top_occurances; *)
let dt = compile top_occurances matrix in
let dt = compile usages top_occurances matrix in
let dt = generate_match_record matchee matchee_type dt in
let used, rdt = generate_match_record matchee matchee_type false usages dt in
 
let dt = if used then rdt else dt in
let match_expr = to_match_expression ~fail ~body_type dt in
let match_expr = to_match_expression ~fail ~body_type dt in
(* Format.printf "\n@@@@@@@@@@@@\n%a\n@@@@@@@@@@@@\n" O.PP.expression match_expr; *)
(* Format.printf "\n@@@@@@@@@@@@\n%a\n@@@@@@@@@@@@\n" O.PP.expression match_expr; *)
(* print_decision_tree dt; *)
(* print_decision_tree dt; *)
Loading