Skip to content

Commit

Permalink
[Fix] Support Query parameters for POST/PATCH operations
Browse files Browse the repository at this point in the history
  • Loading branch information
hectorcast-db committed Jan 17, 2025
1 parent 914ab6b commit 8972af3
Show file tree
Hide file tree
Showing 28 changed files with 1,619 additions and 784 deletions.
3 changes: 2 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ func (c *DatabricksClient) GetOAuthToken(ctx context.Context, authDetails string

// Do sends an HTTP request against path.
func (c *DatabricksClient) Do(ctx context.Context, method, path string,
headers map[string]string, request, response any,
headers map[string]string, queryParams map[string]any, request, response any,
visitors ...func(*http.Request) error) error {
opts := []httpclient.DoOption{}
for _, v := range visitors {
opts = append(opts, httpclient.WithRequestVisitor(v))
}
opts = append(opts, httpclient.WithQueryParameters(queryParams))
opts = append(opts, httpclient.WithRequestHeaders(headers))
opts = append(opts, httpclient.WithRequestData(request))
opts = append(opts, httpclient.WithResponseUnmarshal(response))
Expand Down
42 changes: 24 additions & 18 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ func TestSimpleRequestFailsURLError(t *testing.T) {
require.NoError(t, err)
err = c.Do(context.Background(), "GET", "/a/b", map[string]string{
"e": "f",
}, map[string]string{
"c": "d",
}, nil)
},
nil,
map[string]string{
"c": "d",
}, nil)
require.EqualError(t, err, `Get "https://some/a/b?c=d": nope`)
}

Expand All @@ -68,9 +70,11 @@ func TestSimpleRequestFailsAPIError(t *testing.T) {
require.NoError(t, err)
err = c.Do(context.Background(), "GET", "/a/b", map[string]string{
"e": "f",
}, map[string]string{
"c": "d",
}, nil)
},
nil,
map[string]string{
"c": "d",
}, nil)
require.EqualError(t, err, "nope")
require.ErrorIs(t, err, apierr.ErrInvalidParameterValue)
}
Expand Down Expand Up @@ -117,9 +121,11 @@ func TestETag(t *testing.T) {
require.NoError(t, err)
err = c.Do(context.Background(), "GET", "/a/b", map[string]string{
"e": "f",
}, map[string]string{
"c": "d",
}, nil)
},
nil,
map[string]string{
"c": "d",
}, nil)
details := apierr.GetErrorInfo(err)
require.Equal(t, 1, len(details))
errorDetails := details[0]
Expand Down Expand Up @@ -148,7 +154,7 @@ func TestSimpleRequestSucceeds(t *testing.T) {
})
require.NoError(t, err)
var resp Dummy
err = c.Do(context.Background(), "POST", "/c", nil, Dummy{1}, &resp)
err = c.Do(context.Background(), "POST", "/c", nil, nil, Dummy{1}, &resp)
require.NoError(t, err)
require.Equal(t, 2, resp.Foo)
}
Expand Down Expand Up @@ -180,7 +186,7 @@ func TestSimpleRequestRetried(t *testing.T) {
})
require.NoError(t, err)
var resp Dummy
err = c.Do(context.Background(), "PATCH", "/a", nil, Dummy{1}, &resp)
err = c.Do(context.Background(), "PATCH", "/a", nil, nil, Dummy{1}, &resp)
require.NoError(t, err)
require.Equal(t, 2, resp.Foo)
require.True(t, retried[0], "request was not retried")
Expand All @@ -203,7 +209,7 @@ func TestSimpleRequestAPIError(t *testing.T) {
}),
})
require.NoError(t, err)
err = c.Do(context.Background(), "PATCH", "/a", nil, map[string]any{}, nil)
err = c.Do(context.Background(), "PATCH", "/a", nil, nil, map[string]any{}, nil)
var aerr *apierr.APIError
require.ErrorAs(t, err, &aerr)
require.Equal(t, "NOT_FOUND", aerr.ErrorCode)
Expand All @@ -223,7 +229,7 @@ func TestHttpTransport(t *testing.T) {
client, err := New(cfg)
require.NoError(t, err)

err = client.Do(context.Background(), "GET", "/a", nil, nil, bytes.Buffer{})
err = client.Do(context.Background(), "GET", "/a", nil, nil, nil, bytes.Buffer{})
require.NoError(t, err)
require.True(t, calledMock)
}
Expand All @@ -249,9 +255,9 @@ func TestDoRemovesDoubleSlashesFromFilesAPI(t *testing.T) {
}),
})
require.NoError(t, err)
err = c.Do(context.Background(), "GET", "/api/2.0/fs/files//Volumes/abc/def/ghi", nil, map[string]any{}, nil)
err = c.Do(context.Background(), "GET", "/api/2.0/fs/files//Volumes/abc/def/ghi", nil, nil, map[string]any{}, nil)
require.NoError(t, err)
err = c.Do(context.Background(), "GET", "/api/2.0/anotherservice//test", nil, map[string]any{}, nil)
err = c.Do(context.Background(), "GET", "/api/2.0/anotherservice//test", nil, nil, map[string]any{}, nil)
require.NoError(t, err)
}

