From 6a866d47e79bb94c30490cd305902bf6dc5235ad Mon Sep 17 00:00:00 2001 From: Ethan Lewis Date: Thu, 31 Oct 2024 13:29:00 -0700 Subject: [PATCH 1/3] feat: add NumIn, NumOut, In, Out for function reflection --- compiler/interface.go | 35 +++++++++++++++++++++--- src/reflect/type.go | 62 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 86 insertions(+), 11 deletions(-) diff --git a/compiler/interface.go b/compiler/interface.go index dffaeec0ad..1c881c5888 100644 --- a/compiler/interface.go +++ b/compiler/interface.go @@ -243,12 +243,13 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { typeFieldTypes = append(typeFieldTypes, types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), ) - // TODO: methods case *types.Signature: typeFieldTypes = append(typeFieldTypes, types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), + types.NewVar(token.NoPos, nil, "inCount", types.Typ[types.Uint16]), + types.NewVar(token.NoPos, nil, "outCount", types.Typ[types.Uint16]), + types.NewVar(token.NoPos, nil, "fields", types.NewArray(types.Typ[types.UnsafePointer], int64(typ.Params().Len()+typ.Results().Len()))), ) - // TODO: signature params and return values } if hasMethodSet { // This method set is appended at the start of the struct. It is @@ -408,8 +409,34 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ))} // TODO: methods case *types.Signature: - typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ))} - // TODO: params, return values, etc + v := typ.Results().Len() + if typ.Variadic() { + // set variadic to 1 + v = v | (1 << 15) + } else { + // set variadic to 0 + v = v & ^(1 << 15) + } + + var vars []*types.Var + for i := 0; i < typ.Params().Len(); i++ { + vars = append(vars, typ.Params().At(i)) + } + + for i := 0; i < typ.Results().Len(); i++ { + vars = append(vars, typ.Results().At(i)) + } + + var fields []llvm.Value + for i := 0; i < len(vars); i++ { + fields = append(fields, c.getTypeCode(vars[i].Type())) + } + typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ)), + llvm.ConstInt(c.ctx.Int16Type(), uint64(typ.Params().Len()), false), + llvm.ConstInt(c.ctx.Int16Type(), uint64(v), false), + } + + typeFields = append(typeFields, llvm.ConstArray(c.dataPtrType, fields)) } // Prepend metadata byte. typeFields = append([]llvm.Value{ diff --git a/src/reflect/type.go b/src/reflect/type.go index b18a1a9d4f..67e3cb972d 100644 --- a/src/reflect/type.go +++ b/src/reflect/type.go @@ -487,6 +487,14 @@ type structField struct { data unsafe.Pointer // various bits of information, packed in a byte array } +type funcType struct { + rawType + ptrType *rawType + inCount uint16 + outCount uint16 + fields [1]*rawType // the remaining fields are all of type funcField +} + // Equivalent to (go/types.Type).Underlying(): if this is a named type return // the underlying type, else just return the type itself. func (t *rawType) underlying() *rawType { @@ -1038,15 +1046,25 @@ func (t *rawType) ConvertibleTo(u Type) bool { } func (t *rawType) IsVariadic() bool { - panic("unimplemented: (reflect.Type).IsVariadic()") + // need to test if bool mapped to int set by compiler + if t.Kind() != Func { + panic("reflect: IsVariadic of non-func type") + } + return (*funcType)(unsafe.Pointer(t)).outCount&(1<<15) != 0 } func (t *rawType) NumIn() int { - panic("unimplemented: (reflect.Type).NumIn()") + if t.Kind() != Func { + panic("reflect: NumIn of non-func type") + } + return int((*funcType)(unsafe.Pointer(t)).inCount) } func (t *rawType) NumOut() int { - panic("unimplemented: (reflect.Type).NumOut()") + if t.Kind() != Func { + panic("reflect: NumOut of non-func type") + } + return int((*funcType)(unsafe.Pointer(t)).outCount) } func (t *rawType) NumMethod() int { @@ -1110,12 +1128,42 @@ func (t *rawType) Key() Type { return t.key() } -func (t rawType) In(i int) Type { - panic("unimplemented: (reflect.Type).In()") +// addChecked returns p+x. +// +// The whySafe string is ignored, so that the function still inlines +// as efficiently as p+x, but all call sites should use the string to +// record why the addition is safe, which is to say why the addition +// does not cause x to advance to the very end of p's allocation +// and therefore point incorrectly at the next block in memory. +func addChecked(p unsafe.Pointer, x uintptr, whySafe string) unsafe.Pointer { + return unsafe.Pointer(uintptr(p) + x) } -func (t rawType) Out(i int) Type { - panic("unimplemented: (reflect.Type).Out()") +func (t *rawType) In(i int) Type { + if t.Kind() != Func { + panic(errTypeField) + } + descriptor := (*funcType)(unsafe.Pointer(t.underlying())) + if uint(i) >= uint(descriptor.inCount) { + panic("reflect: field index out of range") + } + + pointer := (unsafe.Add(unsafe.Pointer(&descriptor.fields[0]), uintptr(i)*unsafe.Sizeof(unsafe.Pointer(nil)))) + return (*rawType)(*(**rawType)(pointer)) +} + +func (t *rawType) Out(i int) Type { + if t.Kind() != Func { + panic(errTypeField) + } + i = i + int((*funcType)(unsafe.Pointer(t)).inCount) + descriptor := (*funcType)(unsafe.Pointer(t.underlying())) + if uint(i) > uint(descriptor.inCount) && uint(i) <= uint(descriptor.outCount+descriptor.inCount) { + panic("reflect: field index out of range") + } + + pointer := (unsafe.Add(unsafe.Pointer(&descriptor.fields[0]), uintptr(i)*unsafe.Sizeof(unsafe.Pointer(nil)))) + return (*rawType)(*(**rawType)(pointer)) } // OverflowComplex reports whether the complex128 x cannot be represented by type t. From ee7dde40e28f218596a889392c40feb64a254a85 Mon Sep 17 00:00:00 2001 From: jeff1010322 Date: Fri, 1 Nov 2024 12:20:44 -0400 Subject: [PATCH 2/3] feat: update funcStruct to use variadic, impl String Co-authored-by: Ethan Lewis --- compiler/interface.go | 13 +++++-------- src/reflect/type.go | 34 ++++++++++++++++++++++++++++++---- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/compiler/interface.go b/compiler/interface.go index 1c881c5888..a4a6b15a52 100644 --- a/compiler/interface.go +++ b/compiler/interface.go @@ -248,6 +248,7 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), types.NewVar(token.NoPos, nil, "inCount", types.Typ[types.Uint16]), types.NewVar(token.NoPos, nil, "outCount", types.Typ[types.Uint16]), + types.NewVar(token.NoPos, nil, "variadic", types.Typ[types.Bool]), types.NewVar(token.NoPos, nil, "fields", types.NewArray(types.Typ[types.UnsafePointer], int64(typ.Params().Len()+typ.Results().Len()))), ) } @@ -409,13 +410,9 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ))} // TODO: methods case *types.Signature: - v := typ.Results().Len() + v := 0 if typ.Variadic() { - // set variadic to 1 - v = v | (1 << 15) - } else { - // set variadic to 0 - v = v & ^(1 << 15) + v = 1 } var vars []*types.Var @@ -433,9 +430,9 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { } typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ)), llvm.ConstInt(c.ctx.Int16Type(), uint64(typ.Params().Len()), false), - llvm.ConstInt(c.ctx.Int16Type(), uint64(v), false), + llvm.ConstInt(c.ctx.Int16Type(), uint64(typ.Results().Len()), false), + llvm.ConstInt(c.ctx.Int1Type(), uint64(v), false), } - typeFields = append(typeFields, llvm.ConstArray(c.dataPtrType, fields)) } // Prepend metadata byte. diff --git a/src/reflect/type.go b/src/reflect/type.go index 67e3cb972d..315dbe51e7 100644 --- a/src/reflect/type.go +++ b/src/reflect/type.go @@ -489,9 +489,10 @@ type structField struct { type funcType struct { rawType - ptrType *rawType + ptrTo *rawType inCount uint16 outCount uint16 + variadic bool fields [1]*rawType // the remaining fields are all of type funcField } @@ -606,6 +607,28 @@ func (t *rawType) String() string { } s += " }" return s + case Func: + + f := "func(" + for i := 0; i < t.NumIn(); i++ { + if i > 0 { + f += ", " + } + f += t.In(i).String() + } + f += ") " + + var rets string + for i := 0; i < t.NumOut(); i++ { + if i > 0 { + rets += ", " + } + rets += t.Out(i).String() + } + if t.NumOut() > 1 { + rets = "(" + rets + ")" + } + return f + rets case Interface: // TODO(dgryski): Needs actual method set info return "interface {}" @@ -1050,7 +1073,7 @@ func (t *rawType) IsVariadic() bool { if t.Kind() != Func { panic("reflect: IsVariadic of non-func type") } - return (*funcType)(unsafe.Pointer(t)).outCount&(1<<15) != 0 + return (*funcType)(unsafe.Pointer(t)).variadic } func (t *rawType) NumIn() int { @@ -1156,12 +1179,15 @@ func (t *rawType) Out(i int) Type { if t.Kind() != Func { panic(errTypeField) } - i = i + int((*funcType)(unsafe.Pointer(t)).inCount) + descriptor := (*funcType)(unsafe.Pointer(t.underlying())) - if uint(i) > uint(descriptor.inCount) && uint(i) <= uint(descriptor.outCount+descriptor.inCount) { + if uint(i) >= uint(descriptor.outCount) { panic("reflect: field index out of range") } + // Shift the index by the number of input parameters. + i = i + int((*funcType)(unsafe.Pointer(t)).inCount) + pointer := (unsafe.Add(unsafe.Pointer(&descriptor.fields[0]), uintptr(i)*unsafe.Sizeof(unsafe.Pointer(nil)))) return (*rawType)(*(**rawType)(pointer)) } From b848853a7d58ceccde6e24b649a47bac2b3247b5 Mon Sep 17 00:00:00 2001 From: jeff1010322 Date: Fri, 1 Nov 2024 14:29:05 -0400 Subject: [PATCH 3/3] fix: func reflect impl for named functions, updated tests Co-authored-by: Ethan Lewis --- src/reflect/all_test.go | 4 +++ src/reflect/type.go | 57 ++++++++++++++++++++++++++++++--------- src/reflect/type_test.go | 2 ++ src/reflect/value_test.go | 53 ++++++++++++++++++++++++++++++++++++ 4 files changed, 104 insertions(+), 12 deletions(-) diff --git a/src/reflect/all_test.go b/src/reflect/all_test.go index 436bc00341..da30d314a5 100644 --- a/src/reflect/all_test.go +++ b/src/reflect/all_test.go @@ -3312,6 +3312,8 @@ func TestMethodPkgPath(t *testing.T) { } } +*/ + func TestVariadicType(t *testing.T) { // Test example from Type documentation. var f func(x int, y ...float64) @@ -3335,6 +3337,8 @@ func TestVariadicType(t *testing.T) { t.Error(s) } +/* + type inner struct { x int } diff --git a/src/reflect/type.go b/src/reflect/type.go index 315dbe51e7..97f43e6efb 100644 --- a/src/reflect/type.go +++ b/src/reflect/type.go @@ -608,14 +608,22 @@ func (t *rawType) String() string { s += " }" return s case Func: + isVariadic := t.IsVariadic() f := "func(" for i := 0; i < t.NumIn(); i++ { if i > 0 { f += ", " } - f += t.In(i).String() + + input := t.In(i).String() + if isVariadic && i == t.NumIn()-1 { + f += "..." + input = input[2:] + } + f += input } + f += ") " var rets string @@ -1069,14 +1077,24 @@ func (t *rawType) ConvertibleTo(u Type) bool { } func (t *rawType) IsVariadic() bool { - // need to test if bool mapped to int set by compiler + if t.isNamed() { + named := (*namedType)(unsafe.Pointer(t)) + t = named.elem + } + if t.Kind() != Func { panic("reflect: IsVariadic of non-func type") } + return (*funcType)(unsafe.Pointer(t)).variadic } func (t *rawType) NumIn() int { + if t.isNamed() { + named := (*namedType)(unsafe.Pointer(t)) + return int((*funcType)(unsafe.Pointer(named.elem)).inCount) + } + if t.Kind() != Func { panic("reflect: NumIn of non-func type") } @@ -1084,6 +1102,11 @@ func (t *rawType) NumIn() int { } func (t *rawType) NumOut() int { + if t.isNamed() { + named := (*namedType)(unsafe.Pointer(t)) + return int((*funcType)(unsafe.Pointer(named.elem)).outCount) + } + if t.Kind() != Func { panic("reflect: NumOut of non-func type") } @@ -1163,33 +1186,43 @@ func addChecked(p unsafe.Pointer, x uintptr, whySafe string) unsafe.Pointer { } func (t *rawType) In(i int) Type { + if t.isNamed() { + named := (*namedType)(unsafe.Pointer(t)) + t = named.elem + } + if t.Kind() != Func { panic(errTypeField) } - descriptor := (*funcType)(unsafe.Pointer(t.underlying())) - if uint(i) >= uint(descriptor.inCount) { + fType := (*funcType)(unsafe.Pointer(t)) + if uint(i) >= uint(fType.inCount) { panic("reflect: field index out of range") } - pointer := (unsafe.Add(unsafe.Pointer(&descriptor.fields[0]), uintptr(i)*unsafe.Sizeof(unsafe.Pointer(nil)))) - return (*rawType)(*(**rawType)(pointer)) + pointer := (unsafe.Add(unsafe.Pointer(&fType.fields[0]), uintptr(i)*unsafe.Sizeof(unsafe.Pointer(nil)))) + return (*(**rawType)(pointer)) } func (t *rawType) Out(i int) Type { + if t.isNamed() { + named := (*namedType)(unsafe.Pointer(t)) + t = named.elem + } + if t.Kind() != Func { panic(errTypeField) } - descriptor := (*funcType)(unsafe.Pointer(t.underlying())) - if uint(i) >= uint(descriptor.outCount) { + fType := (*funcType)(unsafe.Pointer(t)) + + if uint(i) >= uint(fType.outCount) { panic("reflect: field index out of range") } // Shift the index by the number of input parameters. - i = i + int((*funcType)(unsafe.Pointer(t)).inCount) - - pointer := (unsafe.Add(unsafe.Pointer(&descriptor.fields[0]), uintptr(i)*unsafe.Sizeof(unsafe.Pointer(nil)))) - return (*rawType)(*(**rawType)(pointer)) + i = i + int(fType.inCount) + pointer := (unsafe.Add(unsafe.Pointer(&fType.fields[0]), uintptr(i)*unsafe.Sizeof(unsafe.Pointer(nil)))) + return (*(**rawType)(pointer)) } // OverflowComplex reports whether the complex128 x cannot be represented by type t. diff --git a/src/reflect/type_test.go b/src/reflect/type_test.go index 75784f9666..dd54b0175c 100644 --- a/src/reflect/type_test.go +++ b/src/reflect/type_test.go @@ -13,6 +13,7 @@ func TestTypeFor(t *testing.T) { type ( mystring string myiface interface{} + myfunc func() ) testcases := []struct { @@ -25,6 +26,7 @@ func TestTypeFor(t *testing.T) { {new(mystring), reflect.TypeFor[mystring]()}, {new(any), reflect.TypeFor[any]()}, {new(myiface), reflect.TypeFor[myiface]()}, + {new(myfunc), reflect.TypeFor[myfunc]()}, } for _, tc := range testcases { want := reflect.ValueOf(tc.wantFrom).Elem().Type() diff --git a/src/reflect/value_test.go b/src/reflect/value_test.go index 508b358ad9..2ef4f0f35d 100644 --- a/src/reflect/value_test.go +++ b/src/reflect/value_test.go @@ -487,6 +487,59 @@ func TestTinyStruct(t *testing.T) { } } +func TestTinyFunc(t *testing.T) { + type barStruct struct { + QuxString string + BazInt int + } + + type foobar func(bar barStruct, x int, v ...string) string + + var fb foobar + + reffb := TypeOf(fb) + + numIn := reffb.NumIn() + if want := 3; numIn != want { + t.Errorf("NumIn=%v, want %v", numIn, want) + } + + numOut := reffb.NumOut() + if want := 1; numOut != want { + t.Errorf("NumOut=%v, want %v", numOut, want) + } + + in0 := reffb.In(0) + if want := TypeOf(barStruct{}); in0 != want { + t.Errorf("In(0)=%v, want %v", in0, want) + } + + in1 := reffb.In(1) + if want := TypeOf(0); in1 != want { + t.Errorf("In(1)=%v, want %v", in1, want) + } + + in2 := reffb.In(2) + if want := TypeOf([]string{}); in2 != want { + t.Errorf("In(2)=%v, want %v", in2, want) + } + + out0 := reffb.Out(0) + if want := TypeOf(""); out0 != want { + t.Errorf("Out(0)=%v, want %v", out0, want) + } + + isVariadic := reffb.IsVariadic() + if want := true; isVariadic != want { + t.Errorf("IsVariadic=%v, want %v", isVariadic, want) + } + + if got, want := reffb.String(), "reflect_test.foobar"; got != want { + t.Errorf("Value.String()=%v, want %v", got, want) + } + +} + func TestTinyZero(t *testing.T) { s := "hello, world" sptr := &s