Skip to content

Commit

Permalink
Merge pull request #412 from Nadrieril/bind-methods-in-pure
Browse files Browse the repository at this point in the history
  • Loading branch information
Nadrieril authored Jan 8, 2025
2 parents 4c81315 + 57e185e commit dd87c9d
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 309 deletions.
256 changes: 73 additions & 183 deletions src/extract/Extract.ml

Large diffs are not rendered by default.

125 changes: 80 additions & 45 deletions src/extract/ExtractBase.ml
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,18 @@ let decl_is_not_last_from_group (kind : decl_kind) : bool =

type type_decl_kind = Enum | Struct | Tuple [@@deriving show]

(** Generics can be bound in two places: each item has its generics, and
additionally within a trait decl or impl each method has its own generics.
We distinguish these two cases here. In charon, the distinction is made
thanks to `de_bruijn_var`.
Note that for the generics of a top-level `fun_decl` we always use `Item`;
`Method` only refers to the inner binder found in the list of methods in a
trait_decl/trait_impl.
*)
type generic_origin = Item | Method

(** We use identifiers to look for name clashes *)
type id =
and id =
| GlobalId of A.GlobalDeclId.id
| FunId of fun_id
| TerminationMeasureId of (A.fun_id * LoopId.id option)
Expand Down Expand Up @@ -162,12 +172,12 @@ type id =
must be unique (it is the case in F* ) which is why we register
them here.
*)
| TypeVarId of TypeVarId.id
| ConstGenericVarId of ConstGenericVarId.id
| VarId of VarId.id
| TraitDeclId of TraitDeclId.id
| TraitImplId of TraitImplId.id
| LocalTraitClauseId of TraitClauseId.id
| TypeVarId of generic_origin * TypeVarId.id
| ConstGenericVarId of generic_origin * ConstGenericVarId.id
| LocalTraitClauseId of generic_origin * TraitClauseId.id
| TraitDeclConstructorId of TraitDeclId.id
| TraitMethodId of TraitDeclId.id * string
| TraitItemId of TraitDeclId.id * string
Expand Down Expand Up @@ -674,14 +684,19 @@ let id_to_string (span : Meta.span option) (id : id) (ctx : extraction_ctx) :
let field_name = adt_field_to_string span ctx id field_id in
"type name: " ^ type_name ^ ", field name: " ^ field_name
| UnknownId -> "keyword"
| TypeVarId id -> "type_var_id: " ^ TypeVarId.to_string id
| ConstGenericVarId id ->
"const_generic_var_id: " ^ ConstGenericVarId.to_string id
| VarId id -> "var_id: " ^ VarId.to_string id
| TraitDeclId id -> "trait_decl_id: " ^ TraitDeclId.to_string id
| TraitImplId id -> "trait_impl_id: " ^ TraitImplId.to_string id
| LocalTraitClauseId id ->
"local_trait_clause_id: " ^ TraitClauseId.to_string id
| TypeVarId (origin, id) ->
"type_var_id: " ^ TypeVarId.to_string id ^ " from "
^ show_generic_origin origin
| ConstGenericVarId (origin, id) ->
"const_generic_var_id: "
^ ConstGenericVarId.to_string id
^ " from " ^ show_generic_origin origin
| LocalTraitClauseId (origin, id) ->
"local_trait_clause_id: " ^ TraitClauseId.to_string id ^ " from "
^ show_generic_origin origin
| TraitDeclConstructorId id ->
"trait_decl_constructor: " ^ trait_decl_id_to_string id
| TraitParentClauseId (id, clause_id) ->
Expand Down Expand Up @@ -771,17 +786,32 @@ let ctx_get_var (span : Meta.span) (id : VarId.id) (ctx : extraction_ctx) :
string =
ctx_get (Some span) (VarId id) ctx