Expand Down Expand Up @@ -340,7 +346,7 @@ func captureUserAgent(t *testing.T) string {
})
require.NoError(t, err)

err = c.Do(context.Background(), "GET", "/a", nil, nil, nil)
err = c.Do(context.Background(), "GET", "/a", nil, nil, nil, nil)
require.NoError(t, err)

return userAgent
Expand Down Expand Up @@ -450,7 +456,7 @@ func testNonJSONResponseIncludedInError(t *testing.T, statusCode int, status, er
})
require.NoError(t, err)
var m map[string]string
err = c.Do(context.Background(), "GET", "/a", nil, nil, &m)
err = c.Do(context.Background(), "GET", "/a", nil, nil, nil, &m)
require.EqualError(t, err, errorMessage)
}

Expand All @@ -477,6 +483,6 @@ func TestRetryOn503(t *testing.T) {
}),
})
require.NoError(t, err)
err = c.Do(context.Background(), "GET", "/a/b", nil, map[string]any{}, nil)
err = c.Do(context.Background(), "GET", "/a/b", nil, nil, map[string]any{}, nil)
require.NoError(t, err)
}
10 changes: 9 additions & 1 deletion httpclient/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,21 @@ type DoOption struct {
body any
contentType string
isAuthOption bool
queryParams map[string]any
}

