Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bind methods in pure #412

Merged
merged 4 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 73 additions & 183 deletions src/extract/Extract.ml

Large diffs are not rendered by default.

125 changes: 80 additions & 45 deletions src/extract/ExtractBase.ml
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,18 @@ let decl_is_not_last_from_group (kind : decl_kind) : bool =

type type_decl_kind = Enum | Struct | Tuple [@@deriving show]

(** Generics can be bound in two places: each item has its generics, and
additionally within a trait decl or impl each method has its own generics.
We distinguish these two cases here. In charon, the distinction is made
thanks to `de_bruijn_var`.
Note that for the generics of a top-level `fun_decl` we always use `Item`;
`Method` only refers to the inner binder found in the list of methods in a
trait_decl/trait_impl.
*)
type generic_origin = Item | Method

(** We use identifiers to look for name clashes *)
type id =
and id =
| GlobalId of A.GlobalDeclId.id
| FunId of fun_id
| TerminationMeasureId of (A.fun_id * LoopId.id option)
Expand Down Expand Up @@ -162,12 +172,12 @@ type id =
must be unique (it is the case in F* ) which is why we register
them here.
*)
| TypeVarId of TypeVarId.id
| ConstGenericVarId of ConstGenericVarId.id
| VarId of VarId.id
| TraitDeclId of TraitDeclId.id
| TraitImplId of TraitImplId.id
| LocalTraitClauseId of TraitClauseId.id
| TypeVarId of generic_origin * TypeVarId.id
| ConstGenericVarId of generic_origin * ConstGenericVarId.id
| LocalTraitClauseId of generic_origin * TraitClauseId.id
| TraitDeclConstructorId of TraitDeclId.id
| TraitMethodId of TraitDeclId.id * string
| TraitItemId of TraitDeclId.id * string
Expand Down Expand Up @@ -674,14 +684,19 @@ let id_to_string (span : Meta.span option) (id : id) (ctx : extraction_ctx) :
let field_name = adt_field_to_string span ctx id field_id in
"type name: " ^ type_name ^ ", field name: " ^ field_name
| UnknownId -> "keyword"
| TypeVarId id -> "type_var_id: " ^ TypeVarId.to_string id
| ConstGenericVarId id ->
"const_generic_var_id: " ^ ConstGenericVarId.to_string id
| VarId id -> "var_id: " ^ VarId.to_string id
| TraitDeclId id -> "trait_decl_id: " ^ TraitDeclId.to_string id
| TraitImplId id -> "trait_impl_id: " ^ TraitImplId.to_string id
| LocalTraitClauseId id ->
"local_trait_clause_id: " ^ TraitClauseId.to_string id
| TypeVarId (origin, id) ->
"type_var_id: " ^ TypeVarId.to_string id ^ " from "
^ show_generic_origin origin
| ConstGenericVarId (origin, id) ->
"const_generic_var_id: "
^ ConstGenericVarId.to_string id
^ " from " ^ show_generic_origin origin
| LocalTraitClauseId (origin, id) ->
"local_trait_clause_id: " ^ TraitClauseId.to_string id ^ " from "
^ show_generic_origin origin
| TraitDeclConstructorId id ->
"trait_decl_constructor: " ^ trait_decl_id_to_string id
| TraitParentClauseId (id, clause_id) ->
Expand Down Expand Up @@ -771,17 +786,32 @@ let ctx_get_var (span : Meta.span) (id : VarId.id) (ctx : extraction_ctx) :
string =
ctx_get (Some span) (VarId id) ctx