let ctx_get_type_var (span : Meta.span) (id : TypeVarId.id)
(ctx : extraction_ctx) : string =
ctx_get (Some span) (TypeVarId id) ctx
(** This warrants explanations. Charon supports several levels of nested
binders; however there are currently only two cases where we bind
non-lifetime variables: at the top-level of each item, and for each method
inside a trait_decl/trait_impl. Moreover, we use `Free` vars to identify
item-bound vars. This means that we can identify which binder a variable
comes from without rigorously tracking binder levels, which is what this
function does.
Note that the `de_bruijn_id`s are wrong anyway because we kept charon's
binding levels but forgot all the region binders.
*)
let origin_from_de_bruijn_var (var : 'a de_bruijn_var) : generic_origin * 'a =
match var with
| Bound (_, id) -> (Method, id)
| Free id -> (Item, id)

let ctx_get_const_generic_var (span : Meta.span) (id : ConstGenericVarId.id)
(ctx : extraction_ctx) : string =
ctx_get (Some span) (ConstGenericVarId id) ctx
let ctx_get_type_var (span : Meta.span) (origin : generic_origin)
(id : TypeVarId.id) (ctx : extraction_ctx) : string =
ctx_get (Some span) (TypeVarId (origin, id)) ctx

let ctx_get_local_trait_clause (span : Meta.span) (id : TraitClauseId.id)
(ctx : extraction_ctx) : string =
ctx_get (Some span) (LocalTraitClauseId id) ctx
let ctx_get_const_generic_var (span : Meta.span) (origin : generic_origin)
(id : ConstGenericVarId.id) (ctx : extraction_ctx) : string =
ctx_get (Some span) (ConstGenericVarId (origin, id)) ctx

let ctx_get_local_trait_clause (span : Meta.span) (origin : generic_origin)
(id : TraitClauseId.id) (ctx : extraction_ctx) : string =
ctx_get (Some span) (LocalTraitClauseId (origin, id)) ctx

let ctx_get_field (span : Meta.span) (type_id : type_id) (field_id : FieldId.id)
(ctx : extraction_ctx) : string =
Expand Down Expand Up @@ -1956,27 +1986,29 @@ let basename_to_unique (ctx : extraction_ctx) (name : string) =
basename_to_unique_aux collision name_append_index name

(** Generate a unique type variable name and add it to the context *)
let ctx_add_type_var (span : Meta.span) (basename : string) (id : TypeVarId.id)
(ctx : extraction_ctx) : extraction_ctx * string =
let ctx_add_type_var (span : Meta.span) (origin : generic_origin)
(basename : string) (id : TypeVarId.id) (ctx : extraction_ctx) :
extraction_ctx * string =
let name = ctx_compute_type_var_basename ctx basename in
let name = basename_to_unique ctx name in
let ctx = ctx_add span (TypeVarId id) name ctx in
let ctx = ctx_add span (TypeVarId (origin, id)) name ctx in
(ctx, name)

(** Generate a unique const generic variable name and add it to the context *)
let ctx_add_const_generic_var (span : Meta.span) (basename : string)
(id : ConstGenericVarId.id) (ctx : extraction_ctx) : extraction_ctx * string
=
let ctx_add_const_generic_var (span : Meta.span) (origin : generic_origin)
(basename : string) (id : ConstGenericVarId.id) (ctx : extraction_ctx) :
extraction_ctx * string =
let name = ctx_compute_const_generic_var_basename ctx basename in
let name = basename_to_unique ctx name in
let ctx = ctx_add span (ConstGenericVarId id) name ctx in
let ctx = ctx_add span (ConstGenericVarId (origin, id)) name ctx in
(ctx, name)

(** See {!ctx_add_type_var} *)
let ctx_add_type_vars (span : Meta.span) (vars : (string * TypeVarId.id) list)
(ctx : extraction_ctx) : extraction_ctx * string list =
let ctx_add_type_vars (span : Meta.span) (origin : generic_origin)
(vars : (string * TypeVarId.id) list) (ctx : extraction_ctx) :
extraction_ctx * string list =
List.fold_left_map
(fun ctx (name, id) -> ctx_add_type_var span name id ctx)
(fun ctx (name, id) -> ctx_add_type_var span origin name id ctx)
ctx vars

(** Generate a unique variable name and add it to the context *)
Expand All @@ -1995,10 +2027,11 @@ let ctx_add_trait_self_clause (span : Meta.span) (ctx : extraction_ctx) :
(ctx, name)

(** Generate a unique trait clause name and add it to the context *)
let ctx_add_local_trait_clause (span : Meta.span) (basename : string)
(id : TraitClauseId.id) (ctx : extraction_ctx) : extraction_ctx * string =
let ctx_add_local_trait_clause (span : Meta.span) (origin : generic_origin)
(basename : string) (id : TraitClauseId.id) (ctx : extraction_ctx) :
extraction_ctx * string =
let name = basename_to_unique ctx basename in
let ctx = ctx_add span (LocalTraitClauseId id) name ctx in
let ctx = ctx_add span (LocalTraitClauseId (origin, id)) name ctx in
(ctx, name)

(** See {!ctx_add_var} *)
Expand All @@ -2010,18 +2043,20 @@ let ctx_add_vars (span : Meta.span) (vars : var list) (ctx : extraction_ctx) :
ctx_add_var span name v.id ctx)
ctx vars

