From 8fea7bb037a732a210364d2a0897c2242e1b3307 Mon Sep 17 00:00:00 2001 From: MystPi <86574651+MystPi@users.noreply.github.com> Date: Sun, 3 Mar 2024 14:36:46 -0500 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Part=206?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/type_inference.gleam | 342 +++++++++++++++++++++++++++++++++------ 1 file changed, 289 insertions(+), 53 deletions(-) diff --git a/src/type_inference.gleam b/src/type_inference.gleam index d9d9a9e..13f3505 100644 --- a/src/type_inference.gleam +++ b/src/type_inference.gleam @@ -1,8 +1,11 @@ import gleam/io +import gleam/int import gleam/list import gleam/bool +import gleam/pair import gleam/dict.{type Dict} import gleam/result.{try} +import gleam/option.{type Option, None, Some} // ---- TYPES ------------------------------------------------------------------ @@ -20,13 +23,40 @@ type Type { TVariable(index: Int) } -/// The AST for the hypothetical language +/// The AST for the hypothetical language. Type annotations are optional, and any +/// missing annotations can be filled in by the `infer` function. type Expression { - ELambda(arg: String, body: Expression) - EApply(expression: Expression, arg: Expression) + // A lambda expression, capable of multiple parameters + ELambda( + parameters: List(Parameter), + return_type: Option(Type), + body: Expression, + ) + // A function application + EApply(function: Expression, arguments: List(Expression)) + // A let expression with a body in which the new binding can be used + ELet( + name: String, + type_annotation: Option(Type), + value: Expression, + body: Expression, + ) + // A variable usage EVariable(name: String) + // Value literals + EInt(value: Int) + EString(value: String) + EArray(item_type: Option(Type), items: List(Expression)) +} + +/// A parameter has a name and an optional type annotation. +type Parameter { + Parameter(name: String, type_annotation: Option(Type)) } +type Environment = + Dict(String, Type) + /// A constraint is something we know about types and their relationships type Constraint { /// This constraint means the two types must be equal @@ -45,85 +75,214 @@ type TypeError { // ---- TESTING ---------------------------------------------------------------- pub fn main() { - ELambda("x", EApply(EApply(EVariable("+"), EVariable("x")), EVariable("x"))) - |> infer + ELet( + name: "add", + type_annotation: None, + value: ELambda( + parameters: [Parameter("a", None), Parameter("b", None)], + return_type: None, + body: EApply(EVariable("+"), [EVariable("a"), EVariable("b")]), + ), + body: EApply(EVariable("add"), [EInt(1), EInt(2)]), + ) + |> infer_expression(initial_environment()) |> io.debug Nil } +/// Takes an expression and returns the same expression with type annotations +/// inferred and filled in, or an error if type checking failed. +fn infer_expression( + expression: Expression, + environment: Environment, +) -> Result(Expression, TypeError) { + let initial_ctx = Ctx(substitution: dict.new(), type_constraints: []) + let #(fresh_var, ctx) = fresh_type_variable(initial_ctx) + + // 1. Infer the type of the expression without expecting it to be a certain + // type by passing a fresh type variable. + use #(inferred, ctx) <- try(infer(expression, fresh_var, environment, ctx)) + // 2. Solve the generated constraints + use ctx <- try(solve_constraints(ctx)) + // 3. Replace type variables in the inferred expression with their substitutions + let result = substitute_expression(inferred, ctx) + Ok(result) +} + fn initial_environment() -> Dict(String, Type) { + // Operators in this case are functions. It wouldn't be very difficult to add + // operators to the Expression type itself, but since the article doesn't I'll + // hold off for now to keep things simple. ["+", "-", "*", "/"] |> list.map(fn(op) { #( op, - TConstructor("Function1", [ + TConstructor("Function2", [ + TConstructor("Int", []), + TConstructor("Int", []), TConstructor("Int", []), - TConstructor("Function1", [ - TConstructor("Int", []), - TConstructor("Int", []), - ]), ]), ) }) |> dict.from_list } -fn initial_ctx() -> Ctx { - Ctx(dict.new(), []) -} - -fn infer(expression: Expression) -> Result(Type, TypeError) { - // 1. Infer the type of the expression - use #(inferred, ctx) <- try(infer_type( - expression, - initial_environment(), - initial_ctx(), - )) - // 2. Solve the constraints - use ctx <- try(solve_constraints(ctx)) - // 3. Replace type variables in the inferred type with their substitutions - let result = substitute(inferred, ctx) - Ok(result) -} - // ---- TYPE INFERENCE --------------------------------------------------------- -/// Turn an expression into type variables, adding constraints to things we know -/// must be certain types, and inserting variables and functions into the environment -/// when appropriate. +/// Infer the type of an expression, filling in any missing type annotations with +/// fresh type variables, and adding constraints to things we know should be certain +/// types. The expected type of the expression must be passed to the function, +/// but a fresh type variable can be passed instead to infer without checking. /// -fn infer_type( +fn infer( expression: Expression, - environment: Dict(String, Type), + expected_type: Type, + environment: Environment, ctx: Ctx, -) -> Result(#(Type, Ctx), TypeError) { +) -> Result(#(Expression, Ctx), TypeError) { case expression { - ELambda(name, body) -> { - let #(arg_type, ctx) = fresh_type_variable(ctx) - let environment2 = dict.insert(environment, name, arg_type) - use #(return_type, ctx) <- try(infer_type(body, environment2, ctx)) - Ok(#(TConstructor("Function1", [arg_type, return_type]), ctx)) + ELambda(parameters, return_type, body) -> { + let #(new_return_type, ctx) = type_or_fresh_variable(return_type, ctx) + let #(new_parameter_types, ctx) = + map_fold(parameters, ctx, fn(ctx, p) { + type_or_fresh_variable(p.type_annotation, ctx) + }) + // Update the parameters to include the new types + let new_parameters = + list.map2(parameters, new_parameter_types, fn(p, t) { + Parameter(..p, type_annotation: Some(t)) + }) + // Add the parameters to the environment + let new_environment = + list.fold(new_parameters, environment, fn(env, p) { + let assert Parameter(name, Some(t)) = p + dict.insert(env, name, t) + }) + // Infer the type of the body, expecting it to be the return type (which + // could be a fresh type variable if an annotation wasn't given) + use #(new_body, ctx) <- try(infer( + body, + new_return_type, + new_environment, + ctx, + )) + + let constructor_name = + "Function" <> int.to_string(list.length(parameters)) + + let ctx = + push_constraint( + ctx, + CEquality( + expected_type, + TConstructor( + constructor_name, + list.reverse([new_return_type, ..new_parameter_types]), + ), + ), + ) + + Ok(#(ELambda(new_parameters, Some(new_return_type), new_body), ctx)) } - EVariable(name) -> + EApply(function, arguments) -> { + // Create a fresh type variable for each of the arguments + let #(argument_types, ctx) = + map_fold(arguments, ctx, fn(ctx, _) { fresh_type_variable(ctx) }) + // Create a function type based on the arguments and expected return type + let constructor_name = "Function" <> int.to_string(list.length(arguments)) + let function_type = + TConstructor( + constructor_name, + list.reverse([expected_type, ..argument_types]), + ) + // Infer the type of function, expecting it to be `function_type`. + use #(new_function, ctx) <- try(infer( + function, + function_type, + environment, + ctx, + )) + // Infer the missing types of the arguments + use #(new_arguments, ctx) <- try( + list.zip(arguments, argument_types) + |> try_map_fold(ctx, fn(ctx, pair) { + infer(pair.0, pair.1, environment, ctx) + }), + ) + + Ok(#(EApply(new_function, new_arguments), ctx)) + } + + EVariable(name) -> { case dict.get(environment, name) { - Ok(t) -> Ok(#(t, ctx)) - Error(_) -> Error(TypeError("Variable not defined: " <> name)) + Ok(t) -> { + // Constrain the variable type to be the expected type + let ctx = push_constraint(ctx, CEquality(expected_type, t)) + // Simply return the original expression. There are no type annotations + // to fill in! + Ok(#(expression, ctx)) + } + Error(_) -> Error(TypeError("Variable not in scope: " <> name)) } + } + + ELet(name, type_annotation, value, body) -> { + let #(new_type_annotation, ctx) = + type_or_fresh_variable(type_annotation, ctx) + use #(new_value, ctx) <- try(infer( + value, + new_type_annotation, + environment, + ctx, + )) + // Add the variable to the environment so it can be referenced in the body + let new_environment = dict.insert(environment, name, new_type_annotation) + use #(new_body, ctx) <- try(infer( + body, + expected_type, + new_environment, + ctx, + )) + + Ok(#(ELet(name, Some(new_type_annotation), new_value, new_body), ctx)) + } + + // These literals are easy—we just have to constrain the expected type. There + // aren't any annotations to fill in. + EInt(_) -> { + let ctx = + push_constraint(ctx, CEquality(expected_type, TConstructor("Int", []))) + + Ok(#(expression, ctx)) + } + + EString(_) -> { + let ctx = + push_constraint( + ctx, + CEquality(expected_type, TConstructor("String", [])), + ) - EApply(function, arg) -> { - use #(function_type, ctx) <- try(infer_type(function, environment, ctx)) - use #(arg_type, ctx) <- try(infer_type(arg, environment, ctx)) - let #(return_type, ctx) = fresh_type_variable(ctx) + Ok(#(expression, ctx)) + } - let constraint = - CEquality( - function_type, - TConstructor("Function1", [arg_type, return_type]), + // Array literals are slightly more complex. + EArray(item_type, items) -> { + let #(new_item_type, ctx) = type_or_fresh_variable(item_type, ctx) + // Infer the types of the items, expecting them to be `new_item_type`. + use #(new_items, ctx) <- try( + try_map_fold(items, ctx, fn(ctx, item) { + infer(item, new_item_type, environment, ctx) + }), + ) + let ctx = + push_constraint( + ctx, + CEquality(expected_type, TConstructor("Array", [new_item_type])), ) - let ctx = push_constraint(ctx, constraint) - Ok(#(return_type, ctx)) + Ok(#(EArray(Some(new_item_type), new_items), ctx)) } } } @@ -136,6 +295,15 @@ fn fresh_type_variable(ctx: Ctx) -> #(Type, Ctx) { #(result, insert_substitution(ctx, index, result)) } +/// Return a fresh type variable if the optional type isn't provided. +/// +fn type_or_fresh_variable(annotation: Option(Type), ctx: Ctx) -> #(Type, Ctx) { + case annotation { + Some(a) -> #(a, ctx) + None -> fresh_type_variable(ctx) + } +} + /// Add a constraint to the type_constraints list. /// fn push_constraint(ctx: Ctx, constraint: Constraint) -> Ctx { @@ -145,7 +313,7 @@ fn push_constraint(ctx: Ctx, constraint: Constraint) -> Ctx { // TODO: this function could be easily modified to return multiple type errors // instead of just the first one. // -/// "Solve" constraints by going through them and making sure they are true, then +/// "Solve" constraints by going through them and making sure they hold true, then /// clearing the constraints list. /// fn solve_constraints(ctx: Ctx) -> Result(Ctx, TypeError) { @@ -262,3 +430,71 @@ fn substitute(t: Type, ctx: Ctx) -> Type { TConstructor(name, list.map(generics, substitute(_, ctx))) } } + +/// Recursively replace solved type variables in an expression with their +/// substitutions. +/// +fn substitute_expression(expression: Expression, ctx: Ctx) -> Expression { + case expression { + ELambda(parameters, return_type, body) -> { + let new_return_type = option.map(return_type, substitute(_, ctx)) + let new_parameters = + list.map(parameters, fn(p) { + let assert Parameter(name, Some(t)) = p + Parameter(name, Some(substitute(t, ctx))) + }) + let new_body = substitute_expression(body, ctx) + ELambda(new_parameters, new_return_type, new_body) + } + + EApply(function, arguments) -> { + let new_function = substitute_expression(function, ctx) + let new_arguments = list.map(arguments, substitute_expression(_, ctx)) + EApply(new_function, new_arguments) + } + + EVariable(_) | EInt(_) | EString(_) -> expression + + ELet(name, type_annotation, value, body) -> { + let new_type_annotation = option.map(type_annotation, substitute(_, ctx)) + let new_value = substitute_expression(value, ctx) + let new_body = substitute_expression(body, ctx) + ELet(name, new_type_annotation, new_value, new_body) + } + + EArray(item_type, items) -> { + let new_item_type = option.map(item_type, substitute(_, ctx)) + let new_items = list.map(items, substitute_expression(_, ctx)) + EArray(new_item_type, new_items) + } + } +} + +// ---- UTILS ------------------------------------------------------------------ + +/// Same as `list.map_fold` but with the return tuples swapped. This function is +/// useful because other functions (such as `fresh_type_variable`) return +/// `#(a, Ctx)` instead of `#(Ctx, a)` and `Ctx` is being folded. +/// +fn map_fold(list: List(a), acc: b, fun: fn(b, a) -> #(c, b)) -> #(List(c), b) { + list.map_fold(list, acc, fn(b, a) { + fun(b, a) + |> pair.swap + }) + |> pair.swap +} + +/// A combination of `list.try_fold` and `map_fold`. +/// +fn try_map_fold( + list: List(a), + folding: b, + fun: fn(b, a) -> Result(#(c, b), d), +) -> Result(#(List(c), b), d) { + list.try_fold(list, #([], folding), fn(state, curr) { + let #(acc, folding) = state + use #(result, folding) <- try(fun(folding, curr)) + Ok(#([result, ..acc], folding)) + }) + |> result.map(pair.map_first(_, list.reverse)) +}