]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httptest: add NewRequestWithContext
authorKevin Burke <kevin@burke.dev>
Thu, 7 Dec 2023 21:13:25 +0000 (13:13 -0800)
committerDamien Neil <dneil@google.com>
Mon, 11 Mar 2024 18:09:14 +0000 (18:09 +0000)
This matches the net/http API.

Updates #59473.

Change-Id: I99917cef3ed42a0b4a2b39230b492be00da8bbfd
Reviewed-on: https://go-review.googlesource.com/c/go/+/548355
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
api/next/59473.txt [new file with mode: 0644]
doc/next/6-stdlib/99-minor/net/http/httptest/59473.md [new file with mode: 0644]
src/net/http/httptest/httptest.go
src/net/http/httptest/httptest_test.go

diff --git a/api/next/59473.txt b/api/next/59473.txt
new file mode 100644 (file)
index 0000000..da6902d
--- /dev/null
@@ -0,0 +1 @@
+pkg net/http/httptest, func NewRequestWithContext(context.Context, string, string, io.Reader) *http.Request #59473
diff --git a/doc/next/6-stdlib/99-minor/net/http/httptest/59473.md b/doc/next/6-stdlib/99-minor/net/http/httptest/59473.md
new file mode 100644 (file)
index 0000000..65cc607
--- /dev/null
@@ -0,0 +1,2 @@
+The new NewRequestWithContext method creates an incoming request with
+a Context.
index f0ca64362d7fbd8e25719b64132c423c7313bda2..0c0dbb40e89bc567a6b4bcac34f7622c4339284d 100644 (file)
@@ -8,13 +8,19 @@ package httptest
 import (
        "bufio"
        "bytes"
+       "context"
        "crypto/tls"
        "io"
        "net/http"
        "strings"
 )
 
-// NewRequest returns a new incoming server Request, suitable
+// NewRequest wraps NewRequestWithContext using context.Background.
+func NewRequest(method, target string, body io.Reader) *http.Request {
+       return NewRequestWithContext(context.Background(), method, target, body)
+}
+
+// NewRequestWithContext returns a new incoming server Request, suitable
 // for passing to an [http.Handler] for testing.
 //
 // The target is the RFC 7230 "request-target": it may be either a
@@ -37,7 +43,7 @@ import (
 //
 // To generate a client HTTP request instead of a server request, see
 // the NewRequest function in the net/http package.
-func NewRequest(method, target string, body io.Reader) *http.Request {
+func NewRequestWithContext(ctx context.Context, method, target string, body io.Reader) *http.Request {
        if method == "" {
                method = "GET"
        }
@@ -45,6 +51,7 @@ func NewRequest(method, target string, body io.Reader) *http.Request {
        if err != nil {
                panic("invalid NewRequest arguments; " + err.Error())
        }
+       req = req.WithContext(ctx)
 
        // HTTP/1.0 was used above to avoid needing a Host field. Change it to 1.1 here.
        req.Proto = "HTTP/1.1"
index 071add67ea6dbcf7a357a9a1ff63517187a31ee3..d5a4c3dc9d19a10b2a368f96f25d76f532a405f4 100644 (file)
@@ -5,6 +5,7 @@
 package httptest
 
 import (
+       "context"
        "crypto/tls"
        "io"
        "net/http"
@@ -15,6 +16,26 @@ import (
 )
 
 func TestNewRequest(t *testing.T) {
+       got := NewRequest("GET", "/", nil)
+       want := &http.Request{
+               Method:     "GET",
+               Host:       "example.com",
+               URL:        &url.URL{Path: "/"},
+               Header:     http.Header{},
+               Proto:      "HTTP/1.1",
+               ProtoMajor: 1,
+               ProtoMinor: 1,
+               RemoteAddr: "192.0.2.1:1234",
+               RequestURI: "/",
+       }
+       got.Body = nil // before DeepEqual
+       want = want.WithContext(context.Background())
+       if !reflect.DeepEqual(got, want) {
+               t.Errorf("Request mismatch:\n got: %#v\nwant: %#v", got, want)
+       }
+}
+
+func TestNewRequestWithContext(t *testing.T) {
        for _, tt := range [...]struct {
                name string
 
@@ -153,7 +174,7 @@ func TestNewRequest(t *testing.T) {
                },
        } {
                t.Run(tt.name, func(t *testing.T) {
-                       got := NewRequest(tt.method, tt.uri, tt.body)
+                       got := NewRequestWithContext(context.Background(), tt.method, tt.uri, tt.body)
                        slurp, err := io.ReadAll(got.Body)
                        if err != nil {
                                t.Errorf("ReadAll: %v", err)
@@ -161,6 +182,7 @@ func TestNewRequest(t *testing.T) {
                        if string(slurp) != tt.wantBody {
                                t.Errorf("Body = %q; want %q", slurp, tt.wantBody)
                        }
+                       tt.want = tt.want.WithContext(context.Background())
                        got.Body = nil // before DeepEqual
                        if !reflect.DeepEqual(got.URL, tt.want.URL) {
                                t.Errorf("Request.URL mismatch:\n got: %#v\nwant: %#v", got.URL, tt.want.URL)