let ctx_add_type_params (span : Meta.span) (vars : type_var list)
(ctx : extraction_ctx) : extraction_ctx * string list =
let ctx_add_type_params (span : Meta.span) (origin : generic_origin)
(vars : type_var list) (ctx : extraction_ctx) : extraction_ctx * string list
=
List.fold_left_map
(fun ctx (var : type_var) -> ctx_add_type_var span var.name var.index ctx)
(fun ctx (var : type_var) ->
ctx_add_type_var span origin var.name var.index ctx)
ctx vars

let ctx_add_const_generic_params (span : Meta.span)
let ctx_add_const_generic_params (span : Meta.span) (origin : generic_origin)
(vars : const_generic_var list) (ctx : extraction_ctx) :
extraction_ctx * string list =
List.fold_left_map
(fun ctx (var : const_generic_var) ->
ctx_add_const_generic_var span var.name var.index ctx)
ctx_add_const_generic_var span origin var.name var.index ctx)
ctx vars

(** Returns the lists of names for:
Expand All @@ -2034,16 +2069,16 @@ let ctx_add_const_generic_params (span : Meta.span)
for additional information.
*)
let ctx_add_local_trait_clauses (span : Meta.span)
(current_def_name : Types.name) (llbc_generics : Types.generic_params)
(clauses : trait_clause list) (ctx : extraction_ctx) :
extraction_ctx * string list =
(current_def_name : Types.name) (origin : generic_origin)
(llbc_generics : Types.generic_params) (clauses : trait_clause list)
(ctx : extraction_ctx) : extraction_ctx * string list =
List.fold_left_map
(fun ctx (c : trait_clause) ->
let basename =
ctx_compute_trait_clause_basename ctx current_def_name llbc_generics
c.clause_id
in
ctx_add_local_trait_clause span basename c.clause_id ctx)
ctx_add_local_trait_clause span origin basename c.clause_id ctx)
ctx clauses