let ctx_get_type_var (span : Meta.span) (id : TypeVarId.id)
(ctx : extraction_ctx) : string =
ctx_get (Some span) (TypeVarId id) ctx
(** This warrants explanations. Charon supports several levels of nested
binders; however there are currently only two cases where we bind
non-lifetime variables: at the top-level of each item, and for each method
inside a trait_decl/trait_impl. Moreover, we use `Free` vars to identify
item-bound vars. This means that we can identify which binder a variable
comes from without rigorously tracking binder levels, which is what this
function does.
Note that the `de_bruijn_id`s are wrong anyway because we kept charon's
binding levels but forgot all the region binders.
*)
let origin_from_de_bruijn_var (var : 'a de_bruijn_var) : generic_origin * 'a =
match var with
| Bound (_, id) -> (Method, id)
| Free id -> (Item, id)

let ctx_get_const_generic_var (span : Meta.span) (id : ConstGenericVarId.id)
(ctx : extraction_ctx) : string =
ctx_get (Some span) (ConstGenericVarId id) ctx
let ctx_get_type_var (span : Meta.span) (origin : generic_origin)
(id : TypeVarId.id) (ctx : extraction_ctx) : string =
ctx_get (Some span) (TypeVarId (origin, id)) ctx

let ctx_get_local_trait_clause (span : Meta.span) (id : TraitClauseId.id)
(ctx : extraction_ctx) : string =
ctx_get (Some span) (LocalTraitClauseId id) ctx
let ctx_get_const_generic_var (span : Meta.span) (origin : generic_origin)
(id : ConstGenericVarId.id) (ctx : extraction_ctx) : string =
ctx_get (Some span) (ConstGenericVarId (origin, id)) ctx

let ctx_get_local_trait_clause (span : Meta.span) (origin : generic_origin)
(id : TraitClauseId.id) (ctx : extraction_ctx) : string =
ctx_get (Some span) (LocalTraitClauseId (origin, id)) ctx

let ctx_get_field (span : Meta.span) (type_id : type_id) (field_id : FieldId.id)
(ctx : extraction_ctx) : string =
Expand Down Expand Up @@ -1956,27 +1986,29 @@ let basename_to_unique (ctx : extraction_ctx) (name : string) =
basename_to_unique_aux collision name_append_index name

(** Generate a unique type variable name and add it to the context *)
let ctx_add_type_var (span : Meta.span) (basename : string) (id : TypeVarId.id)
(ctx : extraction_ctx) : extraction_ctx * string =
let ctx_add_type_var (span : Meta.span) (origin : generic_origin)
(basename : string) (id : TypeVarId.id) (ctx : extraction_ctx) :
extraction_ctx * string =
let name = ctx_compute_type_var_basename ctx basename in
let name = basename_to_unique ctx name in
let ctx = ctx_add span (TypeVarId id) name ctx in
let ctx = ctx_add span (TypeVarId (origin, id)) name ctx in
(ctx, name)

(** Generate a unique const generic variable name and add it to the context *)
let ctx_add_const_generic_var (span : Meta.span) (basename : string)
(id : ConstGenericVarId.id) (ctx : extraction_ctx) : extraction_ctx * string
=
let ctx_add_const_generic_var (span : Meta.span) (origin : generic_origin)
(basename : string) (id : ConstGenericVarId.id) (ctx : extraction_ctx) :
extraction_ctx * string =
let name = ctx_compute_const_generic_var_basename ctx basename in
let name = basename_to_unique ctx name in
let ctx = ctx_add span (ConstGenericVarId id) name ctx in
let ctx = ctx_add span (ConstGenericVarId (origin, id)) name ctx in
(ctx, name)

(** See {!ctx_add_type_var} *)
let ctx_add_type_vars (span : Meta.span) (vars : (string * TypeVarId.id) list)
(ctx : extraction_ctx) : extraction_ctx * string list =
let ctx_add_type_vars (span : Meta.span) (origin : generic_origin)
(vars : (string * TypeVarId.id) list) (ctx : extraction_ctx) :
extraction_ctx * string list =
List.fold_left_map
(fun ctx (name, id) -> ctx_add_type_var span name id ctx)
(fun ctx (name, id) -> ctx_add_type_var span origin name id ctx)
ctx vars

(** Generate a unique variable name and add it to the context *)
Expand All @@ -1995,10 +2027,11 @@ let ctx_add_trait_self_clause (span : Meta.span) (ctx : extraction_ctx) :
(ctx, name)

(** Generate a unique trait clause name and add it to the context *)
let ctx_add_local_trait_clause (span : Meta.span) (basename : string)
(id : TraitClauseId.id) (ctx : extraction_ctx) : extraction_ctx * string =
let ctx_add_local_trait_clause (span : Meta.span) (origin : generic_origin)
(basename : string) (id : TraitClauseId.id) (ctx : extraction_ctx) :
extraction_ctx * string =
let name = basename_to_unique ctx basename in
let ctx = ctx_add span (LocalTraitClauseId id) name ctx in
let ctx = ctx_add span (LocalTraitClauseId (origin, id)) name ctx in
(ctx, name)

(** See {!ctx_add_var} *)
Expand All @@ -2010,18 +2043,20 @@ let ctx_add_vars (span : Meta.span) (vars : var list) (ctx : extraction_ctx) :
ctx_add_var span name v.id ctx)
ctx vars

