diff --git a/mapstructure.go b/mapstructure.go index 6cb703cf..7715a687 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -425,6 +425,7 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e return nil } + hooked := false if d.config.DecodeHook != nil { // We have a DecodeHook, so let's pre-process the input. var err error @@ -432,39 +433,49 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e if err != nil { return fmt.Errorf("error decoding '%s': %s", name, err) } + if input != nil && reflect.TypeOf(input).Comparable() && + inputVal.Type().Comparable() && + input != inputVal.Interface() { + //hook was changed input value + hooked = true + } } var err error - outputKind := getKind(outVal) addMetaKey := true - switch outputKind { - case reflect.Bool: - err = d.decodeBool(name, input, outVal) - case reflect.Interface: - err = d.decodeBasic(name, input, outVal) - case reflect.String: - err = d.decodeString(name, input, outVal) - case reflect.Int: - err = d.decodeInt(name, input, outVal) - case reflect.Uint: - err = d.decodeUint(name, input, outVal) - case reflect.Float32: - err = d.decodeFloat(name, input, outVal) - case reflect.Struct: - err = d.decodeStruct(name, input, outVal) - case reflect.Map: - err = d.decodeMap(name, input, outVal) - case reflect.Ptr: - addMetaKey, err = d.decodePtr(name, input, outVal) - case reflect.Slice: - err = d.decodeSlice(name, input, outVal) - case reflect.Array: - err = d.decodeArray(name, input, outVal) - case reflect.Func: - err = d.decodeFunc(name, input, outVal) - default: - // If we reached this point then we weren't able to decode it - return fmt.Errorf("%s: unsupported type: %s", name, outputKind) + if hooked && reflect.TypeOf(input).AssignableTo(outVal.Type()) { + outVal.Set(reflect.ValueOf(input)) + } else { + outputKind := getKind(outVal) + switch outputKind { + case reflect.Bool: + err = d.decodeBool(name, input, outVal) + case reflect.Interface: + err = d.decodeBasic(name, input, outVal) + case reflect.String: + err = d.decodeString(name, input, outVal) + case reflect.Int: + err = d.decodeInt(name, input, outVal) + case reflect.Uint: + err = d.decodeUint(name, input, outVal) + case reflect.Float32: + err = d.decodeFloat(name, input, outVal) + case reflect.Struct: + err = d.decodeStruct(name, input, outVal) + case reflect.Map: + err = d.decodeMap(name, input, outVal) + case reflect.Ptr: + addMetaKey, err = d.decodePtr(name, input, outVal) + case reflect.Slice: + err = d.decodeSlice(name, input, outVal) + case reflect.Array: + err = d.decodeArray(name, input, outVal) + case reflect.Func: + err = d.decodeFunc(name, input, outVal) + default: + // If we reached this point then we weren't able to decode it + return fmt.Errorf("%s: unsupported type: %s", name, outputKind) + } } // If we reached here, then we successfully decoded SOMETHING, so diff --git a/mapstructure_test.go b/mapstructure_test.go index ccd474c1..c283a865 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go @@ -1022,6 +1022,88 @@ func TestDecode_FuncHook(t *testing.T) { } } +func TestDecode_SamePointerHook(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "A": "foo", + "B": "foo", + } + + ptr := stringPtr("bar") + + decodeHook := func(f, t reflect.Type, v interface{}) (interface{}, error) { + if str, ok := v.(string); ok && str == "foo" { + return ptr, nil + } + return v, nil + } + + var result struct { + A *string + B *string + } + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an err: %s", err) + } + + if result.A != result.B { + t.Errorf("decoded pointers should be the same") + } +} + +func TestDecode_SameInterfaceHook(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "A": "foo", + "B": "foo", + } + + var intf io.Reader = strings.NewReader("bar") + + decodeHook := func(f, t reflect.Type, v interface{}) (interface{}, error) { + if str, ok := v.(string); ok && str == "foo" { + return intf, nil + } + return v, nil + } + + var result struct { + A io.Reader + B io.Reader + } + config := &DecoderConfig{ + DecodeHook: decodeHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = decoder.Decode(input) + if err != nil { + t.Fatalf("got an err: %s", err) + } + + if result.A != result.B { + t.Errorf("decoded interfaces should be the same") + } +} + func TestDecode_NonStruct(t *testing.T) { t.Parallel()