From ca4fe47c8ae9dbccae7fa3ae353148a27f659698 Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Tue, 29 Nov 2022 13:45:43 +0100 Subject: [PATCH] Fix data race in Content-Type header (#9) --- .github/workflows/cloc.yml | 4 ++- .github/workflows/golangci-lint.yml | 6 ++-- .github/workflows/gorelease.yml | 2 +- .github/workflows/test-unit.yml | 6 ++-- .golangci.yml | 14 ++++++++ Makefile | 2 +- client.go | 52 +++++++++++++++++------------ client_test.go | 8 ++--- go.mod | 4 +-- go.sum | 7 ++-- server.go | 6 ++-- server_test.go | 21 ++++++------ 12 files changed, 80 insertions(+), 52 deletions(-) diff --git a/.github/workflows/cloc.yml b/.github/workflows/cloc.yml index 4592bb4..619ca74 100644 --- a/.github/workflows/cloc.yml +++ b/.github/workflows/cloc.yml @@ -24,7 +24,9 @@ jobs: - name: Count Lines Of Code id: loc run: | - curl -sLO https://github.com/vearutop/sccdiff/releases/download/v1.0.2/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz && echo "b17e76bede22af0206b4918d3b3c4e7357f2a21b57f8de9e7c9dc0eb56b676c0 sccdiff" | shasum -c + curl -sLO https://github.com/vearutop/sccdiff/releases/download/v1.0.3/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz + sccdiff_hash=$(git hash-object ./sccdiff) + [ "$sccdiff_hash" == "ae8a07b687bd3dba60861584efe724351aa7ff63" ] || (echo "::error::unexpected hash for sccdiff, possible tampering: $sccdiff_hash" && exit 1) OUTPUT=$(cd pr && ../sccdiff -basedir ../base) echo "${OUTPUT}" OUTPUT="${OUTPUT//$'\n'/%0A}" diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 48207f9..bf0bcdb 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -21,13 +21,13 @@ jobs: steps: - uses: actions/setup-go@v3 with: - go-version: 1.18.x + go-version: 1.19.x - uses: actions/checkout@v2 - name: golangci-lint - uses: golangci/golangci-lint-action@v3.1.0 + uses: golangci/golangci-lint-action@v3.2.0 with: # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. - version: v1.46.2 + version: v1.50.0 # Optional: working directory, useful for monorepos # working-directory: somedir diff --git a/.github/workflows/gorelease.yml b/.github/workflows/gorelease.yml index f1c678a..6267500 100644 --- a/.github/workflows/gorelease.yml +++ b/.github/workflows/gorelease.yml @@ -9,7 +9,7 @@ concurrency: cancel-in-progress: true env: - GO_VERSION: 1.18.x + GO_VERSION: 1.19.x jobs: gorelease: runs-on: ubuntu-latest diff --git a/.github/workflows/test-unit.yml b/.github/workflows/test-unit.yml index 60cdcda..94441bd 100644 --- a/.github/workflows/test-unit.yml +++ b/.github/workflows/test-unit.yml @@ -21,7 +21,7 @@ jobs: test: strategy: matrix: - go-version: [ 1.16.x, 1.17.x, 1.18.x ] + go-version: [ 1.16.x, 1.17.x, 1.18.x, 1.19.x ] runs-on: ubuntu-latest steps: - name: Install Go stable @@ -88,8 +88,10 @@ jobs: id: annotate if: matrix.go-version == env.COV_GO_VERSION && github.event.pull_request.base.sha != '' run: | + curl -sLO https://github.com/vearutop/gocovdiff/releases/download/v1.3.6/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz + gocovdiff_hash=$(git hash-object ./gocovdiff) + [ "$gocovdiff_hash" == "8e507e0d671d4d6dfb3612309b72b163492f28eb" ] || (echo "::error::unexpected hash for gocovdiff, possible tampering: $gocovdiff_hash" && exit 1) git fetch origin master ${{ github.event.pull_request.base.sha }} - curl -sLO https://github.com/vearutop/gocovdiff/releases/download/v1.3.4/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz && shasum -a 256 gocovdiff && echo "b351c67526eefeb0671c82e9271ae984875865eed19e911f40f78348cb98347c gocovdiff" | shasum -c REP=$(./gocovdiff -cov unit.coverprofile -gha-annotations gha-unit.txt -delta-cov-file delta-cov-unit.txt -target-delta-cov ${TARGET_DELTA_COV}) echo "${REP}" REP="${REP//$'\n'/%0A}" diff --git a/.golangci.yml b/.golangci.yml index 591e45f..bc4a080 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -42,6 +42,12 @@ linters: - ireturn - exhaustruct - nonamedreturns + - nosnakecase + - structcheck + - varcheck + - deadcode + - testableexamples + - dupword issues: exclude-use-default: false @@ -53,5 +59,13 @@ issues: - noctx - funlen - dupl + - structcheck + - unused + - unparam + - nosnakecase path: "_test.go" + - linters: + - errcheck # Error checking omitted for brevity. + - gosec + path: "example_" diff --git a/Makefile b/Makefile index 7de96d0..e578775 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -#GOLANGCI_LINT_VERSION := "v1.46.2" # Optional configuration to pinpoint golangci-lint version. +#GOLANGCI_LINT_VERSION := "v1.50.0" # Optional configuration to pinpoint golangci-lint version. # The head of Makefile determines location of dev-go to include standard targets. GO ?= go diff --git a/client.go b/client.go index f606e2a..55be17e 100644 --- a/client.go +++ b/client.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "net/http/cookiejar" "net/url" @@ -34,7 +33,7 @@ type Client struct { // Cookies are default cookies added to all requests, can be overridden by WithCookie. Cookies map[string]string - ctx context.Context // nolint:containedctx // Context is configured separately. + ctx context.Context //nolint:containedctx // Context is configured separately. resp *http.Response respBody []byte @@ -231,7 +230,7 @@ func (c *Client) do() (err error) { return } - body, er := ioutil.ReadAll(resp.Body) + body, er := io.ReadAll(resp.Body) if er != nil { return } @@ -264,8 +263,8 @@ func (c *Client) do() (err error) { // CheckResponses checks if responses qualify idempotence criteria. // // Operation is considered idempotent in one of two cases: -// * all responses have same status code (e.g. GET /resource: all 200 OK), -// * all responses but one have same status code (e.g. POST /resource: one 200 OK, many 409 Conflict). +// - all responses have same status code (e.g. GET /resource: all 200 OK), +// - all responses but one have same status code (e.g. POST /resource: one 200 OK, many 409 Conflict). // // Any other case is considered an idempotence violation. func (c *Client) checkResponses( @@ -348,27 +347,13 @@ func (c *Client) buildBody() io.Reader { c.reqMethod = http.MethodPost } - c.reqHeaders["Content-Type"] = "application/x-www-form-urlencoded" - return strings.NewReader(c.reqFormDataParams.Encode()) } return nil } -func (c *Client) doOnce() (*http.Response, error) { - uri, err := c.buildURI() - if err != nil { - return nil, err - } - - body := c.buildBody() - - req, err := http.NewRequestWithContext(c.ctx, c.reqMethod, uri, body) - if err != nil { - return nil, err - } - +func (c *Client) applyHeaders(req *http.Request) { for k, v := range c.Headers { req.Header.Set(k, v) } @@ -377,6 +362,12 @@ func (c *Client) doOnce() (*http.Response, error) { req.Header.Set(k, v) } + if len(c.reqFormDataParams) > 0 && req.Header.Get("Content-Type") == "" { + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } +} + +func (c *Client) applyCookies(req *http.Request) { cookies := make([]http.Cookie, 0, len(c.Cookies)+len(c.reqCookies)) for n, v := range c.Cookies { @@ -399,6 +390,23 @@ func (c *Client) doOnce() (*http.Response, error) { v := v req.AddCookie(&v) } +} + +func (c *Client) doOnce() (*http.Response, error) { + uri, err := c.buildURI() + if err != nil { + return nil, err + } + + body := c.buildBody() + + req, err := http.NewRequestWithContext(c.ctx, c.reqMethod, uri, body) + if err != nil { + return nil, err + } + + c.applyHeaders(req) + c.applyCookies(req) tr := c.Transport if tr == nil { @@ -407,7 +415,7 @@ func (c *Client) doOnce() (*http.Response, error) { if c.followRedirects { cl := http.Client{} - j, _ := cookiejar.New(nil) // nolint:errcheck // Error is always nil. + j, _ := cookiejar.New(nil) //nolint:errcheck // Error is always nil. cl.Transport = tr cl.Jar = j @@ -604,7 +612,7 @@ func (c *Client) checkBody(expected, received []byte) (err error) { } if !bytes.Equal(expected, received) { - return fmt.Errorf("%w, expected: %s, received: %s", + return fmt.Errorf("%w, expected: %q, received: %q", errUnexpectedBody, string(expected), string(received)) } diff --git a/client_test.go b/client_test.go index f91d580..c589f49 100644 --- a/client_test.go +++ b/client_test.go @@ -3,7 +3,7 @@ package httpmock_test import ( "context" "errors" - "io/ioutil" + "io" "net/http" "net/http/httptest" "sync/atomic" @@ -18,7 +18,7 @@ func TestNewClient(t *testing.T) { cnt := int64(0) srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { assert.Equal(t, "/foo?q=1", r.URL.String()) - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) assert.NoError(t, err) assert.Equal(t, `{"foo":"bar"}`, string(b)) assert.Equal(t, "application/json", r.Header.Get("Content-Type")) @@ -104,7 +104,7 @@ func TestNewClient_failedExpectation(t *testing.T) { c.WithURI("/") assert.EqualError(t, c.ExpectResponseBody([]byte(`{"foo":"bar}"`)), - "unexpected body, expected: {\"foo\":\"bar}\", received: {\"bar\":\"foo\"}") + "unexpected body, expected: \"{\\\"foo\\\":\\\"bar}\\\"\", received: \"{\\\"bar\\\":\\\"foo\\\"}\"") } func TestNewClient_followRedirects(t *testing.T) { @@ -180,5 +180,5 @@ func TestNewClient_formData(t *testing.T) { c.WithURLEncodedFormDataParam("qux", "quux") assert.EqualError(t, c.ExpectResponseBody([]byte(`{"foo":"bar}"`)), - "unexpected body, expected: {\"foo\":\"bar}\", received: {\"bar\":\"foo\"}") + "unexpected body, expected: \"{\\\"foo\\\":\\\"bar}\\\"\", received: \"{\\\"bar\\\":\\\"foo\\\"}\"") } diff --git a/go.mod b/go.mod index 2432f4e..ec3252a 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,9 @@ module github.com/bool64/httpmock go 1.17 require ( - github.com/bool64/dev v0.2.17 + github.com/bool64/dev v0.2.22 github.com/bool64/shared v0.1.5 - github.com/stretchr/testify v1.8.0 + github.com/stretchr/testify v1.8.1 github.com/swaggest/assertjson v1.7.0 ) diff --git a/go.sum b/go.sum index 91e0167..e25cff3 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,8 @@ github.com/bool64/dev v0.2.5/go.mod h1:cTHiTDNc8EewrQPy3p1obNilpMpdmlUesDkFTF2zRWU= github.com/bool64/dev v0.2.10/go.mod h1:/csLrm+4oDSsKJRIVS0mrywAonLnYKFG8RvGT7Jh9b8= -github.com/bool64/dev v0.2.17 h1:jE+T92oazAIV8fvMDJrKjsF1bzfr5XezZ8bM5GS1Cl0= github.com/bool64/dev v0.2.17/go.mod h1:iJbh1y/HkunEPhgebWRNcs8wfGq7sjvJ6W5iabL8ACg= +github.com/bool64/dev v0.2.22 h1:YJFKBRKplkt+0Emq/5Xk1Z5QRmMNzc1UOJkR3rxJksA= +github.com/bool64/dev v0.2.22/go.mod h1:iJbh1y/HkunEPhgebWRNcs8wfGq7sjvJ6W5iabL8ACg= github.com/bool64/shared v0.1.4/go.mod h1:ryGjsnQFh6BnEXClfVlEJrzjwzat7CmA8PNS5E+jPp0= github.com/bool64/shared v0.1.5 h1:fp3eUhBsrSjNCQPcSdQqZxxh9bBwrYiZ+zOKFkM0/2E= github.com/bool64/shared v0.1.5/go.mod h1:081yz68YC9jeFB3+Bbmno2RFWvGKv1lPKkMP6MHJlPs= @@ -53,10 +54,12 @@ github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/swaggest/assertjson v1.7.0 h1:SKw5Rn0LQs6UvmGrIdaKQbMR1R3ncXm5KNon+QJ7jtw= github.com/swaggest/assertjson v1.7.0/go.mod h1:vxMJMehbSVJd+dDWFCKv3QRZKNTpy/ktZKTz9LOEDng= github.com/yosuke-furukawa/json5 v0.1.2-0.20201207051438-cf7bb3f354ff h1:7YqG491bE4vstXRz1lD38rbSgbXnirvROz1lZiOnPO8= diff --git a/server.go b/server.go index b46119d..1be39a9 100755 --- a/server.go +++ b/server.go @@ -5,7 +5,7 @@ import ( "encoding/json" "errors" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httptest" "sync" @@ -191,7 +191,7 @@ func (sm *Server) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if len(sm.expectations) == 0 { - body, err := ioutil.ReadAll(req.Body) + body, err := io.ReadAll(req.Body) if err == nil && len(body) > 0 { if sm.OnBodyMismatch != nil { sm.OnBodyMismatch(body) @@ -285,7 +285,7 @@ func (sm *Server) checkRequest(req *http.Request, expectation Expectation) error } } - reqBody, err := ioutil.ReadAll(req.Body) + reqBody, err := io.ReadAll(req.Body) if err != nil { return err } diff --git a/server_test.go b/server_test.go index fbf38cb..b590e1c 100644 --- a/server_test.go +++ b/server_test.go @@ -3,7 +3,6 @@ package httpmock_test import ( "bytes" "io" - "io/ioutil" "net/http" "strings" "sync" @@ -38,7 +37,7 @@ func assertRoundTrip(t *testing.T, baseURL string, expectation httpmock.Expectat resp, err := http.DefaultTransport.RoundTrip(req) require.NoError(t, err) - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) require.NoError(t, resp.Body.Close()) require.NoError(t, err) @@ -161,7 +160,7 @@ func TestServer_ServeHTTP_error(t *testing.T) { resp, err := http.DefaultTransport.RoundTrip(req) require.NoError(t, err) - respBody, err := ioutil.ReadAll(resp.Body) + respBody, err := io.ReadAll(resp.Body) require.NoError(t, resp.Body.Close()) require.NoError(t, err) @@ -176,7 +175,7 @@ func TestServer_ServeHTTP_error(t *testing.T) { resp, err = http.DefaultTransport.RoundTrip(req) require.NoError(t, err) - respBody, err = ioutil.ReadAll(resp.Body) + respBody, err = io.ReadAll(resp.Body) require.NoError(t, resp.Body.Close()) require.NoError(t, err) @@ -191,7 +190,7 @@ func TestServer_ServeHTTP_error(t *testing.T) { resp, err = http.DefaultTransport.RoundTrip(req) require.NoError(t, err) - respBody, err = ioutil.ReadAll(resp.Body) + respBody, err = io.ReadAll(resp.Body) require.NoError(t, resp.Body.Close()) require.NoError(t, err) @@ -206,7 +205,7 @@ func TestServer_ServeHTTP_error(t *testing.T) { resp, err = http.DefaultTransport.RoundTrip(req) require.NoError(t, err) - respBody, err = ioutil.ReadAll(resp.Body) + respBody, err = io.ReadAll(resp.Body) require.NoError(t, resp.Body.Close()) require.NoError(t, err) @@ -250,7 +249,7 @@ func TestServer_ServeHTTP_concurrency(t *testing.T) { resp, err := http.DefaultTransport.RoundTrip(req) require.NoError(t, err) - respBody, err := ioutil.ReadAll(resp.Body) + respBody, err := io.ReadAll(resp.Body) require.NoError(t, resp.Body.Close()) require.NoError(t, err) @@ -301,7 +300,7 @@ func TestServer_vars(t *testing.T) { resp, err := http.DefaultTransport.RoundTrip(req) require.NoError(t, err) - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.NoError(t, resp.Body.Close()) @@ -340,7 +339,7 @@ func TestServer_ExpectAsync(t *testing.T) { resp, err := http.DefaultTransport.RoundTrip(req) require.NoError(t, err) - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.NoError(t, resp.Body.Close()) @@ -357,7 +356,7 @@ func TestServer_ExpectAsync(t *testing.T) { resp, err := http.DefaultTransport.RoundTrip(req) require.NoError(t, err) - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.NoError(t, resp.Body.Close()) @@ -371,7 +370,7 @@ func TestServer_ExpectAsync(t *testing.T) { resp, err := http.DefaultTransport.RoundTrip(req) require.NoError(t, err) - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.NoError(t, resp.Body.Close())