(** Returns the lists of names for:
Expand All @@ -2056,14 +2091,14 @@ let ctx_add_local_trait_clauses (span : Meta.span)
for additional information.
*)
let ctx_add_generic_params (span : Meta.span) (current_def_name : Types.name)
(llbc_generics : Types.generic_params) (generics : generic_params)
(ctx : extraction_ctx) :
(origin : generic_origin) (llbc_generics : Types.generic_params)
(generics : generic_params) (ctx : extraction_ctx) :
extraction_ctx * string list * string list * string list =
let { types; const_generics; trait_clauses } = generics in
let ctx, tys = ctx_add_type_params span types ctx in
let ctx, cgs = ctx_add_const_generic_params span const_generics ctx in
let ctx, tys = ctx_add_type_params span origin types ctx in
let ctx, cgs = ctx_add_const_generic_params span origin const_generics ctx in
let ctx, tcs =
ctx_add_local_trait_clauses span current_def_name llbc_generics
ctx_add_local_trait_clauses span current_def_name origin llbc_generics
trait_clauses ctx
in
(ctx, tys, cgs, tcs)
Expand Down
46 changes: 25 additions & 21 deletions src/extract/ExtractTypes.ml
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,8 @@ let extract_const_generic (span : Meta.span) (ctx : extraction_ctx)
F.pp_print_string fmt s
| CgValue v -> extract_literal span fmt false inside v
| CgVar var ->
let id = TypesUtils.expect_free_var (Some span) var in
let s = ctx_get_const_generic_var span id ctx in
let origin, id = origin_from_de_bruijn_var var in
let s = ctx_get_const_generic_var span origin id ctx in
F.pp_print_string fmt s

