Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: reimplement interface type asserts #4375

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 90 additions & 27 deletions compiler/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"go/token"
"go/types"
"sort"
"strconv"
"strings"

Expand Down Expand Up @@ -180,6 +181,15 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
typeFieldTypes := []*types.Var{
types.NewVar(token.NoPos, nil, "kind", types.Typ[types.Int8]),
}
var methods []*types.Func
for i := 0; i < ms.Len(); i++ {
methods = append(methods, ms.At(i).Obj().(*types.Func))
}
methodSetType := types.NewStruct([]*types.Var{
types.NewVar(token.NoPos, nil, "length", types.Typ[types.Uintptr]),
types.NewVar(token.NoPos, nil, "methods", types.NewArray(types.Typ[types.UnsafePointer], int64(len(methods)))),
}, nil)
methodSetValue := c.getMethodSetValue(methods)
switch typ := typ.(type) {
case *types.Basic:
typeFieldTypes = append(typeFieldTypes,
Expand All @@ -196,6 +206,7 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]),
types.NewVar(token.NoPos, nil, "underlying", types.Typ[types.UnsafePointer]),
types.NewVar(token.NoPos, nil, "pkgpath", types.Typ[types.UnsafePointer]),
types.NewVar(token.NoPos, nil, "methods", methodSetType),
types.NewVar(token.NoPos, nil, "name", types.NewArray(types.Typ[types.Int8], int64(len(pkgname)+1+len(name)+1))),
)
case *types.Chan:
Expand All @@ -214,6 +225,7 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
typeFieldTypes = append(typeFieldTypes,
types.NewVar(token.NoPos, nil, "numMethods", types.Typ[types.Uint16]),
types.NewVar(token.NoPos, nil, "elementType", types.Typ[types.UnsafePointer]),
types.NewVar(token.NoPos, nil, "methods", methodSetType),
)
case *types.Array:
typeFieldTypes = append(typeFieldTypes,
Expand All @@ -238,12 +250,13 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
types.NewVar(token.NoPos, nil, "size", types.Typ[types.Uint32]),
types.NewVar(token.NoPos, nil, "numFields", types.Typ[types.Uint16]),
types.NewVar(token.NoPos, nil, "fields", types.NewArray(c.getRuntimeType("structField"), int64(typ.NumFields()))),
types.NewVar(token.NoPos, nil, "methods", methodSetType),
)
case *types.Interface:
typeFieldTypes = append(typeFieldTypes,
types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]),
types.NewVar(token.NoPos, nil, "methods", methodSetType),
)
// TODO: methods
case *types.Signature:
typeFieldTypes = append(typeFieldTypes,
types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]),
Expand Down Expand Up @@ -294,6 +307,7 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
c.getTypeCode(types.NewPointer(typ)), // ptrTo
c.getTypeCode(typ.Underlying()), // underlying
pkgPathPtr, // pkgpath pointer
methodSetValue, // methods
c.ctx.ConstString(pkgname+"."+name+"\x00", false), // name
}
metabyte |= 1 << 5 // "named" flag
Expand Down Expand Up @@ -323,6 +337,7 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
typeFields = []llvm.Value{
llvm.ConstInt(c.ctx.Int16Type(), uint64(numMethods), false), // numMethods
c.getTypeCode(typ.Elem()),
methodSetValue, // methods
}
case *types.Array:
typeFields = []llvm.Value{
Expand Down Expand Up @@ -404,9 +419,12 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
}))
}
typeFields = append(typeFields, llvm.ConstArray(structFieldType, fields))
typeFields = append(typeFields, methodSetValue)
case *types.Interface:
typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ))}
// TODO: methods
typeFields = []llvm.Value{
c.getTypeCode(types.NewPointer(typ)),
methodSetValue,
}
case *types.Signature:
typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ))}
// TODO: params, return values, etc
Expand Down Expand Up @@ -685,25 +703,24 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value {
assertedType := b.getLLVMType(expr.AssertedType)

actualTypeNum := b.CreateExtractValue(itf, 0, "interface.type")
commaOk := llvm.Value{}
var commaOk llvm.Value

if intf, ok := expr.AssertedType.Underlying().(*types.Interface); ok {
if intf.Empty() {
// intf is the empty interface => no methods
// This type assertion always succeeds, so we can just set commaOk to true.
commaOk = llvm.ConstInt(b.ctx.Int1Type(), 1, true)
} else {
// Type assert on interface type with methods.
// This is a call to an interface type assert function.
// The interface lowering pass will define this function by filling it
// with a type switch over all concrete types that implement this
// interface, and returning whether it's one of the matched types.
// This is very different from how interface asserts are implemented in
// the main Go compiler, where the runtime checks whether the type
// implements each method of the interface. See:
// Type assert using interface type with methods.
// This is implemented using a runtime call, which checks that the
// type implements each method of the interface.
// For comparison, here is how the Go compiler does this (which is
// very similar):
// https://research.swtch.com/interfaces
fn := b.getInterfaceImplementsFunc(expr.AssertedType)
commaOk = b.CreateCall(fn.GlobalValueType(), fn, []llvm.Value{actualTypeNum}, "")
commaOk = b.createRuntimeCall("typeImplementsMethodSet", []llvm.Value{
actualTypeNum,
b.getInterfaceMethodSet(intf),
}, "")
}
} else {
name, _ := getTypeCodeName(expr.AssertedType)
Expand Down Expand Up @@ -780,20 +797,66 @@ func (c *compilerContext) getMethodsString(itf *types.Interface) string {
return strings.Join(methods, "; ")
}

// getInterfaceImplementsFunc returns a declared function that works as a type
// switch. The interface lowering pass will define this function.
func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) llvm.Value {
s, _ := getTypeCodeName(assertedType.Underlying())
fnName := s + ".$typeassert"
llvmFn := c.mod.NamedFunction(fnName)
if llvmFn.IsNil() {
llvmFnType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.dataPtrType}, false)
llvmFn = llvm.AddFunction(c.mod, fnName, llvmFnType)
c.addStandardDeclaredAttributes(llvmFn)
methods := c.getMethodsString(assertedType.Underlying().(*types.Interface))
llvmFn.AddFunctionAttr(c.ctx.CreateStringAttribute("tinygo-methods", methods))
// Get the global that contains an interface method set, creating it if needed.
func (c *compilerContext) getInterfaceMethodSet(t *types.Interface) llvm.Value {
s, _ := getTypeCodeName(t)
methodSetName := s + "$itfmethods"
methodSet := c.mod.NamedFunction(methodSetName)
if methodSet.IsNil() {
var methods []*types.Func
for i := 0; i < t.NumMethods(); i++ {
methods = append(methods, t.Method(i))
}
if len(methods) == 0 {
// This *should* be unreachable: the caller checks whether the
// interface is empty before creating a method set.
panic("unreachable")
}

methodSetValue := c.getMethodSetValue(methods)
methodSet = llvm.AddGlobal(c.mod, methodSetValue.Type(), methodSetName)
methodSet.SetInitializer(methodSetValue)
methodSet.SetGlobalConstant(true)
methodSet.SetLinkage(llvm.LinkOnceODRLinkage)
methodSet.SetAlignment(c.targetData.ABITypeAlignment(methodSetValue.Type()))
methodSet.SetUnnamedAddr(true)
}
return llvmFn

return methodSet
}

// Get the method set value that is used in a number of type structs.
func (c *compilerContext) getMethodSetValue(methods []*types.Func) llvm.Value {
// Create a sorted list of method names.
var methodRefNames []string
for _, method := range methods {
name := method.Name()
if !token.IsExported(name) {
name = method.Pkg().Path() + "." + name
}
s, _ := getTypeCodeName(method.Type())
methodRefNames = append(methodRefNames, "reflect/types.signature:"+name+":"+s)
}
sort.Strings(methodRefNames)

// Turn this slice of strings in a slice of global variables.
var methodRefValues []llvm.Value
for _, name := range methodRefNames {
value := c.mod.NamedGlobal(name)
if value.IsNil() {
value = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), name)
value.SetInitializer(llvm.ConstNull(c.ctx.Int8Type()))
value.SetGlobalConstant(true)
value.SetLinkage(llvm.LinkOnceODRLinkage)
value.SetAlignment(1)
}
methodRefValues = append(methodRefValues, value)
}

