diff --git a/.gitignore b/.gitignore index 14b870f7927..669d7de5094 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .fleet/ .idea/ +.vscode/ .DS_Store tmp/ vendor/ diff --git a/sdk/nullable/sentinel_nullable.go b/sdk/nullable/sentinel_nullable.go new file mode 100644 index 00000000000..b4c4cecf8c1 --- /dev/null +++ b/sdk/nullable/sentinel_nullable.go @@ -0,0 +1,114 @@ +package nullable + +import ( + "encoding/json" + "reflect" + "strings" + "sync" +) + +// holds sentinel values used to send nulls +var nullSentinels map[reflect.Type]any = map[reflect.Type]any{} +var nullablesLock sync.RWMutex + +// NullValue is used to send an explicit 'null' within a request. +// This is typically used in JSON-MERGE-PATCH operations to delete a value. +// Type arugment `T` MUST be a pointer type (pointer, map, or slice) +// for interface type's null value, a pointer to implementor type is required +func NullValue[T any]() T { + t := reflect.TypeFor[T]() + + nullablesLock.RLock() + v, found := nullSentinels[t] + nullablesLock.RUnlock() + + if found { + // return the sentinel object + if t.Kind() == reflect.Interface { + var zero T + return zero + } + return v.(T) + } + + // promote to exclusive lock and check again (double-checked locking pattern) + nullablesLock.Lock() + defer nullablesLock.Unlock() + + v, found = nullSentinels[t] + if !found { + var o reflect.Value + switch k := t.Kind(); k { + case reflect.Map: + o = reflect.MakeMap(t) + case reflect.Slice: + o = reflect.MakeSlice(t, 1, 1) + default: + // let it panic here if non-pointer type is passed + o = reflect.New(t.Elem()) + } + v = o.Interface() + nullSentinels[t] = v + } + // return the sentinel object + return v.(T) +} + +func IsNullValue[T any](v T) bool { + t := reflect.TypeOf(v) + nullablesLock.RLock() + defer nullablesLock.RUnlock() + + // if found, it MUST be a pointer, so never panic here + if o, found := nullSentinels[t]; found { + o1 := reflect.ValueOf(o) + v1 := reflect.ValueOf(v) + return o1.Pointer() == v1.Pointer() + } + return false +} + +func MarshalNullableStruct(obj interface{}) ([]byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) + switch v.Kind() { + case reflect.Struct: + return marshalStruct(v) + } + return json.Marshal(obj) +} + +func marshalStruct(v reflect.Value) ([]byte, error) { + m := make(map[string]any) + t := v.Type() + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.PkgPath != "" { + continue + } + jsonName := field.Name + omitEmtpy := false + if tag := field.Tag.Get("json"); tag != "" { + opts := strings.Split(tag, ",") + jsonName = opts[0] + for _, opt := range opts[1:] { + if opt == "omitempty" { + omitEmtpy = true + break + } + } + } + + rval := v.Field(i) + val := rval.Interface() + + if val == nil || (omitEmtpy && rval.IsZero()) { + continue + } else if IsNullValue(val) { + m[jsonName] = nil + } else { + m[jsonName] = val + } + } + return json.Marshal(m) +} diff --git a/sdk/nullable/sentinel_nullable_test.go b/sdk/nullable/sentinel_nullable_test.go new file mode 100644 index 00000000000..d628179f7af --- /dev/null +++ b/sdk/nullable/sentinel_nullable_test.go @@ -0,0 +1,230 @@ +package nullable_test + +import ( + "encoding/json" + "reflect" + "strings" + "testing" + + "github.com/hashicorp/go-azure-sdk/sdk/nullable" +) + +type InnerStruct struct { + InnerName *string `json:"inner_name,omitempty"` + InnerID string `json:"inner_id"` + Number int64 `json:"number,omitempty"` +} + +func (i InnerStruct) MarshalJSON() ([]byte, error) { + return nullable.MarshalNullableStruct(i) +} + +type TestStruct struct { + ID string `json:"id"` + OmitID string `json:"omit_id,omitempty"` + Name *string `json:"name,omitempty"` + Age *float64 `json:"age,omitempty"` + Address *string `json:"address,omitempty"` + Inner *InnerStruct `json:"inner,omitempty"` +} + +// MarshalJSON implements json.Marshaler. +func (t TestStruct) MarshalJSON() ([]byte, error) { + return nullable.MarshalNullableStruct(t) +} + +var _ json.Marshaler = TestStruct{} + +func TestNullableValuePanic(t *testing.T) { + defer func() { + if e := recover(); e == nil { + t.Fatalf("Expected panic, but got nil") + } else if !strings.Contains(e.(string), "Elem of invalid type") { + t.Fatalf("Expected panic of invalid type but got %v", e) + } + }() + nullable.NullValue[string]() +} + +func TestMarshalNullableNil(t *testing.T) { + // nullable field address is nil, it should be omitted + name := "John Doe" + age := 30.0 + obj := TestStruct{ + Name: &name, + Age: &age, + } + + expected := map[string]interface{}{ + "age": age, + "id": "", + "name": name, + } + + data, err := json.Marshal(obj) + if err != nil { + t.Fatalf("MarshalNullable returned an error: %v", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(data, &result); err != nil { + t.Fatalf("json.Unmarshal returned an error: %v", err) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %v, but got %v", expected, result) + } +} + +func TestMarshalNullable(t *testing.T) { + // nullable field address is set to null value, it should be included + name := "John Doe" + age := 30.0 + obj := TestStruct{ + Name: &name, + Age: &age, + } + obj.Address = nullable.NullValue[*string]() + + expected := map[string]interface{}{ + "age": age, + "address": nil, + "id": "", + "name": name, + } + + data, err := json.Marshal(obj) + if err != nil { + t.Fatalf("MarshalNullable returned an error: %v", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(data, &result); err != nil { + t.Fatalf("json.Unmarshal returned an error: %v", err) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %v, but got %v", expected, result) + } +} + +func TestMarshalNullableInnerStruct(t *testing.T) { + name := "John Doe" + age := 30.0 + obj := TestStruct{ + Name: &name, + Age: &age, + } + obj.Address = nullable.NullValue[*string]() + obj.Inner = nullable.NullValue[*InnerStruct]() + + expected := map[string]interface{}{ + "address": nil, + "age": age, + "id": "", + "inner": nil, + "name": name, + } + expectedBytes, _ := json.Marshal(expected) + + data, err := json.Marshal(obj) + if err != nil { + t.Fatalf("MarshalNullable returned an error: %v", err) + } + + if !reflect.DeepEqual(data, expectedBytes) { + t.Errorf("Expected %s, but got %s", expectedBytes, data) + } +} + +func TestMarshalNullableWithInnerNullale(t *testing.T) { + name := "John Doe" + age := 30.0 + obj := TestStruct{ + Name: &name, + Age: &age, + Inner: &InnerStruct{ + InnerID: "", + InnerName: nullable.NullValue[*string](), + }, + } + obj.Address = nullable.NullValue[*string]() + + expected := map[string]interface{}{ + "address": nil, + "age": age, + "id": "", + "inner": map[string]interface{}{ + "inner_id": "", // inner_id is not omitempty flagged + "inner_name": nil, + }, + "name": name, + } + expectedBytes, _ := json.Marshal(expected) + + data, err := json.Marshal(obj) + if err != nil { + t.Fatalf("MarshalNullable returned an error: %v", err) + } + + if !reflect.DeepEqual(data, expectedBytes) { + t.Errorf("Expected %s, but got %s", expectedBytes, data) + } +} + +type ITest interface { + Foo() string +} + +type TestImpl struct{} + +func (t TestImpl) Foo() string { + return "foo" +} + +type NullableInterface struct { + ITest ITest `json:"itest,omitempty"` +} + +func (n NullableInterface) MarshalJSON() ([]byte, error) { + return nullable.MarshalNullableStruct(n) +} + +func TestMarshalNullableWithInterface(t *testing.T) { + obj := NullableInterface{ + ITest: nil, + } + + expected := map[string]interface{}{} + expectedBytes, _ := json.Marshal(expected) + + data, err := json.Marshal(obj) + if err != nil { + t.Fatalf("MarshalNullable returned an error: %v", err) + } + + if !reflect.DeepEqual(data, expectedBytes) { + t.Errorf("Expected %s, but got %s", expectedBytes, data) + } +} + +func TestMarshalNullableWithInterfaceNullValue(t *testing.T) { + // for interface, set the null value of it's implementation + obj := NullableInterface{ + ITest: nullable.NullValue[*TestImpl](), + } + + expected := map[string]interface{}{ + "itest": nil, + } + expectedBytes, _ := json.Marshal(expected) + + data, err := json.Marshal(obj) + if err != nil { + t.Fatalf("MarshalNullable returned an error: %v", err) + } + + if !reflect.DeepEqual(data, expectedBytes) { + t.Errorf("Expected %s, but got %s", expectedBytes, data) + } +}