let extract_literal_type (_ctx : extraction_ctx) (fmt : F.formatter)
Expand Down Expand Up @@ -582,7 +582,9 @@ let rec extract_ty (span : Meta.span) (ctx : extraction_ctx) (fmt : F.formatter)
Collections.List.iter_link (F.pp_print_space fmt)
(extract_trait_ref span ctx fmt no_params_tys true)
trait_refs)))
| TVar vid -> F.pp_print_string fmt (ctx_get_type_var span vid ctx)
| TVar var ->
let origin, id = origin_from_de_bruijn_var var in
F.pp_print_string fmt (ctx_get_type_var span origin id ctx)
| TLiteral lty -> extract_literal_type ctx fmt lty
| TArrow (arg_ty, ret_ty) ->
if inside then F.pp_print_string fmt "(";
Expand Down Expand Up @@ -772,8 +774,9 @@ and extract_trait_instance_id (span : Meta.span) (ctx : extraction_ctx)
F.pp_print_string fmt name;
extract_generic_args span ctx fmt no_params_tys ~explicit generics;
if use_brackets then F.pp_print_string fmt ")"
| Clause id ->
let name = ctx_get_local_trait_clause span id ctx in
| Clause var ->
let origin, id = origin_from_de_bruijn_var var in
let name = ctx_get_local_trait_clause span origin id ctx in
F.pp_print_string fmt name
| ParentClause (inst_id, decl_id, clause_id) ->
(* Use the trait decl id to lookup the name *)
Expand Down Expand Up @@ -1295,9 +1298,10 @@ let extract_generic_params (span : Meta.span) (ctx : extraction_ctx)
(fmt : F.formatter) (no_params_tys : TypeDeclId.Set.t) ?(use_forall = false)
?(use_forall_use_sep = true) ?(use_arrows = false)
?(as_implicits : bool = false) ?(space : bool ref option = None)
?(trait_decl : trait_decl option = None) (generics : generic_params)
(explicit : explicit_info option) (type_params : string list)
(cg_params : string list) (trait_clauses : string list) : unit =
?(trait_decl : trait_decl option = None) (origin : generic_origin)
(generics : generic_params) (explicit : explicit_info option)
(type_params : string list) (cg_params : string list)
(trait_clauses : string list) : unit =
let all_params = List.concat [ type_params; cg_params; trait_clauses ] in
(* HOL4 doesn't support const generics *)
cassert __FILE__ __LINE__
Expand Down Expand Up @@ -1359,7 +1363,7 @@ let extract_generic_params (span : Meta.span) (ctx : extraction_ctx)
insert_req_space ();
(* ( *)
left_bracket expl;
let n = ctx_get_const_generic_var span var.index ctx in
let n = ctx_get_const_generic_var span origin var.index ctx in
print_implicit_symbol expl;
F.pp_print_string fmt n;
F.pp_print_space fmt ();
Expand All @@ -1378,7 +1382,7 @@ let extract_generic_params (span : Meta.span) (ctx : extraction_ctx)
insert_req_space ();
(* ( *)
left_bracket expl;
let n = ctx_get_local_trait_clause span clause.clause_id ctx in
let n = ctx_get_local_trait_clause span origin clause.clause_id ctx in
print_implicit_symbol expl;
F.pp_print_string fmt n;
F.pp_print_space fmt ();
Expand Down Expand Up @@ -1445,12 +1449,12 @@ let extract_generic_params (span : Meta.span) (ctx : extraction_ctx)
map snd dtype_params;
map
(fun ((_, cg) : _ * const_generic_var) ->
ctx_get_const_generic_var trait_decl.item_meta.span cg.index
ctx)
ctx_get_const_generic_var trait_decl.item_meta.span origin
cg.index ctx)
dcgs;
map
(fun (_, c) ->
ctx_get_local_trait_clause trait_decl.item_meta.span
ctx_get_local_trait_clause trait_decl.item_meta.span origin
c.clause_id ctx)
dtrait_clauses;
]
Expand Down Expand Up @@ -1507,7 +1511,7 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(* Add the type and const generic params - note that we need those bindings only for the
* body translation (they are not top-level) *)
let ctx_body, type_params, cg_params, trait_clauses =
ctx_add_generic_params def.item_meta.span def.item_meta.name
ctx_add_generic_params def.item_meta.span def.item_meta.name Item
def.llbc_generics def.generics ctx
in
(* Add a break before *)
Expand Down Expand Up @@ -1553,7 +1557,7 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
"Constant generics and type definitions with trait clauses are not \
supported yet when generating code for HOL4";
(* Print the generic parameters *)
extract_generic_params def.item_meta.span ctx_body fmt type_decl_group
extract_generic_params def.item_meta.span ctx_body fmt type_decl_group Item
~use_forall def.generics (Some def.explicit_info) type_params cg_params
trait_clauses;
(* Print the "=" if we extract the body*)
Expand Down Expand Up @@ -1795,7 +1799,7 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
if is_rec then
(* Add the type params *)
let ctx, type_params, cg_params, trait_clauses =
ctx_add_generic_params decl.item_meta.span decl.item_meta.name
ctx_add_generic_params decl.item_meta.span decl.item_meta.name Item
decl.llbc_generics decl.generics ctx
in
(* Record_var will be the ADT argument to the projector *)
Expand Down Expand Up @@ -1855,8 +1859,8 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
(* Print the generics *)
let as_implicits = true in
extract_generic_params decl.item_meta.span ctx fmt
TypeDeclId.Set.empty ~as_implicits decl.generics None type_params
cg_params trait_clauses;
TypeDeclId.Set.empty Item ~as_implicits decl.generics None
type_params cg_params trait_clauses;

(* Print the record parameter as "(x : ADT)" *)
F.pp_print_space fmt ();
Expand Down Expand Up @@ -1996,8 +2000,8 @@ let extract_type_decl_record_field_projectors_simp_lemmas (ctx : extraction_ctx)
if is_rec then
(* Add the type params *)
let ctx, type_params, cg_params, trait_clauses =
ctx_add_generic_params span decl.item_meta.name decl.llbc_generics
decl.generics ctx
ctx_add_generic_params span decl.item_meta.name Item
decl.llbc_generics decl.generics ctx
in
(* Name of the ADT *)
let def_name = ctx_get_local_type span decl.def_id ctx in
Expand Down Expand Up @@ -2042,7 +2046,7 @@ let extract_type_decl_record_field_projectors_simp_lemmas (ctx : extraction_ctx)
(* Print the generics *)
let as_implicits = true in
extract_generic_params span ctx fmt TypeDeclId.Set.empty ~as_implicits
decl.generics None type_params cg_params trait_clauses;
Item decl.generics None type_params cg_params trait_clauses;

(* Print the input parameters (the fields) *)
let print_field (ctx : extraction_ctx) (field_id : FieldId.id)
Expand Down
Loading

0 comments on commit dd87c9d

Please sign in to comment.