Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: improvements to support generated data plane SDKs #1058

Merged
merged 2 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sdk/client/msgraph/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type Client struct {
tenantId string
}

func NewMsGraphClient(api environments.Api, serviceName string, apiVersion ApiVersion) (*Client, error) {
func NewClient(api environments.Api, serviceName string, apiVersion ApiVersion) (*Client, error) {
endpoint, ok := api.Endpoint()
if !ok {
return nil, fmt.Errorf("no `endpoint` was returned for this environment")
Expand All @@ -49,6 +49,11 @@ func NewMsGraphClient(api environments.Api, serviceName string, apiVersion ApiVe
}, nil
}

// Deprecated: use NewClient instead
func NewMsGraphClient(api environments.Api, serviceName string, apiVersion ApiVersion) (*Client, error) {
return NewClient(api, serviceName, apiVersion)
}

func (c *Client) NewRequest(ctx context.Context, input client.RequestOptions) (*client.Request, error) {
// TODO move these validations to base client method
if _, ok := ctx.Deadline(); !ok {
Expand Down
7 changes: 6 additions & 1 deletion sdk/client/resourcemanager/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type Client struct {
apiVersion string
}

func NewResourceManagerClient(api environments.Api, serviceName, apiVersion string) (*Client, error) {
func NewClient(api environments.Api, serviceName, apiVersion string) (*Client, error) {
endpoint, ok := api.Endpoint()
if !ok {
return nil, fmt.Errorf("no `endpoint` was returned for this environment")
Expand All @@ -38,6 +38,11 @@ func NewResourceManagerClient(api environments.Api, serviceName, apiVersion stri
}, nil
}

// Deprecated: use NewClient instead
func NewResourceManagerClient(api environments.Api, serviceName, apiVersion string) (*Client, error) {
return NewClient(api, serviceName, apiVersion)
}

func (c *Client) NewRequest(ctx context.Context, input client.RequestOptions) (*client.Request, error) {
// TODO move these validations to base client method
if _, ok := ctx.Deadline(); !ok {
Expand Down
109 changes: 109 additions & 0 deletions sdk/nullable/nullable.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package nullable

import (
"bytes"
"encoding/json"
)

var _ json.Marshaler = Type[string]{}
var _ json.Unmarshaler = &Type[string]{}

type Type[T comparable] map[bool]T

// Value returns a new Type[T], setting its type and value to the provided value
func Value[T comparable](t T) Type[T] {
var n Type[T]
n.Set(t)
return n
}

// NoZero returns a new Type[T], setting its type and value, whilst also nulling the value if it was set to
// its zero value. This ensures that zero values are sent as null.
func NoZero[T comparable](t T) Type[T] {
var n Type[T]
n.SetNoZero(t)
return n
}

// Get retrieves the underlying value, if present, and returns nil if the value is null
func (t Type[T]) Get() *T {
var empty T
if t.IsNull() {
return nil
}
if t.IsSet() {
ret := t[true]
return &ret
}
return &empty
}

// GetOrZero retrieves the underlying value, if present, and returns the zero value if null
func (t Type[T]) GetOrZero() T {
var empty T
val := t.Get()
if val == nil {
return empty
}
return *val
}

// Set sets the underlying value to a given value
func (t *Type[T]) Set(value T) {
*t = map[bool]T{true: value}
}

// SetNoZero sets the underlying value to a given value, whilst also nulling the value if it was set to
// its zero value. This ensures that zero values are sent as null.
func (t *Type[T]) SetNoZero(value T) {
var empty T
*t = map[bool]T{value != empty: value}
}

// SetNull clears the value and ensures a value of `null`
func (t *Type[T]) SetNull() {
var empty T
*t = map[bool]T{false: empty}
}

// SetUnspecified clears the value
func (t *Type[T]) SetUnspecified() {
*t = map[bool]T{}
}

// IsNull indicates whether the value was set to `null`
func (t Type[T]) IsNull() bool {
_, foundNull := t[false]
return foundNull
}

// IsSet indicates whether a value is set
func (t Type[T]) IsSet() bool {
return len(t) != 0
}

func (t Type[T]) MarshalJSON() ([]byte, error) {
// note: if value was unspecified, and `omitempty` is set on the field tags, `json.Marshal` will omit this field
// if value was specified, and `null`, marshal it
if t.IsNull() {
return []byte("null"), nil
}
// otherwise, we have a value, so marshal it
return json.Marshal(t[true])
}

func (t *Type[T]) UnmarshalJSON(data []byte) error {
// note: if value is unspecified, UnmarshalJSON won't be called
// if value is specified and `null`
if bytes.Equal(data, []byte("null")) {
t.SetNull()
return nil
}
// otherwise, we have an actual value, so parse it
var val T
if err := json.Unmarshal(data, &val); err != nil {
return err
}
t.Set(val)
return nil
}
150 changes: 150 additions & 0 deletions sdk/nullable/nullable_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package nullable

import (
"bytes"
"testing"
)

func TestTypeMarshalBool(t *testing.T) {
testCases := []struct {
value Type[bool]
expected []byte
}{
{
value: Value(true),
expected: []byte(`true`),
},
{
value: Value(false),
expected: []byte(`false`),
},
{
value: NoZero(true),
expected: []byte(`true`),
},
{
value: NoZero(false),
expected: []byte(`null`),
},
}

for i, testCase := range testCases {
result, err := testCase.value.MarshalJSON()
if err != nil {
t.Errorf("test case %d: %v", i, err)
}
if !bytes.Equal(result, testCase.expected) {
t.Errorf("test case %d: expected %q, got %q", i, string(testCase.expected), string(result))
}
}
}

func TestTypeMarshalFloat(t *testing.T) {
testCases := []struct {
value Type[float64]
expected []byte
}{
{
value: Value(123.45),
expected: []byte(`123.45`),
},
{
value: Value(0.0),
expected: []byte(`0`),
},
{
value: NoZero(123.45),
expected: []byte(`123.45`),
},
{
value: NoZero(-123.45),
expected: []byte(`-123.45`),
},
{
value: NoZero(0.0),
expected: []byte(`null`),
},
}

for i, testCase := range testCases {
result, err := testCase.value.MarshalJSON()
if err != nil {
t.Errorf("test case %d: %v", i, err)
}
if !bytes.Equal(result, testCase.expected) {
t.Errorf("test case %d: expected %q, got %q", i, string(testCase.expected), string(result))
}
}
}

func TestTypeMarshalInt(t *testing.T) {
testCases := []struct {
value Type[int]
expected []byte
}{
{
value: Value(123),
expected: []byte(`123`),
},
{
value: Value(0),
expected: []byte(`0`),
},
{
value: NoZero(123),
expected: []byte(`123`),
},
{
value: NoZero(-123),
expected: []byte(`-123`),
},
{
value: NoZero(0),
expected: []byte(`null`),
},
}

for i, testCase := range testCases {
result, err := testCase.value.MarshalJSON()
if err != nil {
t.Errorf("test case %d: %v", i, err)
}
if !bytes.Equal(result, testCase.expected) {
t.Errorf("test case %d: expected %q, got %q", i, string(testCase.expected), string(result))
}
}
}

func TestTypeMarshalString(t *testing.T) {
testCases := []struct {
value Type[string]
expected []byte
}{
{
value: Value("foo"),
expected: []byte(`"foo"`),
},
{
value: Value(""),
expected: []byte(`""`),
},
{
value: NoZero("foo"),
expected: []byte(`"foo"`),
},
{
value: NoZero(""),
expected: []byte(`null`),
},
}

for i, testCase := range testCases {
result, err := testCase.value.MarshalJSON()
if err != nil {
t.Errorf("test case %d: %v", i, err)
}
if !bytes.Equal(result, testCase.expected) {
t.Errorf("test case %d: expected %q, got %q", i, string(testCase.expected), string(result))
}
}
}
Loading