From 4f242969731f222fe861bcd0aa5bbe25d9ae3367 Mon Sep 17 00:00:00 2001 From: David Finkel Date: Mon, 29 Jan 2024 14:59:09 -0500 Subject: [PATCH] pflag: add integral-slice and uintptr support Mirroring the flag support, add support for integral-typed flags. Also, fix a bug in the pflag vistor callback that paniced due to an incorrect conversion. (exercised via the new the uintptr support) --- sources/pflag/pflag.go | 67 +++++++++- sources/pflag/pflag_test.go | 235 ++++++++++++++++++++++++++++++++++++ 2 files changed, 296 insertions(+), 6 deletions(-) diff --git a/sources/pflag/pflag.go b/sources/pflag/pflag.go index 221e4a1..410933c 100644 --- a/sources/pflag/pflag.go +++ b/sources/pflag/pflag.go @@ -42,15 +42,29 @@ var ( int32Type = reflect.TypeOf(int32(0)) int64Type = reflect.TypeOf(int64(0)) - uintType = reflect.TypeOf(uint(0)) - uint8Type = reflect.TypeOf(uint8(0)) - uint16Type = reflect.TypeOf(uint16(0)) - uint32Type = reflect.TypeOf(uint32(0)) - uint64Type = reflect.TypeOf(uint64(0)) + uintType = reflect.TypeOf(uint(0)) + uint8Type = reflect.TypeOf(uint8(0)) + uint16Type = reflect.TypeOf(uint16(0)) + uint32Type = reflect.TypeOf(uint32(0)) + uint64Type = reflect.TypeOf(uint64(0)) + uintptrType = reflect.TypeOf(uintptr(0)) complex64Type = reflect.TypeOf((*complex64)(nil)) complex128Type = reflect.TypeOf((*complex128)(nil)) + intSliceType = reflect.SliceOf(intType) + int8SliceType = reflect.SliceOf(int8Type) + int16SliceType = reflect.SliceOf(int16Type) + int32SliceType = reflect.SliceOf(int32Type) + int64SliceType = reflect.SliceOf(int64Type) + + uintSliceType = reflect.SliceOf(uintType) + uint8SliceType = reflect.SliceOf(uint8Type) + uint16SliceType = reflect.SliceOf(uint16Type) + uint32SliceType = reflect.SliceOf(uint32Type) + uint64SliceType = reflect.SliceOf(uint64Type) + uintptrSliceType = reflect.SliceOf(uintptrType) + // Verify that Set implements the dials.Source interface _ dials.Source = (*Set)(nil) ) @@ -313,6 +327,8 @@ func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error { f = s.Flags.Uint32P(name, shorthand, fieldVal.Convert(uint32Type).Interface().(uint32), help) case reflect.Uint64: f = s.Flags.Uint64P(name, shorthand, fieldVal.Convert(uint64Type).Interface().(uint64), help) + case reflect.Uintptr: + f = s.Flags.Uint64P(name, shorthand, uint64(fieldVal.Convert(uintptrType).Interface().(uintptr)), help) case reflect.Slice, reflect.Map: switch ft { case stringSlice: @@ -326,6 +342,44 @@ func (s *Set) registerFlags(tmpl reflect.Value, ptyp reflect.Type) error { case stringSet: f = fieldVal.Addr().Interface() s.Flags.VarP(flaghelper.NewStringSetFlag(fieldVal.Addr().Interface().(*map[string]struct{})), name, shorthand, help) + + // signed integral slices + case intSliceType: + f = fieldVal.Addr().Interface() + s.Flags.VarP(flaghelper.NewSignedIntegralSlice(f.(*[]int)), name, shorthand, help) + case int8SliceType: + f = fieldVal.Addr().Interface() + s.Flags.VarP(flaghelper.NewSignedIntegralSlice(f.(*[]int8)), name, shorthand, help) + case int16SliceType: + f = fieldVal.Addr().Interface() + s.Flags.VarP(flaghelper.NewSignedIntegralSlice(f.(*[]int16)), name, shorthand, help) + case int32SliceType: + f = fieldVal.Addr().Interface() + s.Flags.VarP(flaghelper.NewSignedIntegralSlice(f.(*[]int32)), name, shorthand, help) + case int64SliceType: + f = fieldVal.Addr().Interface() + s.Flags.VarP(flaghelper.NewSignedIntegralSlice(f.(*[]int64)), name, shorthand, help) + + // unsigned integral slices + case uintSliceType: + f = fieldVal.Addr().Interface() + s.Flags.VarP(flaghelper.NewUnsignedIntegralSlice(f.(*[]uint)), name, shorthand, help) + case uint8SliceType: + f = fieldVal.Addr().Interface() + s.Flags.VarP(flaghelper.NewUnsignedIntegralSlice(f.(*[]uint8)), name, shorthand, help) + case uint16SliceType: + f = fieldVal.Addr().Interface() + s.Flags.VarP(flaghelper.NewUnsignedIntegralSlice(f.(*[]uint16)), name, shorthand, help) + case uint32SliceType: + f = fieldVal.Addr().Interface() + s.Flags.VarP(flaghelper.NewUnsignedIntegralSlice(f.(*[]uint32)), name, shorthand, help) + case uint64SliceType: + f = fieldVal.Addr().Interface() + s.Flags.VarP(flaghelper.NewUnsignedIntegralSlice(f.(*[]uint64)), name, shorthand, help) + case uintptrSliceType: + f = fieldVal.Addr().Interface() + s.Flags.VarP(flaghelper.NewUnsignedIntegralSlice(f.(*[]uintptr)), name, shorthand, help) + default: // Unhandled type. Just keep going. continue @@ -431,7 +485,8 @@ func (s *Set) Value(_ context.Context, t *dials.Type) (reflect.Value, error) { return } - cfval := fval.Convert(stripTypePtr(ffield.Type())) + // fval is always a pointer, so dereference it before converting to the final type + cfval := fval.Elem().Convert(stripTypePtr(ffield.Type())) switch ffield.Kind() { case reflect.Ptr: // common case diff --git a/sources/pflag/pflag_test.go b/sources/pflag/pflag_test.go index cbeddc5..1a37c54 100644 --- a/sources/pflag/pflag_test.go +++ b/sources/pflag/pflag_test.go @@ -81,6 +81,7 @@ func TestDefaultVals(t *testing.T) { type otherUint16 uint16 type otherUint32 uint32 type otherUint64 uint64 + type otherUintptr uintptr type otherFloat32 float32 type otherFloat64 float64 type otherComplex64 complex64 @@ -99,6 +100,7 @@ func TestDefaultVals(t *testing.T) { OUint16 otherUint16 OUint32 otherUint32 OUint64 otherUint64 + OUintptr otherUintptr OFloat32 otherFloat32 OFloat64 otherFloat64 OComplex64 otherComplex64 @@ -118,6 +120,7 @@ func TestDefaultVals(t *testing.T) { OUint16: 3, OUint32: 4, OUint64: 5, + OUintptr: 0xffff_f333_7777, OFloat32: 6.0, OFloat64: 7.0, OComplex64: 8 + 2i, @@ -200,6 +203,24 @@ func TestPFlags(t *testing.T) { args: []string{"--a=42"}, expected: &struct{ A int }{A: 42}, }, + { + name: "basic_int_slice_set", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []int }{A: []int{4}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=42,33"}, + expected: &struct{ A []int }{A: []int{42, 33}}, + }, + { + name: "basic_uint_slice_set", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []uint }{A: []uint{4}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=42,33"}, + expected: &struct{ A []uint }{A: []uint{42, 33}}, + }, { name: "basic_float32_set", tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { @@ -294,6 +315,202 @@ func TestPFlags(t *testing.T) { expected: nil, expErr: "failed to parse pflags: invalid argument \"1000000\" for \"--a\" flag: strconv.ParseInt: parsing \"1000000\": value out of range", }, + { + name: "basic_uint16_slice_default", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []uint16 }{A: []uint16{10}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{}, + expected: &struct{ A []uint16 }{A: []uint16{10}}, + }, + { + name: "basic_uint16_slice_set_nooverflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []uint16 }{A: []uint16{10}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=128,32"}, + expected: &struct{ A []uint16 }{A: []uint16{128, 32}}, + }, + { + name: "basic_uint16_slice_set_overflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []uint16 }{A: []uint16{10}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=1000000"}, + expected: nil, + expErr: "failed to parse pflags: invalid argument \"1000000\" for \"--a\" flag: failed to parse integer index 0: strconv.ParseUint: parsing \"1000000\": value out of range", + }, + { + name: "basic_uint32_set_nooverflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A uint32 }{A: 10} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=128"}, + expected: &struct{ A uint32 }{A: 128}, + }, + { + name: "basic_uint32_set_overflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A uint32 }{A: 10} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=100_000_000_000"}, + expected: nil, + expErr: "failed to parse pflags: invalid argument \"100_000_000_000\" for \"--a\" flag: strconv.ParseUint: parsing \"100_000_000_000\": value out of range", + }, + { + name: "basic_uint32_slice_default", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []uint32 }{A: []uint32{10}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{}, + expected: &struct{ A []uint32 }{A: []uint32{10}}, + }, + { + name: "basic_uint32_slice_set_nooverflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []uint32 }{A: []uint32{10}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=128,32"}, + expected: &struct{ A []uint32 }{A: []uint32{128, 32}}, + }, + { + name: "basic_uint8_default", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A uint8 }{A: 10} + return &cfg, testWrapDials(&cfg) + }, + args: []string{}, + expected: &struct{ A uint8 }{A: 10}, + }, + { + name: "basic_uint8_slice_default", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []uint8 }{A: []uint8{10}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{}, + expected: &struct{ A []uint8 }{A: []uint8{10}}, + }, + { + name: "basic_uint8_set_nooverflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A uint8 }{A: 10} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=125"}, + expected: &struct{ A uint8 }{A: 125}, + }, + { + name: "basic_uint8_slice_set_nooverflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []uint8 }{A: []uint8{10}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=125"}, + expected: &struct{ A []uint8 }{A: []uint8{125}}, + }, + { + name: "basic_uint8_set_overflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A uint8 }{A: 10} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=1000000"}, + expected: nil, + expErr: "failed to parse pflags: invalid argument \"1000000\" for \"--a\" flag: strconv.ParseUint: parsing \"1000000\": value out of range", + }, + { + name: "basic_uint8_slice_set_overflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []uint8 }{A: []uint8{10}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=1000000"}, + expected: nil, + expErr: "failed to parse pflags: invalid argument \"1000000\" for \"--a\" flag: failed to parse integer index 0: strconv.ParseUint: parsing \"1000000\": value out of range", + }, + { + name: "basic_uint64_set_nooverflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A uint64 }{A: 10} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=128"}, + expected: &struct{ A uint64 }{A: 128}, + }, + { + name: "basic_uint64_set_overflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A uint64 }{A: 10} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=100_000_000_000_000_000_000"}, + expected: nil, + expErr: "failed to parse pflags: invalid argument \"100_000_000_000_000_000_000\" for \"--a\" flag: strconv.ParseUint: parsing \"100_000_000_000_000_000_000\": value out of range", + }, + { + name: "basic_uint64_slice_default", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []uint64 }{A: []uint64{10}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{}, + expected: &struct{ A []uint64 }{A: []uint64{10}}, + }, + { + name: "basic_uint64_slice_set_nooverflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []uint64 }{A: []uint64{10}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=128,32"}, + expected: &struct{ A []uint64 }{A: []uint64{128, 32}}, + }, + { + name: "basic_uintptr_set_nooverflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A uintptr }{A: 10} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=128"}, + expected: &struct{ A uintptr }{A: 128}, + }, + { + name: "basic_uintptr_set_overflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A uintptr }{A: 10} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=100_000_000_000_000_000_000"}, + expected: nil, + expErr: "failed to parse pflags: invalid argument \"100_000_000_000_000_000_000\" for \"--a\" flag: strconv.ParseUint: parsing \"100_000_000_000_000_000_000\": value out of range", + }, + { + name: "basic_uintptr_slice_default", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []uintptr }{A: []uintptr{10}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{}, + expected: &struct{ A []uintptr }{A: []uintptr{10}}, + }, + { + name: "basic_uintptr_slice_set_nooverflow", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []uintptr }{A: []uintptr{10}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{"--a=128,32"}, + expected: &struct{ A []uintptr }{A: []uintptr{128, 32}}, + }, + { name: "map_string_string_set", tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { @@ -348,6 +565,24 @@ func TestPFlags(t *testing.T) { args: []string{}, expected: &struct{ A map[string]struct{} }{A: map[string]struct{}{"i": {}}}, }, + { + name: "int_slice_default_val", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []int }{A: []int{33, 22}} + return &cfg, testWrapDials(&cfg) + }, + args: []string{}, + expected: &struct{ A []int }{A: []int{33, 22}}, + }, + { + name: "int_slice_default_nil", + tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) { + cfg := struct{ A []int }{A: []int(nil)} + return &cfg, testWrapDials(&cfg) + }, + args: []string{}, + expected: &struct{ A []int }{A: nil}, + }, { name: "complex128_default", tmplCB: func() (any, func(ctx context.Context, src *Set) (any, error)) {