]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add CrossOriginProtection
authorFilippo Valsorda <filippo@golang.org>
Wed, 21 May 2025 13:35:51 +0000 (15:35 +0200)
committerGopher Robot <gobot@golang.org>
Wed, 21 May 2025 20:22:27 +0000 (13:22 -0700)
Fixes #73626

Change-Id: I6a6a4656862e7a38acb65c4815fb7a1e04896172
Reviewed-on: https://go-review.googlesource.com/c/go/+/674936
Reviewed-by: Damien Neil <dneil@google.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: David Chase <drchase@google.com>
api/next/73626.txt [new file with mode: 0644]
doc/next/6-stdlib/99-minor/net/http/73626.md [new file with mode: 0644]
src/net/http/csrf.go [new file with mode: 0644]
src/net/http/csrf_test.go [new file with mode: 0644]

diff --git a/api/next/73626.txt b/api/next/73626.txt
new file mode 100644 (file)
index 0000000..ef4d068
--- /dev/null
@@ -0,0 +1,7 @@
+pkg net/http, func NewCrossOriginProtection() *CrossOriginProtection #73626
+pkg net/http, method (*CrossOriginProtection) AddInsecureBypassPattern(string) #73626
+pkg net/http, method (*CrossOriginProtection) AddTrustedOrigin(string) error #73626
+pkg net/http, method (*CrossOriginProtection) Check(*Request) error #73626
+pkg net/http, method (*CrossOriginProtection) Handler(Handler) Handler #73626
+pkg net/http, method (*CrossOriginProtection) SetDenyHandler(Handler) #73626
+pkg net/http, type CrossOriginProtection struct #73626
diff --git a/doc/next/6-stdlib/99-minor/net/http/73626.md b/doc/next/6-stdlib/99-minor/net/http/73626.md
new file mode 100644 (file)
index 0000000..88a204b
--- /dev/null
@@ -0,0 +1,7 @@
+The new [CrossOriginProtection] implements protections against [Cross-Site
+Request Forgery (CSRF)][] by rejecting non-safe cross-origin browser requests.
+It uses [modern browser Fetch metadata][Sec-Fetch-Site], doesn't require tokens
+or cookies, and supports origin-based and pattern-based bypasses.
+
+[Sec-Fetch-Site]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site
+[Cross-Site Request Forgery (CSRF)]: https://developer.mozilla.org/en-US/docs/Web/Security/Attacks/CSRF
diff --git a/src/net/http/csrf.go b/src/net/http/csrf.go
new file mode 100644 (file)
index 0000000..a46071f
--- /dev/null
@@ -0,0 +1,182 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+       "errors"
+       "fmt"
+       "net/url"
+       "sync"
+       "sync/atomic"
+)
+
+// CrossOriginProtection implements protections against [Cross-Site Request
+// Forgery (CSRF)] by rejecting non-safe cross-origin browser requests.
+//
+// Cross-origin requests are currently detected with the [Sec-Fetch-Site]
+// header, available in all browsers since 2023, or by comparing the hostname of
+// the [Origin] header with the Host header.
+//
+// The GET, HEAD, and OPTIONS methods are [safe methods] and are always allowed.
+// It's important that applications do not perform any state changing actions
+// due to requests with safe methods.
+//
+// Requests without Sec-Fetch-Site or Origin headers are currently assumed to be
+// either same-origin or non-browser requests, and are allowed.
+//
+// [Sec-Fetch-Site]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site
+// [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin
+// [Cross-Site Request Forgery (CSRF)]: https://developer.mozilla.org/en-US/docs/Web/Security/Attacks/CSRF
+// [safe methods]: https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP
+type CrossOriginProtection struct {
+       bypass    *ServeMux
+       trustedMu sync.RWMutex
+       trusted   map[string]bool
+       deny      atomic.Pointer[Handler]
+}
+
+// NewCrossOriginProtection returns a new [CrossOriginProtection] value.
+func NewCrossOriginProtection() *CrossOriginProtection {
+       return &CrossOriginProtection{
+               bypass:  NewServeMux(),
+               trusted: make(map[string]bool),
+       }
+}
+
+// AddTrustedOrigin allows all requests with an [Origin] header
+// which exactly matches the given value.
+//
+// Origin header values are of the form "scheme://host[:port]".
+//
+// AddTrustedOrigin can be called concurrently with other methods
+// or request handling, and applies to future requests.
+//
+// [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin
+func (c *CrossOriginProtection) AddTrustedOrigin(origin string) error {
+       u, err := url.Parse(origin)
+       if err != nil {
+               return fmt.Errorf("invalid origin %q: %w", origin, err)
+       }
+       if u.Scheme == "" {
+               return fmt.Errorf("invalid origin %q: scheme is required", origin)
+       }
+       if u.Host == "" {
+               return fmt.Errorf("invalid origin %q: host is required", origin)
+       }
+       if u.Path != "" || u.RawQuery != "" || u.Fragment != "" {
+               return fmt.Errorf("invalid origin %q: path, query, and fragment are not allowed", origin)
+       }
+       c.trustedMu.Lock()
+       defer c.trustedMu.Unlock()
+       c.trusted[origin] = true
+       return nil
+}
+
+var noopHandler = HandlerFunc(func(w ResponseWriter, r *Request) {})
+
+// AddInsecureBypassPattern permits all requests that match the given pattern.
+// The pattern syntax and precedence rules are the same as [ServeMux].
+//
+// AddInsecureBypassPattern can be called concurrently with other methods
+// or request handling, and applies to future requests.
+func (c *CrossOriginProtection) AddInsecureBypassPattern(pattern string) {
+       c.bypass.Handle(pattern, noopHandler)
+}
+
+// SetDenyHandler sets a handler to invoke when a request is rejected.
+// The default error handler responds with a 403 Forbidden status.
+//
+// SetDenyHandler can be called concurrently with other methods
+// or request handling, and applies to future requests.
+//
+// Check does not call the error handler.
+func (c *CrossOriginProtection) SetDenyHandler(h Handler) {
+       if h == nil {
+               c.deny.Store(nil)
+               return
+       }
+       c.deny.Store(&h)
+}
+
+// Check applies cross-origin checks to a request.
+// It returns an error if the request should be rejected.
+func (c *CrossOriginProtection) Check(req *Request) error {
+       switch req.Method {
+       case "GET", "HEAD", "OPTIONS":
+               // Safe methods are always allowed.
+               return nil
+       }
+
+       switch req.Header.Get("Sec-Fetch-Site") {
+       case "":
+               // No Sec-Fetch-Site header is present.
+               // Fallthrough to check the Origin header.
+       case "same-origin", "none":
+               return nil
+       default:
+               if c.isRequestExempt(req) {
+                       return nil
+               }
+               return errors.New("cross-origin request detected from Sec-Fetch-Site header")
+       }
+
+       origin := req.Header.Get("Origin")
+       if origin == "" {
+               // Neither Sec-Fetch-Site nor Origin headers are present.
+               // Either the request is same-origin or not a browser request.
+               return nil
+       }
+
+       if o, err := url.Parse(origin); err == nil && o.Host == req.Host {
+               // The Origin header matches the Host header. Note that the Host header
+               // doesn't include the scheme, so we don't know if this might be an
+               // HTTP→HTTPS cross-origin request. We fail open, since all modern
+               // browsers support Sec-Fetch-Site since 2023, and running an older
+               // browser makes a clear security trade-off already. Sites can mitigate
+               // this with HTTP Strict Transport Security (HSTS).
+               return nil
+       }
+
+       if c.isRequestExempt(req) {
+               return nil
+       }
+       return errors.New("cross-origin request detected, and/or browser is out of date: " +
+               "Sec-Fetch-Site is missing, and Origin does not match Host")
+}
+
+// isRequestExempt checks the bypasses which require taking a lock, and should
+// be deferred until the last moment.
+func (c *CrossOriginProtection) isRequestExempt(req *Request) bool {
+       if _, pattern := c.bypass.Handler(req); pattern != "" {
+               // The request matches a bypass pattern.
+               return true
+       }
+
+       c.trustedMu.RLock()
+       defer c.trustedMu.RUnlock()
+       origin := req.Header.Get("Origin")
+       // The request matches a trusted origin.
+       return origin != "" && c.trusted[origin]
+}
+
+// Handler returns a handler that applies cross-origin checks
+// before invoking the handler h.
+//
+// If a request fails cross-origin checks, the request is rejected
+// with a 403 Forbidden status or handled with the handler passed
+// to [CrossOriginProtection.SetDenyHandler].
+func (c *CrossOriginProtection) Handler(h Handler) Handler {
+       return HandlerFunc(func(w ResponseWriter, r *Request) {
+               if err := c.Check(r); err != nil {
+                       if deny := c.deny.Load(); deny != nil {
+                               (*deny).ServeHTTP(w, r)
+                               return
+                       }
+                       Error(w, err.Error(), StatusForbidden)
+                       return
+               }
+               h.ServeHTTP(w, r)
+       })
+}
diff --git a/src/net/http/csrf_test.go b/src/net/http/csrf_test.go
new file mode 100644 (file)
index 0000000..30986a4
--- /dev/null
@@ -0,0 +1,330 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+       "io"
+       "net/http"
+       "net/http/httptest"
+       "strings"
+       "testing"
+)
+
+// httptestNewRequest works around https://go.dev/issue/73151.
+func httptestNewRequest(method, target string) *http.Request {
+       req := httptest.NewRequest(method, target, nil)
+       req.URL.Scheme = ""
+       req.URL.Host = ""
+       return req
+}
+
+var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+       w.WriteHeader(http.StatusOK)
+})
+
+func TestCrossOriginProtectionSecFetchSite(t *testing.T) {
+       protection := http.NewCrossOriginProtection()
+       handler := protection.Handler(okHandler)
+
+       tests := []struct {
+               name           string
+               method         string
+               secFetchSite   string
+               origin         string
+               expectedStatus int
+       }{
+               {"same-origin allowed", "POST", "same-origin", "", http.StatusOK},
+               {"none allowed", "POST", "none", "", http.StatusOK},
+               {"cross-site blocked", "POST", "cross-site", "", http.StatusForbidden},
+               {"same-site blocked", "POST", "same-site", "", http.StatusForbidden},
+
+               {"no header with no origin", "POST", "", "", http.StatusOK},
+               {"no header with matching origin", "POST", "", "https://example.com", http.StatusOK},
+               {"no header with mismatched origin", "POST", "", "https://attacker.example", http.StatusForbidden},
+               {"no header with null origin", "POST", "", "null", http.StatusForbidden},
+
+               {"GET allowed", "GET", "cross-site", "", http.StatusOK},
+               {"HEAD allowed", "HEAD", "cross-site", "", http.StatusOK},
+               {"OPTIONS allowed", "OPTIONS", "cross-site", "", http.StatusOK},
+               {"PUT blocked", "PUT", "cross-site", "", http.StatusForbidden},
+       }
+
+       for _, tc := range tests {
+               t.Run(tc.name, func(t *testing.T) {
+                       req := httptestNewRequest(tc.method, "https://example.com/")
+                       if tc.secFetchSite != "" {
+                               req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
+                       }
+                       if tc.origin != "" {
+                               req.Header.Set("Origin", tc.origin)
+                       }
+
+                       w := httptest.NewRecorder()
+                       handler.ServeHTTP(w, req)
+
+                       if w.Code != tc.expectedStatus {
+                               t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
+                       }
+               })
+       }
+}
+
+func TestCrossOriginProtectionTrustedOriginBypass(t *testing.T) {
+       protection := http.NewCrossOriginProtection()
+       err := protection.AddTrustedOrigin("https://trusted.example")
+       if err != nil {
+               t.Fatalf("AddTrustedOrigin: %v", err)
+       }
+       handler := protection.Handler(okHandler)
+
+       tests := []struct {
+               name           string
+               origin         string
+               secFetchSite   string
+               expectedStatus int
+       }{
+               {"trusted origin without sec-fetch-site", "https://trusted.example", "", http.StatusOK},
+               {"trusted origin with cross-site", "https://trusted.example", "cross-site", http.StatusOK},
+               {"untrusted origin without sec-fetch-site", "https://attacker.example", "", http.StatusForbidden},
+               {"untrusted origin with cross-site", "https://attacker.example", "cross-site", http.StatusForbidden},
+       }
+
+       for _, tc := range tests {
+               t.Run(tc.name, func(t *testing.T) {
+                       req := httptestNewRequest("POST", "https://example.com/")
+                       req.Header.Set("Origin", tc.origin)
+                       if tc.secFetchSite != "" {
+                               req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
+                       }
+
+                       w := httptest.NewRecorder()
+                       handler.ServeHTTP(w, req)
+
+                       if w.Code != tc.expectedStatus {
+                               t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
+                       }
+               })
+       }
+}
+
+func TestCrossOriginProtectionPatternBypass(t *testing.T) {
+       protection := http.NewCrossOriginProtection()
+       protection.AddInsecureBypassPattern("/bypass/")
+       protection.AddInsecureBypassPattern("/only/{foo}")
+       handler := protection.Handler(okHandler)
+
+       tests := []struct {
+               name           string
+               path           string
+               secFetchSite   string
+               expectedStatus int
+       }{
+               {"bypass path without sec-fetch-site", "/bypass/", "", http.StatusOK},
+               {"bypass path with cross-site", "/bypass/", "cross-site", http.StatusOK},
+               {"non-bypass path without sec-fetch-site", "/api/", "", http.StatusForbidden},
+               {"non-bypass path with cross-site", "/api/", "cross-site", http.StatusForbidden},
+
+               {"redirect to bypass path without ..", "/foo/../bypass/bar", "", http.StatusOK},
+               {"redirect to bypass path with trailing slash", "/bypass", "", http.StatusOK},
+               {"redirect to non-bypass path with ..", "/foo/../api/bar", "", http.StatusForbidden},
+               {"redirect to non-bypass path with trailing slash", "/api", "", http.StatusForbidden},
+
+               {"wildcard bypass", "/only/123", "", http.StatusOK},
+               {"non-wildcard", "/only/123/foo", "", http.StatusForbidden},
+       }
+
+       for _, tc := range tests {
+               t.Run(tc.name, func(t *testing.T) {
+                       req := httptestNewRequest("POST", "https://example.com"+tc.path)
+                       req.Header.Set("Origin", "https://attacker.example")
+                       if tc.secFetchSite != "" {
+                               req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
+                       }
+
+                       w := httptest.NewRecorder()
+                       handler.ServeHTTP(w, req)
+
+                       if w.Code != tc.expectedStatus {
+                               t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
+                       }
+               })
+       }
+}
+
+func TestCrossOriginProtectionSetDenyHandler(t *testing.T) {
+       protection := http.NewCrossOriginProtection()
+
+       handler := protection.Handler(okHandler)
+
+       req := httptestNewRequest("POST", "https://example.com/")
+       req.Header.Set("Sec-Fetch-Site", "cross-site")
+
+       w := httptest.NewRecorder()
+       handler.ServeHTTP(w, req)
+
+       if w.Code != http.StatusForbidden {
+               t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
+       }
+
+       customErrHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               w.WriteHeader(http.StatusTeapot)
+               io.WriteString(w, "custom error")
+       })
+       protection.SetDenyHandler(customErrHandler)
+
+       w = httptest.NewRecorder()
+       handler.ServeHTTP(w, req)
+
+       if w.Code != http.StatusTeapot {
+               t.Errorf("got status %d, want %d", w.Code, http.StatusTeapot)
+       }
+
+       if !strings.Contains(w.Body.String(), "custom error") {
+               t.Errorf("expected custom error message, got: %q", w.Body.String())
+       }
+
+       req = httptestNewRequest("GET", "https://example.com/")
+
+       w = httptest.NewRecorder()
+       handler.ServeHTTP(w, req)
+
+       if w.Code != http.StatusOK {
+               t.Errorf("got status %d, want %d", w.Code, http.StatusOK)
+       }
+
+       protection.SetDenyHandler(nil)
+
+       req = httptestNewRequest("POST", "https://example.com/")
+       req.Header.Set("Sec-Fetch-Site", "cross-site")
+
+       w = httptest.NewRecorder()
+       handler.ServeHTTP(w, req)
+
+       if w.Code != http.StatusForbidden {
+               t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
+       }
+}
+
+func TestCrossOriginProtectionAddTrustedOriginErrors(t *testing.T) {
+       protection := http.NewCrossOriginProtection()
+
+       tests := []struct {
+               name    string
+               origin  string
+               wantErr bool
+       }{
+               {"valid origin", "https://example.com", false},
+               {"valid origin with port", "https://example.com:8080", false},
+               {"http origin", "http://example.com", false},
+               {"missing scheme", "example.com", true},
+               {"missing host", "https://", true},
+               {"trailing slash", "https://example.com/", true},
+               {"with path", "https://example.com/path", true},
+               {"with query", "https://example.com?query=value", true},
+               {"with fragment", "https://example.com#fragment", true},
+               {"invalid url", "https://ex ample.com", true},
+               {"empty string", "", true},
+               {"null", "null", true},
+       }
+
+       for _, tc := range tests {
+               t.Run(tc.name, func(t *testing.T) {
+                       err := protection.AddTrustedOrigin(tc.origin)
+                       if (err != nil) != tc.wantErr {
+                               t.Errorf("AddTrustedOrigin(%q) error = %v, wantErr %v", tc.origin, err, tc.wantErr)
+                       }
+               })
+       }
+}
+
+func TestCrossOriginProtectionAddingBypassesConcurrently(t *testing.T) {
+       protection := http.NewCrossOriginProtection()
+       handler := protection.Handler(okHandler)
+
+       req := httptestNewRequest("POST", "https://example.com/")
+       req.Header.Set("Origin", "https://concurrent.example")
+       req.Header.Set("Sec-Fetch-Site", "cross-site")
+
+       w := httptest.NewRecorder()
+       handler.ServeHTTP(w, req)
+
+       if w.Code != http.StatusForbidden {
+               t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
+       }
+
+       start := make(chan struct{})
+       done := make(chan struct{})
+       go func() {
+               close(start)
+               defer close(done)
+               for range 10 {
+                       w := httptest.NewRecorder()
+                       handler.ServeHTTP(w, req)
+               }
+       }()
+
+       // Add bypasses while the requests are in flight.
+       <-start
+       protection.AddTrustedOrigin("https://concurrent.example")
+       protection.AddInsecureBypassPattern("/foo/")
+       <-done
+
+       w = httptest.NewRecorder()
+       handler.ServeHTTP(w, req)
+
+       if w.Code != http.StatusOK {
+               t.Errorf("After concurrent bypass addition, got status %d, want %d", w.Code, http.StatusOK)
+       }
+}
+
+func TestCrossOriginProtectionServer(t *testing.T) {
+       protection := http.NewCrossOriginProtection()
+       protection.AddTrustedOrigin("https://trusted.example")
+       protection.AddInsecureBypassPattern("/bypass/")
+       handler := protection.Handler(okHandler)
+
+       ts := httptest.NewServer(handler)
+       defer ts.Close()
+
+       tests := []struct {
+               name           string
+               method         string
+               url            string
+               origin         string
+               secFetchSite   string
+               expectedStatus int
+       }{
+               {"cross-site", "POST", ts.URL, "https://attacker.example", "cross-site", http.StatusForbidden},
+               {"same-origin", "POST", ts.URL, "", "same-origin", http.StatusOK},
+               {"origin matches host", "POST", ts.URL, ts.URL, "", http.StatusOK},
+               {"trusted origin", "POST", ts.URL, "https://trusted.example", "", http.StatusOK},
+               {"untrusted origin", "POST", ts.URL, "https://attacker.example", "", http.StatusForbidden},
+               {"bypass path", "POST", ts.URL + "/bypass/", "https://attacker.example", "", http.StatusOK},
+       }
+
+       for _, tc := range tests {
+               t.Run(tc.name, func(t *testing.T) {
+                       req, err := http.NewRequest(tc.method, tc.url, nil)
+                       if err != nil {
+                               t.Fatalf("NewRequest: %v", err)
+                       }
+                       if tc.origin != "" {
+                               req.Header.Set("Origin", tc.origin)
+                       }
+                       if tc.secFetchSite != "" {
+                               req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
+                       }
+                       client := &http.Client{}
+                       resp, err := client.Do(req)
+                       if err != nil {
+                               t.Fatalf("Do: %v", err)
+                       }
+                       defer resp.Body.Close()
+                       if resp.StatusCode != tc.expectedStatus {
+                               t.Errorf("got status %d, want %d", resp.StatusCode, tc.expectedStatus)
+                       }
+               })
+       }
+}