// Requests without Sec-Fetch-Site or Origin headers are currently assumed to be
// either same-origin or non-browser requests, and are allowed.
//
+// The zero value of CrossOriginProtection is valid and has no trusted origins
+// or bypass patterns.
+//
// [Sec-Fetch-Site]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site
// [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin
// [Cross-Site Request Forgery (CSRF)]: https://developer.mozilla.org/en-US/docs/Web/Security/Attacks/CSRF
// [safe methods]: https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP
type CrossOriginProtection struct {
- bypass *ServeMux
+ bypass atomic.Pointer[ServeMux]
trustedMu sync.RWMutex
trusted map[string]bool
deny atomic.Pointer[Handler]
// NewCrossOriginProtection returns a new [CrossOriginProtection] value.
func NewCrossOriginProtection() *CrossOriginProtection {
- return &CrossOriginProtection{
- bypass: NewServeMux(),
- trusted: make(map[string]bool),
- }
+ return &CrossOriginProtection{}
}
// AddTrustedOrigin allows all requests with an [Origin] header
}
c.trustedMu.Lock()
defer c.trustedMu.Unlock()
+ if c.trusted == nil {
+ c.trusted = make(map[string]bool)
+ }
c.trusted[origin] = true
return nil
}
// AddInsecureBypassPattern can be called concurrently with other methods
// or request handling, and applies to future requests.
func (c *CrossOriginProtection) AddInsecureBypassPattern(pattern string) {
- c.bypass.Handle(pattern, noopHandler)
+ var bypass *ServeMux
+
+ // Lazily initialize c.bypass
+ for {
+ bypass = c.bypass.Load()
+ if bypass != nil {
+ break
+ }
+ bypass = NewServeMux()
+ if c.bypass.CompareAndSwap(nil, bypass) {
+ break
+ }
+ }
+
+ bypass.Handle(pattern, noopHandler)
}
// SetDenyHandler sets a handler to invoke when a request is rejected.
// isRequestExempt checks the bypasses which require taking a lock, and should
// be deferred until the last moment.
func (c *CrossOriginProtection) isRequestExempt(req *Request) bool {
- if _, pattern := c.bypass.Handler(req); pattern != "" {
- // The request matches a bypass pattern.
- return true
+ if bypass := c.bypass.Load(); bypass != nil {
+ if _, pattern := bypass.Handler(req); pattern != "" {
+ // The request matches a bypass pattern.
+ return true
+ }
}
c.trustedMu.RLock()