From: Filippo Valsorda Date: Wed, 21 May 2025 13:35:51 +0000 (+0200) Subject: net/http: add CrossOriginProtection X-Git-Tag: go1.25rc1~116 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=1881d680b0b573c32d3002c37902760668ffec0f;p=gostls13.git net/http: add CrossOriginProtection Fixes #73626 Change-Id: I6a6a4656862e7a38acb65c4815fb7a1e04896172 Reviewed-on: https://go-review.googlesource.com/c/go/+/674936 Reviewed-by: Damien Neil Auto-Submit: Filippo Valsorda LUCI-TryBot-Result: Go LUCI Reviewed-by: David Chase --- diff --git a/api/next/73626.txt b/api/next/73626.txt new file mode 100644 index 0000000000..ef4d0683b2 --- /dev/null +++ b/api/next/73626.txt @@ -0,0 +1,7 @@ +pkg net/http, func NewCrossOriginProtection() *CrossOriginProtection #73626 +pkg net/http, method (*CrossOriginProtection) AddInsecureBypassPattern(string) #73626 +pkg net/http, method (*CrossOriginProtection) AddTrustedOrigin(string) error #73626 +pkg net/http, method (*CrossOriginProtection) Check(*Request) error #73626 +pkg net/http, method (*CrossOriginProtection) Handler(Handler) Handler #73626 +pkg net/http, method (*CrossOriginProtection) SetDenyHandler(Handler) #73626 +pkg net/http, type CrossOriginProtection struct #73626 diff --git a/doc/next/6-stdlib/99-minor/net/http/73626.md b/doc/next/6-stdlib/99-minor/net/http/73626.md new file mode 100644 index 0000000000..88a204b8a4 --- /dev/null +++ b/doc/next/6-stdlib/99-minor/net/http/73626.md @@ -0,0 +1,7 @@ +The new [CrossOriginProtection] implements protections against [Cross-Site +Request Forgery (CSRF)][] by rejecting non-safe cross-origin browser requests. +It uses [modern browser Fetch metadata][Sec-Fetch-Site], doesn't require tokens +or cookies, and supports origin-based and pattern-based bypasses. + +[Sec-Fetch-Site]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site +[Cross-Site Request Forgery (CSRF)]: https://developer.mozilla.org/en-US/docs/Web/Security/Attacks/CSRF diff --git a/src/net/http/csrf.go b/src/net/http/csrf.go new file mode 100644 index 0000000000..a46071f806 --- /dev/null +++ b/src/net/http/csrf.go @@ -0,0 +1,182 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "errors" + "fmt" + "net/url" + "sync" + "sync/atomic" +) + +// CrossOriginProtection implements protections against [Cross-Site Request +// Forgery (CSRF)] by rejecting non-safe cross-origin browser requests. +// +// Cross-origin requests are currently detected with the [Sec-Fetch-Site] +// header, available in all browsers since 2023, or by comparing the hostname of +// the [Origin] header with the Host header. +// +// The GET, HEAD, and OPTIONS methods are [safe methods] and are always allowed. +// It's important that applications do not perform any state changing actions +// due to requests with safe methods. +// +// Requests without Sec-Fetch-Site or Origin headers are currently assumed to be +// either same-origin or non-browser requests, and are allowed. +// +// [Sec-Fetch-Site]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site +// [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin +// [Cross-Site Request Forgery (CSRF)]: https://developer.mozilla.org/en-US/docs/Web/Security/Attacks/CSRF +// [safe methods]: https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP +type CrossOriginProtection struct { + bypass *ServeMux + trustedMu sync.RWMutex + trusted map[string]bool + deny atomic.Pointer[Handler] +} + +// NewCrossOriginProtection returns a new [CrossOriginProtection] value. +func NewCrossOriginProtection() *CrossOriginProtection { + return &CrossOriginProtection{ + bypass: NewServeMux(), + trusted: make(map[string]bool), + } +} + +// AddTrustedOrigin allows all requests with an [Origin] header +// which exactly matches the given value. +// +// Origin header values are of the form "scheme://host[:port]". +// +// AddTrustedOrigin can be called concurrently with other methods +// or request handling, and applies to future requests. +// +// [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin +func (c *CrossOriginProtection) AddTrustedOrigin(origin string) error { + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("invalid origin %q: %w", origin, err) + } + if u.Scheme == "" { + return fmt.Errorf("invalid origin %q: scheme is required", origin) + } + if u.Host == "" { + return fmt.Errorf("invalid origin %q: host is required", origin) + } + if u.Path != "" || u.RawQuery != "" || u.Fragment != "" { + return fmt.Errorf("invalid origin %q: path, query, and fragment are not allowed", origin) + } + c.trustedMu.Lock() + defer c.trustedMu.Unlock() + c.trusted[origin] = true + return nil +} + +var noopHandler = HandlerFunc(func(w ResponseWriter, r *Request) {}) + +// AddInsecureBypassPattern permits all requests that match the given pattern. +// The pattern syntax and precedence rules are the same as [ServeMux]. +// +// AddInsecureBypassPattern can be called concurrently with other methods +// or request handling, and applies to future requests. +func (c *CrossOriginProtection) AddInsecureBypassPattern(pattern string) { + c.bypass.Handle(pattern, noopHandler) +} + +// SetDenyHandler sets a handler to invoke when a request is rejected. +// The default error handler responds with a 403 Forbidden status. +// +// SetDenyHandler can be called concurrently with other methods +// or request handling, and applies to future requests. +// +// Check does not call the error handler. +func (c *CrossOriginProtection) SetDenyHandler(h Handler) { + if h == nil { + c.deny.Store(nil) + return + } + c.deny.Store(&h) +} + +// Check applies cross-origin checks to a request. +// It returns an error if the request should be rejected. +func (c *CrossOriginProtection) Check(req *Request) error { + switch req.Method { + case "GET", "HEAD", "OPTIONS": + // Safe methods are always allowed. + return nil + } + + switch req.Header.Get("Sec-Fetch-Site") { + case "": + // No Sec-Fetch-Site header is present. + // Fallthrough to check the Origin header. + case "same-origin", "none": + return nil + default: + if c.isRequestExempt(req) { + return nil + } + return errors.New("cross-origin request detected from Sec-Fetch-Site header") + } + + origin := req.Header.Get("Origin") + if origin == "" { + // Neither Sec-Fetch-Site nor Origin headers are present. + // Either the request is same-origin or not a browser request. + return nil + } + + if o, err := url.Parse(origin); err == nil && o.Host == req.Host { + // The Origin header matches the Host header. Note that the Host header + // doesn't include the scheme, so we don't know if this might be an + // HTTP→HTTPS cross-origin request. We fail open, since all modern + // browsers support Sec-Fetch-Site since 2023, and running an older + // browser makes a clear security trade-off already. Sites can mitigate + // this with HTTP Strict Transport Security (HSTS). + return nil + } + + if c.isRequestExempt(req) { + return nil + } + return errors.New("cross-origin request detected, and/or browser is out of date: " + + "Sec-Fetch-Site is missing, and Origin does not match Host") +} + +// isRequestExempt checks the bypasses which require taking a lock, and should +// be deferred until the last moment. +func (c *CrossOriginProtection) isRequestExempt(req *Request) bool { + if _, pattern := c.bypass.Handler(req); pattern != "" { + // The request matches a bypass pattern. + return true + } + + c.trustedMu.RLock() + defer c.trustedMu.RUnlock() + origin := req.Header.Get("Origin") + // The request matches a trusted origin. + return origin != "" && c.trusted[origin] +} + +// Handler returns a handler that applies cross-origin checks +// before invoking the handler h. +// +// If a request fails cross-origin checks, the request is rejected +// with a 403 Forbidden status or handled with the handler passed +// to [CrossOriginProtection.SetDenyHandler]. +func (c *CrossOriginProtection) Handler(h Handler) Handler { + return HandlerFunc(func(w ResponseWriter, r *Request) { + if err := c.Check(r); err != nil { + if deny := c.deny.Load(); deny != nil { + (*deny).ServeHTTP(w, r) + return + } + Error(w, err.Error(), StatusForbidden) + return + } + h.ServeHTTP(w, r) + }) +} diff --git a/src/net/http/csrf_test.go b/src/net/http/csrf_test.go new file mode 100644 index 0000000000..30986a43b9 --- /dev/null +++ b/src/net/http/csrf_test.go @@ -0,0 +1,330 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// httptestNewRequest works around https://go.dev/issue/73151. +func httptestNewRequest(method, target string) *http.Request { + req := httptest.NewRequest(method, target, nil) + req.URL.Scheme = "" + req.URL.Host = "" + return req +} + +var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) +}) + +func TestCrossOriginProtectionSecFetchSite(t *testing.T) { + protection := http.NewCrossOriginProtection() + handler := protection.Handler(okHandler) + + tests := []struct { + name string + method string + secFetchSite string + origin string + expectedStatus int + }{ + {"same-origin allowed", "POST", "same-origin", "", http.StatusOK}, + {"none allowed", "POST", "none", "", http.StatusOK}, + {"cross-site blocked", "POST", "cross-site", "", http.StatusForbidden}, + {"same-site blocked", "POST", "same-site", "", http.StatusForbidden}, + + {"no header with no origin", "POST", "", "", http.StatusOK}, + {"no header with matching origin", "POST", "", "https://example.com", http.StatusOK}, + {"no header with mismatched origin", "POST", "", "https://attacker.example", http.StatusForbidden}, + {"no header with null origin", "POST", "", "null", http.StatusForbidden}, + + {"GET allowed", "GET", "cross-site", "", http.StatusOK}, + {"HEAD allowed", "HEAD", "cross-site", "", http.StatusOK}, + {"OPTIONS allowed", "OPTIONS", "cross-site", "", http.StatusOK}, + {"PUT blocked", "PUT", "cross-site", "", http.StatusForbidden}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptestNewRequest(tc.method, "https://example.com/") + if tc.secFetchSite != "" { + req.Header.Set("Sec-Fetch-Site", tc.secFetchSite) + } + if tc.origin != "" { + req.Header.Set("Origin", tc.origin) + } + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != tc.expectedStatus { + t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus) + } + }) + } +} + +func TestCrossOriginProtectionTrustedOriginBypass(t *testing.T) { + protection := http.NewCrossOriginProtection() + err := protection.AddTrustedOrigin("https://trusted.example") + if err != nil { + t.Fatalf("AddTrustedOrigin: %v", err) + } + handler := protection.Handler(okHandler) + + tests := []struct { + name string + origin string + secFetchSite string + expectedStatus int + }{ + {"trusted origin without sec-fetch-site", "https://trusted.example", "", http.StatusOK}, + {"trusted origin with cross-site", "https://trusted.example", "cross-site", http.StatusOK}, + {"untrusted origin without sec-fetch-site", "https://attacker.example", "", http.StatusForbidden}, + {"untrusted origin with cross-site", "https://attacker.example", "cross-site", http.StatusForbidden}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptestNewRequest("POST", "https://example.com/") + req.Header.Set("Origin", tc.origin) + if tc.secFetchSite != "" { + req.Header.Set("Sec-Fetch-Site", tc.secFetchSite) + } + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != tc.expectedStatus { + t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus) + } + }) + } +} + +func TestCrossOriginProtectionPatternBypass(t *testing.T) { + protection := http.NewCrossOriginProtection() + protection.AddInsecureBypassPattern("/bypass/") + protection.AddInsecureBypassPattern("/only/{foo}") + handler := protection.Handler(okHandler) + + tests := []struct { + name string + path string + secFetchSite string + expectedStatus int + }{ + {"bypass path without sec-fetch-site", "/bypass/", "", http.StatusOK}, + {"bypass path with cross-site", "/bypass/", "cross-site", http.StatusOK}, + {"non-bypass path without sec-fetch-site", "/api/", "", http.StatusForbidden}, + {"non-bypass path with cross-site", "/api/", "cross-site", http.StatusForbidden}, + + {"redirect to bypass path without ..", "/foo/../bypass/bar", "", http.StatusOK}, + {"redirect to bypass path with trailing slash", "/bypass", "", http.StatusOK}, + {"redirect to non-bypass path with ..", "/foo/../api/bar", "", http.StatusForbidden}, + {"redirect to non-bypass path with trailing slash", "/api", "", http.StatusForbidden}, + + {"wildcard bypass", "/only/123", "", http.StatusOK}, + {"non-wildcard", "/only/123/foo", "", http.StatusForbidden}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptestNewRequest("POST", "https://example.com"+tc.path) + req.Header.Set("Origin", "https://attacker.example") + if tc.secFetchSite != "" { + req.Header.Set("Sec-Fetch-Site", tc.secFetchSite) + } + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != tc.expectedStatus { + t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus) + } + }) + } +} + +func TestCrossOriginProtectionSetDenyHandler(t *testing.T) { + protection := http.NewCrossOriginProtection() + + handler := protection.Handler(okHandler) + + req := httptestNewRequest("POST", "https://example.com/") + req.Header.Set("Sec-Fetch-Site", "cross-site") + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden) + } + + customErrHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + io.WriteString(w, "custom error") + }) + protection.SetDenyHandler(customErrHandler) + + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusTeapot { + t.Errorf("got status %d, want %d", w.Code, http.StatusTeapot) + } + + if !strings.Contains(w.Body.String(), "custom error") { + t.Errorf("expected custom error message, got: %q", w.Body.String()) + } + + req = httptestNewRequest("GET", "https://example.com/") + + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("got status %d, want %d", w.Code, http.StatusOK) + } + + protection.SetDenyHandler(nil) + + req = httptestNewRequest("POST", "https://example.com/") + req.Header.Set("Sec-Fetch-Site", "cross-site") + + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden) + } +} + +func TestCrossOriginProtectionAddTrustedOriginErrors(t *testing.T) { + protection := http.NewCrossOriginProtection() + + tests := []struct { + name string + origin string + wantErr bool + }{ + {"valid origin", "https://example.com", false}, + {"valid origin with port", "https://example.com:8080", false}, + {"http origin", "http://example.com", false}, + {"missing scheme", "example.com", true}, + {"missing host", "https://", true}, + {"trailing slash", "https://example.com/", true}, + {"with path", "https://example.com/path", true}, + {"with query", "https://example.com?query=value", true}, + {"with fragment", "https://example.com#fragment", true}, + {"invalid url", "https://ex ample.com", true}, + {"empty string", "", true}, + {"null", "null", true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := protection.AddTrustedOrigin(tc.origin) + if (err != nil) != tc.wantErr { + t.Errorf("AddTrustedOrigin(%q) error = %v, wantErr %v", tc.origin, err, tc.wantErr) + } + }) + } +} + +func TestCrossOriginProtectionAddingBypassesConcurrently(t *testing.T) { + protection := http.NewCrossOriginProtection() + handler := protection.Handler(okHandler) + + req := httptestNewRequest("POST", "https://example.com/") + req.Header.Set("Origin", "https://concurrent.example") + req.Header.Set("Sec-Fetch-Site", "cross-site") + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden) + } + + start := make(chan struct{}) + done := make(chan struct{}) + go func() { + close(start) + defer close(done) + for range 10 { + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + } + }() + + // Add bypasses while the requests are in flight. + <-start + protection.AddTrustedOrigin("https://concurrent.example") + protection.AddInsecureBypassPattern("/foo/") + <-done + + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("After concurrent bypass addition, got status %d, want %d", w.Code, http.StatusOK) + } +} + +func TestCrossOriginProtectionServer(t *testing.T) { + protection := http.NewCrossOriginProtection() + protection.AddTrustedOrigin("https://trusted.example") + protection.AddInsecureBypassPattern("/bypass/") + handler := protection.Handler(okHandler) + + ts := httptest.NewServer(handler) + defer ts.Close() + + tests := []struct { + name string + method string + url string + origin string + secFetchSite string + expectedStatus int + }{ + {"cross-site", "POST", ts.URL, "https://attacker.example", "cross-site", http.StatusForbidden}, + {"same-origin", "POST", ts.URL, "", "same-origin", http.StatusOK}, + {"origin matches host", "POST", ts.URL, ts.URL, "", http.StatusOK}, + {"trusted origin", "POST", ts.URL, "https://trusted.example", "", http.StatusOK}, + {"untrusted origin", "POST", ts.URL, "https://attacker.example", "", http.StatusForbidden}, + {"bypass path", "POST", ts.URL + "/bypass/", "https://attacker.example", "", http.StatusOK}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest(tc.method, tc.url, nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + if tc.origin != "" { + req.Header.Set("Origin", tc.origin) + } + if tc.secFetchSite != "" { + req.Header.Set("Sec-Fetch-Site", tc.secFetchSite) + } + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Do: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != tc.expectedStatus { + t.Errorf("got status %d, want %d", resp.StatusCode, tc.expectedStatus) + } + }) + } +}