diff --git a/sdk/client/msgraph/client.go b/sdk/client/msgraph/client.go index 3cf67d08f25..0a74b22cad1 100644 --- a/sdk/client/msgraph/client.go +++ b/sdk/client/msgraph/client.go @@ -35,7 +35,7 @@ type Client struct { tenantId string } -func NewMsGraphClient(api environments.Api, serviceName string, apiVersion ApiVersion) (*Client, error) { +func NewClient(api environments.Api, serviceName string, apiVersion ApiVersion) (*Client, error) { endpoint, ok := api.Endpoint() if !ok { return nil, fmt.Errorf("no `endpoint` was returned for this environment") @@ -49,6 +49,11 @@ func NewMsGraphClient(api environments.Api, serviceName string, apiVersion ApiVe }, nil } +// Deprecated: use NewClient instead +func NewMsGraphClient(api environments.Api, serviceName string, apiVersion ApiVersion) (*Client, error) { + return NewClient(api, serviceName, apiVersion) +} + func (c *Client) NewRequest(ctx context.Context, input client.RequestOptions) (*client.Request, error) { // TODO move these validations to base client method if _, ok := ctx.Deadline(); !ok { diff --git a/sdk/client/resourcemanager/client.go b/sdk/client/resourcemanager/client.go index 8c9b67a9c4f..9cd3335f1ba 100644 --- a/sdk/client/resourcemanager/client.go +++ b/sdk/client/resourcemanager/client.go @@ -23,7 +23,7 @@ type Client struct { apiVersion string } -func NewResourceManagerClient(api environments.Api, serviceName, apiVersion string) (*Client, error) { +func NewClient(api environments.Api, serviceName, apiVersion string) (*Client, error) { endpoint, ok := api.Endpoint() if !ok { return nil, fmt.Errorf("no `endpoint` was returned for this environment") @@ -38,6 +38,11 @@ func NewResourceManagerClient(api environments.Api, serviceName, apiVersion stri }, nil } +// Deprecated: use NewClient instead +func NewResourceManagerClient(api environments.Api, serviceName, apiVersion string) (*Client, error) { + return NewClient(api, serviceName, apiVersion) +} + func (c *Client) NewRequest(ctx context.Context, input client.RequestOptions) (*client.Request, error) { // TODO move these validations to base client method if _, ok := ctx.Deadline(); !ok { diff --git a/sdk/nullable/nullable.go b/sdk/nullable/nullable.go new file mode 100644 index 00000000000..5e2126e533c --- /dev/null +++ b/sdk/nullable/nullable.go @@ -0,0 +1,109 @@ +package nullable + +import ( + "bytes" + "encoding/json" +) + +var _ json.Marshaler = Type[string]{} +var _ json.Unmarshaler = &Type[string]{} + +type Type[T comparable] map[bool]T + +// Value returns a new Type[T], setting its type and value to the provided value +func Value[T comparable](t T) Type[T] { + var n Type[T] + n.Set(t) + return n +} + +// NoZero returns a new Type[T], setting its type and value, whilst also nulling the value if it was set to +// its zero value. This ensures that zero values are sent as null. +func NoZero[T comparable](t T) Type[T] { + var n Type[T] + n.SetNoZero(t) + return n +} + +// Get retrieves the underlying value, if present, and returns nil if the value is null +func (t Type[T]) Get() *T { + var empty T + if t.IsNull() { + return nil + } + if t.IsSet() { + ret := t[true] + return &ret + } + return &empty +} + +// GetOrZero retrieves the underlying value, if present, and returns the zero value if null +func (t Type[T]) GetOrZero() T { + var empty T + val := t.Get() + if val == nil { + return empty + } + return *val +} + +// Set sets the underlying value to a given value +func (t *Type[T]) Set(value T) { + *t = map[bool]T{true: value} +} + +// SetNoZero sets the underlying value to a given value, whilst also nulling the value if it was set to +// its zero value. This ensures that zero values are sent as null. +func (t *Type[T]) SetNoZero(value T) { + var empty T + *t = map[bool]T{value != empty: value} +} + +// SetNull clears the value and ensures a value of `null` +func (t *Type[T]) SetNull() { + var empty T + *t = map[bool]T{false: empty} +} + +// SetUnspecified clears the value +func (t *Type[T]) SetUnspecified() { + *t = map[bool]T{} +} + +// IsNull indicates whether the value was set to `null` +func (t Type[T]) IsNull() bool { + _, foundNull := t[false] + return foundNull +} + +// IsSet indicates whether a value is set +func (t Type[T]) IsSet() bool { + return len(t) != 0 +} + +func (t Type[T]) MarshalJSON() ([]byte, error) { + // note: if value was unspecified, and `omitempty` is set on the field tags, `json.Marshal` will omit this field + // if value was specified, and `null`, marshal it + if t.IsNull() { + return []byte("null"), nil + } + // otherwise, we have a value, so marshal it + return json.Marshal(t[true]) +} + +func (t *Type[T]) UnmarshalJSON(data []byte) error { + // note: if value is unspecified, UnmarshalJSON won't be called + // if value is specified and `null` + if bytes.Equal(data, []byte("null")) { + t.SetNull() + return nil + } + // otherwise, we have an actual value, so parse it + var val T + if err := json.Unmarshal(data, &val); err != nil { + return err + } + t.Set(val) + return nil +} diff --git a/sdk/nullable/nullable_test.go b/sdk/nullable/nullable_test.go new file mode 100644 index 00000000000..4a98e561a2b --- /dev/null +++ b/sdk/nullable/nullable_test.go @@ -0,0 +1,150 @@ +package nullable + +import ( + "bytes" + "testing" +) + +func TestTypeMarshalBool(t *testing.T) { + testCases := []struct { + value Type[bool] + expected []byte + }{ + { + value: Value(true), + expected: []byte(`true`), + }, + { + value: Value(false), + expected: []byte(`false`), + }, + { + value: NoZero(true), + expected: []byte(`true`), + }, + { + value: NoZero(false), + expected: []byte(`null`), + }, + } + + for i, testCase := range testCases { + result, err := testCase.value.MarshalJSON() + if err != nil { + t.Errorf("test case %d: %v", i, err) + } + if !bytes.Equal(result, testCase.expected) { + t.Errorf("test case %d: expected %q, got %q", i, string(testCase.expected), string(result)) + } + } +} + +func TestTypeMarshalFloat(t *testing.T) { + testCases := []struct { + value Type[float64] + expected []byte + }{ + { + value: Value(123.45), + expected: []byte(`123.45`), + }, + { + value: Value(0.0), + expected: []byte(`0`), + }, + { + value: NoZero(123.45), + expected: []byte(`123.45`), + }, + { + value: NoZero(-123.45), + expected: []byte(`-123.45`), + }, + { + value: NoZero(0.0), + expected: []byte(`null`), + }, + } + + for i, testCase := range testCases { + result, err := testCase.value.MarshalJSON() + if err != nil { + t.Errorf("test case %d: %v", i, err) + } + if !bytes.Equal(result, testCase.expected) { + t.Errorf("test case %d: expected %q, got %q", i, string(testCase.expected), string(result)) + } + } +} + +func TestTypeMarshalInt(t *testing.T) { + testCases := []struct { + value Type[int] + expected []byte + }{ + { + value: Value(123), + expected: []byte(`123`), + }, + { + value: Value(0), + expected: []byte(`0`), + }, + { + value: NoZero(123), + expected: []byte(`123`), + }, + { + value: NoZero(-123), + expected: []byte(`-123`), + }, + { + value: NoZero(0), + expected: []byte(`null`), + }, + } + + for i, testCase := range testCases { + result, err := testCase.value.MarshalJSON() + if err != nil { + t.Errorf("test case %d: %v", i, err) + } + if !bytes.Equal(result, testCase.expected) { + t.Errorf("test case %d: expected %q, got %q", i, string(testCase.expected), string(result)) + } + } +} + +func TestTypeMarshalString(t *testing.T) { + testCases := []struct { + value Type[string] + expected []byte + }{ + { + value: Value("foo"), + expected: []byte(`"foo"`), + }, + { + value: Value(""), + expected: []byte(`""`), + }, + { + value: NoZero("foo"), + expected: []byte(`"foo"`), + }, + { + value: NoZero(""), + expected: []byte(`null`), + }, + } + + for i, testCase := range testCases { + result, err := testCase.value.MarshalJSON() + if err != nil { + t.Errorf("test case %d: %v", i, err) + } + if !bytes.Equal(result, testCase.expected) { + t.Errorf("test case %d: expected %q, got %q", i, string(testCase.expected), string(result)) + } + } +}