return c.ctx.ConstStruct([]llvm.Value{
llvm.ConstInt(c.uintptrType, uint64(len(methodRefValues)), false),
llvm.ConstArray(c.dataPtrType, methodRefValues),
}, false)
}

// getInvokeFunction returns the thunk to call the given interface method. The
Expand Down
2 changes: 1 addition & 1 deletion compiler/testdata/gc.ll
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ target triple = "wasm32-unknown-wasi"
@"runtime/gc.layout:62-2000000000000001" = linkonce_odr unnamed_addr constant { i32, [8 x i8] } { i32 62, [8 x i8] c"\01\00\00\00\00\00\00 " }
@"runtime/gc.layout:62-0001" = linkonce_odr unnamed_addr constant { i32, [8 x i8] } { i32 62, [8 x i8] c"\01\00\00\00\00\00\00\00" }
@"reflect/types.type:basic:complex128" = linkonce_odr constant { i8, ptr } { i8 80, ptr @"reflect/types.type:pointer:basic:complex128" }, align 4
@"reflect/types.type:pointer:basic:complex128" = linkonce_odr constant { i8, i16, ptr } { i8 -43, i16 0, ptr @"reflect/types.type:basic:complex128" }, align 4
@"reflect/types.type:pointer:basic:complex128" = linkonce_odr constant { i8, i16, ptr, { i32, [0 x ptr] } } { i8 -43, i16 0, ptr @"reflect/types.type:basic:complex128", { i32, [0 x ptr] } zeroinitializer }, align 4

; Function Attrs: allockind("alloc,zeroed") allocsize(0)
declare noalias nonnull ptr @runtime.alloc(i32, ptr, ptr) #0
Expand Down
Loading
Loading