// Do sends an HTTP request against path.
func (c *ApiClient) Do(ctx context.Context, method, path string, opts ...DoOption) error {
var authVisitor RequestVisitor
var explicitQueryParams map[string]any
visitors := c.config.Visitors[:]
for _, o := range opts {
if o.queryParams != nil {
if explicitQueryParams != nil {
return fmt.Errorf("only one set of query params is allowed")
}
explicitQueryParams = o.queryParams
}
if o.in == nil {
continue
}
Expand Down Expand Up @@ -150,7 +158,7 @@ func (c *ApiClient) Do(ctx context.Context, method, path string, opts ...DoOptio
data = o.body
contentType = o.contentType
}
requestBody, err := makeRequestBody(method, &path, data, contentType)
requestBody, err := makeRequestBody(method, &path, data, contentType, explicitQueryParams)
if err != nil {
return fmt.Errorf("request marshal: %w", err)
}
Expand Down
17 changes: 17 additions & 0 deletions httpclient/api_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,23 @@ func TestSimpleRequestFailsURLError(t *testing.T) {
require.EqualError(t, err, `Get "/a/b?c=d": nope`)
}

func TestQueryParameters(t *testing.T) {
c := NewApiClient(ClientConfig{
RetryTimeout: 1 * time.Millisecond,
Transport: hc(func(r *http.Request) (*http.Response, error) {
require.Equal(t, "POST", r.Method)
require.Equal(t, "/a/b", r.URL.Path)
require.Equal(t, "c=d", r.URL.RawQuery)
return nil, fmt.Errorf("nope")
}),
})
err := c.Do(context.Background(), "POST", "/a/b",
WithQueryParameters(map[string]any{
"c": "d",
}))
require.EqualError(t, err, `Post "/a/b?c=d": nope`)
}

func TestSimpleRequestFailsAPIError(t *testing.T) {
c := NewApiClient(ClientConfig{
Transport: hc(func(r *http.Request) (*http.Response, error) {
Expand Down
51 changes: 40 additions & 11 deletions httpclient/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ func WithRequestData(body any) DoOption {
}
}

// WithQueryParameters takes a map and sends it as query string for non GET/DELETE/HEAD calls.
// This is ignored for GET/DELETE/HEAD calls, as the query parameters are serialized from the body instead.
//
// Experimental: this method may eventually be split into more granular options.
func WithQueryParameters(queryParams map[string]any) DoOption {
// refactor this, so that we split JSON/query string serialization and make
// separate request visitors internally.
return DoOption{
queryParams: queryParams,
}
}

// WithUrlEncodedData takes either a struct instance, map, string, bytes, or io.Reader plus
// a content type, and sends it either as query string for GET and DELETE calls, or as request body
// for POST, PUT, and PATCH calls. The content type is set to "application/x-www-form-urlencoded"
Expand Down Expand Up @@ -148,24 +160,41 @@ func EncodeMultiSegmentPathParameter(p string) string {
return b.String()
}

func makeRequestBody(method string, requestURL *string, data interface{}, contentType string) (common.RequestBody, error) {
if data == nil {
// We used to not send any query parameters for non GET/DELETE/HEAD requests.
// Moreover, serialization for query paramters in GET/DELETE/HEAD requests depends on the `url` tag.
// This tag is wrongly generated and fixing it will have an unknown inpact on the SDK.
// So:
// * GET/DELETE/HEAD requests are sent with query parameters serialized from data using the `url` tag as before (no change).
// * The rest of the requests are sent with query parameters serialized from explicitQueryParams, which does not use the `url` tag.
// TODO: For SDK-Mod, refactor this and remove the `url` tag completely.
func makeRequestBody(method string, requestURL *string, data interface{}, contentType string, explicitQueryParams map[string]any) (common.RequestBody, error) {
if data == nil && len(explicitQueryParams) == 0 {
return common.RequestBody{}, nil
}
if method == "GET" || method == "DELETE" || method == "HEAD" {
qs, err := makeQueryString(data)
if err != nil {
return common.RequestBody{}, err
if data != nil {
if method == "GET" || method == "DELETE" || method == "HEAD" {
qs, err := makeQueryString(data)
if err != nil {
return common.RequestBody{}, err
}
*requestURL += "?" + qs
return common.NewRequestBody([]byte{})
}
if contentType == UrlEncodedContentType {
qs, err := makeQueryString(data)
if err != nil {
return common.RequestBody{}, err
}
return common.NewRequestBody(qs)
}
*requestURL += "?" + qs
return common.NewRequestBody([]byte{})
}
if contentType == UrlEncodedContentType {
qs, err := makeQueryString(data)
if len(explicitQueryParams) > 0 {
qs, err := makeQueryString(explicitQueryParams)
if err != nil {
return common.RequestBody{}, err
}
return common.NewRequestBody(qs)
*requestURL += "?" + qs
return common.NewRequestBody(data)
}
return common.NewRequestBody(data)
}
41 changes: 32 additions & 9 deletions httpclient/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ func TestMakeRequestBody(t *testing.T) {
Scope string `json:"scope" url:"scope"`
}
requestURL := "/a/b/c"
body, err := makeRequestBody("GET", &requestURL, x{"test"}, "")
body, err := makeRequestBody("GET", &requestURL, x{"test"}, "", nil)
require.NoError(t, err)
bodyBytes, err := io.ReadAll(body.Reader)
require.NoError(t, err)
require.Equal(t, "/a/b/c?scope=test", requestURL)
require.Equal(t, 0, len(bodyBytes))

requestURL = "/a/b/c"
body, err = makeRequestBody("POST", &requestURL, x{"test"}, "")
body, err = makeRequestBody("POST", &requestURL, x{"test"}, "", nil)
require.NoError(t, err)
bodyBytes, err = io.ReadAll(body.Reader)
require.NoError(t, err)
Expand All @@ -37,7 +37,7 @@ func TestMakeRequestBody(t *testing.T) {
require.Equal(t, []byte(x1), bodyBytes)

requestURL = "/a/b/c"
body, err = makeRequestBody("HEAD", &requestURL, x{"test"}, "")
body, err = makeRequestBody("HEAD", &requestURL, x{"test"}, "", nil)
require.NoError(t, err)
bodyBytes, err = io.ReadAll(body.Reader)
require.NoError(t, err)
Expand All @@ -47,7 +47,7 @@ func TestMakeRequestBody(t *testing.T) {

func TestMakeRequestBodyFromReader(t *testing.T) {
requestURL := "/a/b/c"
body, err := makeRequestBody("PUT", &requestURL, strings.NewReader("abc"), "")
body, err := makeRequestBody("PUT", &requestURL, strings.NewReader("abc"), "", nil)
require.NoError(t, err)
bodyBytes, err := io.ReadAll(body.Reader)
require.NoError(t, err)
Expand All @@ -61,7 +61,7 @@ func TestUrlEncoding(t *testing.T) {
GrantType: "grant",
}
requestURL := "/a/b/c"
body, err := makeRequestBody("POST", &requestURL, data, UrlEncodedContentType)
body, err := makeRequestBody("POST", &requestURL, data, UrlEncodedContentType, nil)
require.NoError(t, err)
bodyBytes, err := io.ReadAll(body.Reader)
require.NoError(t, err)
Expand All @@ -71,7 +71,7 @@ func TestUrlEncoding(t *testing.T) {

func TestMakeRequestBodyReaderError(t *testing.T) {
requestURL := "/a/b/c"
_, err := makeRequestBody("POST", &requestURL, errReader(false), "")
_, err := makeRequestBody("POST", &requestURL, errReader(false), "", nil)
// The request body is only read once the request is sent, so no error
// should be returned until then.
require.NoError(t, err, "request body reader error should be ignored")
Expand All @@ -82,7 +82,7 @@ func TestMakeRequestBodyJsonError(t *testing.T) {
type x struct {
Foo chan string `json:"foo"`
}
_, err := makeRequestBody("POST", &requestURL, x{make(chan string)}, "")
_, err := makeRequestBody("POST", &requestURL, x{make(chan string)}, "", nil)
require.EqualError(t, err, "request marshal failure: json: unsupported type: chan string")
}

Expand All @@ -97,13 +97,13 @@ func TestMakeRequestBodyQueryFailingEncode(t *testing.T) {
type x struct {
Foo failingUrlEncode `url:"foo"`
}
_, err := makeRequestBody("GET", &requestURL, x{failingUrlEncode("always failing")}, "")
_, err := makeRequestBody("GET", &requestURL, x{failingUrlEncode("always failing")}, "", nil)
require.EqualError(t, err, "cannot create query string: always failing")
}

func TestMakeRequestBodyQueryUnsupported(t *testing.T) {
requestURL := "/a/b/c"
_, err := makeRequestBody("GET", &requestURL, true, "")
_, err := makeRequestBody("GET", &requestURL, true, "", nil)
require.EqualError(t, err, "unsupported query string data: true")
}

Expand Down Expand Up @@ -141,3 +141,26 @@ func TestEncodeMultiSegmentPathParameter(t *testing.T) {
// # and ? should be encoded.
assert.Equal(t, "a%23b%3Fc", EncodeMultiSegmentPathParameter("a#b?c"))
}

func TestMakeRequestBodyExplicitQueryParams(t *testing.T) {
type x struct {
Scope string `json:"scope" url:"scope"`
}
requestURL := "/a/b/c"
// For GET, it should be ignored.
body, err := makeRequestBody("GET", &requestURL, x{"test"}, "", map[string]any{"foo": "bar"})
require.NoError(t, err)
bodyBytes, err := io.ReadAll(body.Reader)
require.NoError(t, err)
require.Equal(t, "/a/b/c?scope=test", requestURL)
require.Equal(t, 0, len(bodyBytes))

requestURL = "/a/b/c"
body, err = makeRequestBody("POST", &requestURL, x{"test"}, "", map[string]any{"foo": "bar"})
require.NoError(t, err)
bodyBytes, err = io.ReadAll(body.Reader)
require.NoError(t, err)
require.Equal(t, "/a/b/c?foo=bar", requestURL)
x1 := `{"scope":"test"}`
require.Equal(t, []byte(x1), bodyBytes)
}
Loading

0 comments on commit 8972af3

Please sign in to comment.