let ctx_add_type_params (span : Meta.span) (vars : type_var list)
(ctx : extraction_ctx) : extraction_ctx * string list =
let ctx_add_type_params (span : Meta.span) (origin : generic_origin)
(vars : type_var list) (ctx : extraction_ctx) : extraction_ctx * string list
=
List.fold_left_map
(fun ctx (var : type_var) -> ctx_add_type_var span var.name var.index ctx)
(fun ctx (var : type_var) ->
ctx_add_type_var span origin var.name var.index ctx)
ctx vars

let ctx_add_const_generic_params (span : Meta.span)
let ctx_add_const_generic_params (span : Meta.span) (origin : generic_origin)
(vars : const_generic_var list) (ctx : extraction_ctx) :
extraction_ctx * string list =
List.fold_left_map
(fun ctx (var : const_generic_var) ->
ctx_add_const_generic_var span var.name var.index ctx)
ctx_add_const_generic_var span origin var.name var.index ctx)
ctx vars

(** Returns the lists of names for:
Expand All @@ -2034,16 +2069,16 @@ let ctx_add_const_generic_params (span : Meta.span)
for additional information.
*)
let ctx_add_local_trait_clauses (span : Meta.span)
(current_def_name : Types.name) (llbc_generics : Types.generic_params)
(clauses : trait_clause list) (ctx : extraction_ctx) :
extraction_ctx * string list =
(current_def_name : Types.name) (origin : generic_origin)
(llbc_generics : Types.generic_params) (clauses : trait_clause list)
(ctx : extraction_ctx) : extraction_ctx * string list =
List.fold_left_map
(fun ctx (c : trait_clause) ->
let basename =
ctx_compute_trait_clause_basename ctx current_def_name llbc_generics
c.clause_id
in
ctx_add_local_trait_clause span basename c.clause_id ctx)
ctx_add_local_trait_clause span origin basename c.clause_id ctx)
ctx clauses

(** Returns the lists of names for:
Expand All @@ -2056,14 +2091,14 @@ let ctx_add_local_trait_clauses (span : Meta.span)
for additional information.
*)
let ctx_add_generic_params (span : Meta.span) (current_def_name : Types.name)
(llbc_generics : Types.generic_params) (generics : generic_params)
(ctx : extraction_ctx) :
(origin : generic_origin) (llbc_generics : Types.generic_params)
(generics : generic_params) (ctx : extraction_ctx) :
extraction_ctx * string list * string list * string list =
let { types; const_generics; trait_clauses } = generics in
let ctx, tys = ctx_add_type_params span types ctx in
let ctx, cgs = ctx_add_const_generic_params span const_generics ctx in
let ctx, tys = ctx_add_type_params span origin types ctx in
let ctx, cgs = ctx_add_const_generic_params span origin const_generics ctx in
let ctx, tcs =
ctx_add_local_trait_clauses span current_def_name llbc_generics
ctx_add_local_trait_clauses span current_def_name origin llbc_generics
trait_clauses ctx
in
(ctx, tys, cgs, tcs)
Expand Down
46 changes: 25 additions & 21 deletions src/extract/ExtractTypes.ml
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,8 @@ let extract_const_generic (span : Meta.span) (ctx : extraction_ctx)
F.pp_print_string fmt s
| CgValue v -> extract_literal span fmt false inside v
| CgVar var ->
let id = TypesUtils.expect_free_var (Some span) var in
let s = ctx_get_const_generic_var span id ctx in
let origin, id = origin_from_de_bruijn_var var in
let s = ctx_get_const_generic_var span origin id ctx in
F.pp_print_string fmt s

let extract_literal_type (_ctx : extraction_ctx) (fmt : F.formatter)
Expand Down Expand Up @@ -582,7 +582,9 @@ let rec extract_ty (span : Meta.span) (ctx : extraction_ctx) (fmt : F.formatter)
Collections.List.iter_link (F.pp_print_space fmt)
(extract_trait_ref span ctx fmt no_params_tys true)
trait_refs)))
| TVar vid -> F.pp_print_string fmt (ctx_get_type_var span vid ctx)
| TVar var ->
let origin, id = origin_from_de_bruijn_var var in
F.pp_print_string fmt (ctx_get_type_var span origin id ctx)
| TLiteral lty -> extract_literal_type ctx fmt lty
| TArrow (arg_ty, ret_ty) ->
if inside then F.pp_print_string fmt "(";
Expand Down Expand Up @@ -772,8 +774,9 @@ and extract_trait_instance_id (span : Meta.span) (ctx : extraction_ctx)
F.pp_print_string fmt name;
extract_generic_args span ctx fmt no_params_tys ~explicit generics;
if use_brackets then F.pp_print_string fmt ")"
| Clause id ->
let name = ctx_get_local_trait_clause span id ctx in
| Clause var ->
let origin, id = origin_from_de_bruijn_var var in
let name = ctx_get_local_trait_clause span origin id ctx in
F.pp_print_string fmt name
| ParentClause (inst_id, decl_id, clause_id) ->
(* Use the trait decl id to lookup the name *)
Expand Down Expand Up @@ -1295,9 +1298,10 @@ let extract_generic_params (span : Meta.span) (ctx : extraction_ctx)
(fmt : F.formatter) (no_params_tys : TypeDeclId.Set.t) ?(use_forall = false)
?(use_forall_use_sep = true) ?(use_arrows = false)
?(as_implicits : bool = false) ?(space : bool ref option = None)
?(trait_decl : trait_decl option = None) (generics : generic_params)
(explicit : explicit_info option) (type_params : string list)
(cg_params : string list) (trait_clauses : string list) : unit =
?(trait_decl : trait_decl option = None) (origin : generic_origin)
(generics : generic_params) (explicit : explicit_info option)
(type_params : string list) (cg_params : string list)
(trait_clauses : string list) : unit =
let all_params = List.concat [ type_params; cg_params; trait_clauses ] in
(* HOL4 doesn't support const generics *)
cassert __FILE__ __LINE__
Expand Down Expand Up @@ -1359,7 +1363,7 @@ let extract_generic_params (span : Meta.span) (ctx : extraction_ctx)
insert_req_space ();
(* ( *)
left_bracket expl;
let n = ctx_get_const_generic_var span var.index ctx in
let n = ctx_get_const_generic_var span origin var.index ctx in
print_implicit_symbol expl;
F.pp_print_string fmt n;
F.pp_print_space fmt ();
Expand All @@ -1378,7 +1382,7 @@ let extract_generic_params (span : Meta.span) (ctx : extraction_ctx)
insert_req_space ();
(* ( *)
left_bracket expl;
let n = ctx_get_local_trait_clause span clause.clause_id ctx in
let n = ctx_get_local_trait_clause span origin clause.clause_id ctx in
print_implicit_symbol expl;
F.pp_print_string fmt n;
F.pp_print_space fmt ();
Expand Down Expand Up @@ -1445,12 +1449,12 @@ let extract_generic_params (span : Meta.span) (ctx : extraction_ctx)
map snd dtype_params;
map
(fun ((_, cg) : _ * const_generic_var) ->
ctx_get_const_generic_var trait_decl.item_meta.span cg.index
ctx)
ctx_get_const_generic_var trait_decl.item_meta.span origin
cg.index ctx)
dcgs;
map
(fun (_, c) ->
ctx_get_local_trait_clause trait_decl.item_meta.span
ctx_get_local_trait_clause trait_decl.item_meta.span origin
c.clause_id ctx)
dtrait_clauses;
]
Expand Down Expand Up @@ -1507,7 +1511,7 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(* Add the type and const generic params - note that we need those bindings only for the
* body translation (they are not top-level) *)
let ctx_body, type_params, cg_params, trait_clauses =
ctx_add_generic_params def.item_meta.span def.item_meta.name
ctx_add_generic_params def.item_meta.span def.item_meta.name Item
def.llbc_generics def.generics ctx
in
(* Add a break before *)
Expand Down Expand Up @@ -1553,7 +1557,7 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
"Constant generics and type definitions with trait clauses are not \
supported yet when generating code for HOL4";
(* Print the generic parameters *)
extract_generic_params def.item_meta.span ctx_body fmt type_decl_group
extract_generic_params def.item_meta.span ctx_body fmt type_decl_group Item
~use_forall def.generics (Some def.explicit_info) type_params cg_params
trait_clauses;
(* Print the "=" if we extract the body*)
Expand Down Expand Up @@ -1795,7 +1799,7 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
if is_rec then
(* Add the type params *)
let ctx, type_params, cg_params, trait_clauses =
ctx_add_generic_params decl.item_meta.span decl.item_meta.name
ctx_add_generic_params decl.item_meta.span decl.item_meta.name Item
decl.llbc_generics decl.generics ctx
in
(* Record_var will be the ADT argument to the projector *)
Expand Down Expand Up @@ -1855,8 +1859,8 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
(* Print the generics *)
let as_implicits = true in
extract_generic_params decl.item_meta.span ctx fmt
TypeDeclId.Set.empty ~as_implicits decl.generics None type_params
cg_params trait_clauses;
TypeDeclId.Set.empty Item ~as_implicits decl.generics None
type_params cg_params trait_clauses;

(* Print the record parameter as "(x : ADT)" *)
F.pp_print_space fmt ();
Expand Down Expand Up @@ -1996,8 +2000,8 @@ let extract_type_decl_record_field_projectors_simp_lemmas (ctx : extraction_ctx)
if is_rec then
(* Add the type params *)
let ctx, type_params, cg_params, trait_clauses =
ctx_add_generic_params span decl.item_meta.name decl.llbc_generics
decl.generics ctx
ctx_add_generic_params span decl.item_meta.name Item
decl.llbc_generics decl.generics ctx
in
(* Name of the ADT *)
let def_name = ctx_get_local_type span decl.def_id ctx in
Expand Down Expand Up @@ -2042,7 +2046,7 @@ let extract_type_decl_record_field_projectors_simp_lemmas (ctx : extraction_ctx)
(* Print the generics *)
let as_implicits = true in
extract_generic_params span ctx fmt TypeDeclId.Set.empty ~as_implicits
decl.generics None type_params cg_params trait_clauses;
Item decl.generics None type_params cg_params trait_clauses;

(* Print the input parameters (the fields) *)
let print_field (ctx : extraction_ctx) (field_id : FieldId.id)
Expand Down
Loading