import (
"bufio"
"bytes"
+ "compress/flate"
"compress/gzip"
"context"
"crypto/rand"
// lastHeaderStream is non-zero if the last frame was an
// unfinished HEADERS/CONTINUATION.
lastHeaderStream uint32
+ // lastFrameType holds the type of the last frame for verifying frame order.
+ lastFrameType http2FrameType
maxReadSize uint32
headerBuf [http2frameHeaderLen]byte
return err != nil
}
-// ReadFrame reads a single frame. The returned Frame is only valid
-// until the next call to ReadFrame.
+// ReadFrameHeader reads the header of the next frame.
+// It reads the 9-byte fixed frame header, and does not read any portion of the
+// frame payload. The caller is responsible for consuming the payload, either
+// with ReadFrameForHeader or directly from the Framer's io.Reader.
//
-// If the frame is larger than previously set with SetMaxReadFrameSize, the
-// returned error is ErrFrameTooLarge. Other errors may be of type
-// ConnectionError, StreamError, or anything else from the underlying
-// reader.
+// If the frame is larger than previously set with SetMaxReadFrameSize, it
+// returns the frame header and ErrFrameTooLarge.
//
-// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID
-// indicates the stream responsible for the error.
-func (fr *http2Framer) ReadFrame() (http2Frame, error) {
+// If the returned FrameHeader.StreamID is non-zero, it indicates the stream
+// responsible for the error.
+func (fr *http2Framer) ReadFrameHeader() (http2FrameHeader, error) {
fr.errDetail = nil
- if fr.lastFrame != nil {
- fr.lastFrame.invalidate()
- }
fh, err := http2readFrameHeader(fr.headerBuf[:], fr.r)
if err != nil {
- return nil, err
+ return fh, err
}
if fh.Length > fr.maxReadSize {
if fh == http2invalidHTTP1LookingFrameHeader() {
- return nil, fmt.Errorf("http2: failed reading the frame payload: %w, note that the frame header looked like an HTTP/1.1 header", http2ErrFrameTooLarge)
+ return fh, fmt.Errorf("http2: failed reading the frame payload: %w, note that the frame header looked like an HTTP/1.1 header", http2ErrFrameTooLarge)
}
- return nil, http2ErrFrameTooLarge
+ return fh, http2ErrFrameTooLarge
+ }
+ if err := fr.checkFrameOrder(fh); err != nil {
+ return fh, err
+ }
+ return fh, nil
+}
+
+// ReadFrameForHeader reads the payload for the frame with the given FrameHeader.
+//
+// It behaves identically to ReadFrame, other than not checking the maximum
+// frame size.
+func (fr *http2Framer) ReadFrameForHeader(fh http2FrameHeader) (http2Frame, error) {
+ if fr.lastFrame != nil {
+ fr.lastFrame.invalidate()
}
payload := fr.getReadBuf(fh.Length)
if _, err := io.ReadFull(fr.r, payload); err != nil {
}
return nil, err
}
- if err := fr.checkFrameOrder(f); err != nil {
- return nil, err
- }
+ fr.lastFrame = f
if fr.logReads {
fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f))
}
return f, nil
}
+// ReadFrame reads a single frame. The returned Frame is only valid
+// until the next call to ReadFrame or ReadFrameBodyForHeader.
+//
+// If the frame is larger than previously set with SetMaxReadFrameSize, the
+// returned error is ErrFrameTooLarge. Other errors may be of type
+// ConnectionError, StreamError, or anything else from the underlying
+// reader.
+//
+// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID
+// indicates the stream responsible for the error.
+func (fr *http2Framer) ReadFrame() (http2Frame, error) {
+ fh, err := fr.ReadFrameHeader()
+ if err != nil {
+ return nil, err
+ }
+ return fr.ReadFrameForHeader(fh)
+}
+
// connError returns ConnectionError(code) but first
// stashes away a public reason to the caller can optionally relay it
// to the peer before hanging up on them. This might help others debug
// checkFrameOrder reports an error if f is an invalid frame to return
// next from ReadFrame. Mostly it checks whether HEADERS and
// CONTINUATION frames are contiguous.
-func (fr *http2Framer) checkFrameOrder(f http2Frame) error {
- last := fr.lastFrame
- fr.lastFrame = f
+func (fr *http2Framer) checkFrameOrder(fh http2FrameHeader) error {
+ lastType := fr.lastFrameType
+ fr.lastFrameType = fh.Type
if fr.AllowIllegalReads {
return nil
}
- fh := f.Header()
if fr.lastHeaderStream != 0 {
if fh.Type != http2FrameContinuation {
return fr.connError(http2ErrCodeProtocol,
fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d",
fh.Type, fh.StreamID,
- last.Header().Type, fr.lastHeaderStream))
+ lastType, fr.lastHeaderStream))
}
if fh.StreamID != fr.lastHeaderStream {
return fr.connError(http2ErrCodeProtocol,
// PriorityParam struct below is a superset of both schemes. The exported
// symbols are from RFC 7540 and the non-exported ones are from RFC 9218.
-// PriorityParam are the stream prioritzation parameters.
+// PriorityParam are the stream prioritization parameters.
type http2PriorityParam struct {
// StreamDep is a 31-bit stream identifier for the
// stream that this stream depends on. Zero means no
// completely unresponsive connection.
pendingResets int
+ // readBeforeStreamID is the smallest stream ID that has not been followed by
+ // a frame read from the peer. We use this to determine when a request may
+ // have been sent to a completely unresponsive connection:
+ // If the request ID is less than readBeforeStreamID, then we have had some
+ // indication of life on the connection since sending the request.
+ readBeforeStreamID uint32
+
// reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests.
// Write to reqHeaderMu to lock it, read from it to unlock.
// Lock reqmu BEFORE mu or wmu.
reqHeaderMu chan struct{}
+ // internalStateHook reports state changes back to the net/http.ClientConn.
+ // Note that this is different from the user state hook registered by
+ // net/http.ClientConn.SetStateHook: The internal hook calls ClientConn,
+ // which calls the user hook.
+ internalStateHook func()
+
// wmu is held while writing.
// Acquire BEFORE mu when holding both, to avoid blocking mu on network writes.
// Only acquire both at the same time when changing peer settings.
func (t *http2Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*http2ClientConn, error) {
if t.http2transportTestHooks != nil {
- return t.newClientConn(nil, singleUse)
+ return t.newClientConn(nil, singleUse, nil)
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
if err != nil {
return nil, err
}
- return t.newClientConn(tconn, singleUse)
+ return t.newClientConn(tconn, singleUse, nil)
}
func (t *http2Transport) newTLSConfig(host string) *tls.Config {
}
func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
- return t.newClientConn(c, t.disableKeepAlives())
+ return t.newClientConn(c, t.disableKeepAlives(), nil)
}
-func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2ClientConn, error) {
+func (t *http2Transport) newClientConn(c net.Conn, singleUse bool, internalStateHook func()) (*http2ClientConn, error) {
conf := http2configFromTransport(t)
cc := &http2ClientConn{
t: t,
pings: make(map[[8]byte]chan struct{}),
reqHeaderMu: make(chan struct{}, 1),
lastActive: time.Now(),
+ internalStateHook: internalStateHook,
}
if t.http2transportTestHooks != nil {
t.http2transportTestHooks.newclientconn(cc)
maxConcurrentOkay = cc.currentRequestCountLocked() < int(cc.maxConcurrentStreams)
}
- st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay &&
- !cc.doNotReuse &&
- int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 &&
- !cc.tooIdleLocked()
+ st.canTakeNewRequest = maxConcurrentOkay && cc.isUsableLocked()
// If this connection has never been used for a request and is closed,
// then let it take a request (which will fail).
return
}
+func (cc *http2ClientConn) isUsableLocked() bool {
+ return cc.goAway == nil &&
+ !cc.closed &&
+ !cc.closing &&
+ !cc.doNotReuse &&
+ int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 &&
+ !cc.tooIdleLocked()
+}
+
+// canReserveLocked reports whether a net/http.ClientConn can reserve a slot on this conn.
+//
+// This follows slightly different rules than clientConnIdleState.canTakeNewRequest.
+// We only permit reservations up to the conn's concurrency limit.
+// This differs from ClientConn.ReserveNewRequest, which permits reservations
+// past the limit when StrictMaxConcurrentStreams is set.
+func (cc *http2ClientConn) canReserveLocked() bool {
+ if cc.currentRequestCountLocked() >= int(cc.maxConcurrentStreams) {
+ return false
+ }
+ if !cc.isUsableLocked() {
+ return false
+ }
+ return true
+}
+
// currentRequestCountLocked reports the number of concurrency slots currently in use,
// including active streams, reserved slots, and reset streams waiting for acknowledgement.
func (cc *http2ClientConn) currentRequestCountLocked() int {
return st.canTakeNewRequest
}
+// availableLocked reports the number of concurrency slots available.
+func (cc *http2ClientConn) availableLocked() int {
+ if !cc.canTakeNewRequestLocked() {
+ return 0
+ }
+ return max(0, int(cc.maxConcurrentStreams)-cc.currentRequestCountLocked())
+}
+
// tooIdleLocked reports whether this connection has been been sitting idle
// for too much wall time.
func (cc *http2ClientConn) tooIdleLocked() bool {
t := time.AfterFunc(250*time.Millisecond, cc.forceCloseConn)
defer t.Stop()
cc.tconn.Close()
+ cc.maybeCallStateHook()
}
// A tls.Conn.Close can hang for a long time if the peer is unresponsive.
}
bodyClosed := cs.reqBodyClosed
closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
+ // Have we read any frames from the connection since sending this request?
+ readSinceStream := cc.readBeforeStreamID > cs.ID
cc.mu.Unlock()
if mustCloseBody {
cs.reqBody.Close()
//
// This could be due to the server becoming unresponsive.
// To avoid sending too many requests on a dead connection,
- // we let the request continue to consume a concurrency slot
- // until we can confirm the server is still responding.
+ // if we haven't read any frames from the connection since
+ // sending this request, we let it continue to consume
+ // a concurrency slot until we can confirm the server is
+ // still responding.
// We do this by sending a PING frame along with the RST_STREAM
// (unless a ping is already in flight).
//
// because it's short lived and will probably be closed before
// we get the ping response.
ping := false
- if !closeOnIdle {
+ if !closeOnIdle && !readSinceStream {
cc.mu.Lock()
// rstStreamPingsBlocked works around a gRPC behavior:
// see comment on the field for details.
}
close(cs.donec)
+ cc.maybeCallStateHook()
}
// awaitOpenSlotForStreamLocked waits until len(streams) < maxConcurrentStreams.
// See comment on ClientConn.rstStreamPingsBlocked for details.
rl.cc.rstStreamPingsBlocked = false
}
+ rl.cc.readBeforeStreamID = rl.cc.nextStreamID
cs := rl.cc.streams[id]
if cs != nil && !cs.readAborted {
return cs
func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) error {
cc := rl.cc
+ defer cc.maybeCallStateHook()
cc.mu.Lock()
defer cc.mu.Unlock()
func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error {
if f.IsAck() {
cc := rl.cc
+ defer cc.maybeCallStateHook()
cc.mu.Lock()
defer cc.mu.Unlock()
// If ack, notify listener if any
func (rt http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { return nil, rt.err }
+var http2errConcurrentReadOnResBody = errors.New("http2: concurrent read on response body")
+
// gzipReader wraps a response body so it can lazily
-// call gzip.NewReader on the first call to Read
+// get gzip.Reader from the pool on the first call to Read.
+// After Close is called it puts gzip.Reader to the pool immediately
+// if there is no Read in progress or later when Read completes.
type http2gzipReader struct {
_ http2incomparable
body io.ReadCloser // underlying Response.Body
- zr *gzip.Reader // lazily-initialized gzip reader
- zerr error // sticky error
+ mu sync.Mutex // guards zr and zerr
+ zr *gzip.Reader // stores gzip reader from the pool between reads
+ zerr error // sticky gzip reader init error or sentinel value to detect concurrent read and read after close
}
-func (gz *http2gzipReader) Read(p []byte) (n int, err error) {
+type http2eofReader struct{}
+
+func (http2eofReader) Read([]byte) (int, error) { return 0, io.EOF }
+
+func (http2eofReader) ReadByte() (byte, error) { return 0, io.EOF }
+
+var http2gzipPool = sync.Pool{New: func() any { return new(gzip.Reader) }}
+
+// gzipPoolGet gets a gzip.Reader from the pool and resets it to read from r.
+func http2gzipPoolGet(r io.Reader) (*gzip.Reader, error) {
+ zr := http2gzipPool.Get().(*gzip.Reader)
+ if err := zr.Reset(r); err != nil {
+ http2gzipPoolPut(zr)
+ return nil, err
+ }
+ return zr, nil
+}
+
+// gzipPoolPut puts a gzip.Reader back into the pool.
+func http2gzipPoolPut(zr *gzip.Reader) {
+ // Reset will allocate bufio.Reader if we pass it anything
+ // other than a flate.Reader, so ensure that it's getting one.
+ var r flate.Reader = http2eofReader{}
+ zr.Reset(r)
+ http2gzipPool.Put(zr)
+}
+
+// acquire returns a gzip.Reader for reading response body.
+// The reader must be released after use.
+func (gz *http2gzipReader) acquire() (*gzip.Reader, error) {
+ gz.mu.Lock()
+ defer gz.mu.Unlock()
if gz.zerr != nil {
- return 0, gz.zerr
+ return nil, gz.zerr
}
if gz.zr == nil {
- gz.zr, err = gzip.NewReader(gz.body)
- if err != nil {
- gz.zerr = err
- return 0, err
+ gz.zr, gz.zerr = http2gzipPoolGet(gz.body)
+ if gz.zerr != nil {
+ return nil, gz.zerr
}
}
- return gz.zr.Read(p)
+ ret := gz.zr
+ gz.zr, gz.zerr = nil, http2errConcurrentReadOnResBody
+ return ret, nil
}
-func (gz *http2gzipReader) Close() error {
- if err := gz.body.Close(); err != nil {
- return err
+// release returns the gzip.Reader to the pool if Close was called during Read.
+func (gz *http2gzipReader) release(zr *gzip.Reader) {
+ gz.mu.Lock()
+ defer gz.mu.Unlock()
+ if gz.zerr == http2errConcurrentReadOnResBody {
+ gz.zr, gz.zerr = zr, nil
+ } else { // fs.ErrClosed
+ http2gzipPoolPut(zr)
+ }
+}
+
+// close returns the gzip.Reader to the pool immediately or
+// signals release to do so after Read completes.
+func (gz *http2gzipReader) close() {
+ gz.mu.Lock()
+ defer gz.mu.Unlock()
+ if gz.zerr == nil && gz.zr != nil {
+ http2gzipPoolPut(gz.zr)
+ gz.zr = nil
}
gz.zerr = fs.ErrClosed
- return nil
+}
+
+func (gz *http2gzipReader) Read(p []byte) (n int, err error) {
+ zr, err := gz.acquire()
+ if err != nil {
+ return 0, err
+ }
+ defer gz.release(zr)
+
+ return zr.Read(p)
+}
+
+func (gz *http2gzipReader) Close() error {
+ gz.close()
+
+ return gz.body.Close()
}
type http2errorReader struct{ err error }
}
// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
-// if there's already has a cached connection to the host.
+// if there's already a cached connection to the host.
// (The field is exported so it can be accessed via reflect from net/http; tested
// by TestNoDialH2RoundTripperType)
+//
+// A noDialH2RoundTripper is registered with http1.Transport.RegisterProtocol,
+// and the http1.Transport can use type assertions to call non-RoundTrip methods on it.
+// This lets us expose, for example, NewClientConn to net/http.
type http2noDialH2RoundTripper struct{ *http2Transport }
func (rt http2noDialH2RoundTripper) RoundTrip(req *Request) (*Response, error) {
return res, err
}
+func (rt http2noDialH2RoundTripper) NewClientConn(conn net.Conn, internalStateHook func()) (RoundTripper, error) {
+ tr := rt.http2Transport
+ cc, err := tr.newClientConn(conn, tr.disableKeepAlives(), internalStateHook)
+ if err != nil {
+ return nil, err
+ }
+
+ // RoundTrip should block when the conn is at its concurrency limit,
+ // not return an error. Setting strictMaxConcurrentStreams enables this.
+ cc.strictMaxConcurrentStreams = true
+
+ return http2netHTTPClientConn{cc}, nil
+}
+
+// netHTTPClientConn wraps ClientConn and implements the interface net/http expects from
+// the RoundTripper returned by NewClientConn.
+type http2netHTTPClientConn struct {
+ cc *http2ClientConn
+}
+
+func (cc http2netHTTPClientConn) RoundTrip(req *Request) (*Response, error) {
+ return cc.cc.RoundTrip(req)
+}
+
+func (cc http2netHTTPClientConn) Close() error {
+ return cc.cc.Close()
+}
+
+func (cc http2netHTTPClientConn) Err() error {
+ cc.cc.mu.Lock()
+ defer cc.cc.mu.Unlock()
+ if cc.cc.closed {
+ return errors.New("connection closed")
+ }
+ return nil
+}
+
+func (cc http2netHTTPClientConn) Reserve() error {
+ defer cc.cc.maybeCallStateHook()
+ cc.cc.mu.Lock()
+ defer cc.cc.mu.Unlock()
+ if !cc.cc.canReserveLocked() {
+ return errors.New("connection is unavailable")
+ }
+ cc.cc.streamsReserved++
+ return nil
+}
+
+func (cc http2netHTTPClientConn) Release() {
+ defer cc.cc.maybeCallStateHook()
+ cc.cc.mu.Lock()
+ defer cc.cc.mu.Unlock()
+ // We don't complain if streamsReserved is 0.
+ //
+ // This is consistent with RoundTrip: both Release and RoundTrip will
+ // consume a reservation iff one exists.
+ if cc.cc.streamsReserved > 0 {
+ cc.cc.streamsReserved--
+ }
+}
+
+func (cc http2netHTTPClientConn) Available() int {
+ cc.cc.mu.Lock()
+ defer cc.cc.mu.Unlock()
+ return cc.cc.availableLocked()
+}
+
+func (cc http2netHTTPClientConn) InFlight() int {
+ cc.cc.mu.Lock()
+ defer cc.cc.mu.Unlock()
+ return cc.cc.currentRequestCountLocked()
+}
+
+func (cc *http2ClientConn) maybeCallStateHook() {
+ if cc.internalStateHook != nil {
+ cc.internalStateHook()
+ }
+}
+
func (t *http2Transport) idleConnTimeout() time.Duration {
// to keep things backwards compatible, we use non-zero values of
// IdleConnTimeout, followed by using the IdleConnTimeout on the underlying
}
// writeQueue is used by implementations of WriteScheduler.
+//
+// Each writeQueue contains a queue of FrameWriteRequests, meant to store all
+// FrameWriteRequests associated with a given stream. This is implemented as a
+// two-stage queue: currQueue[currPos:] and nextQueue. Removing an item is done
+// by incrementing currPos of currQueue. Adding an item is done by appending it
+// to the nextQueue. If currQueue is empty when trying to remove an item, we
+// can swap currQueue and nextQueue to remedy the situation.
+// This two-stage queue is analogous to the use of two lists in Okasaki's
+// purely functional queue but without the overhead of reversing the list when
+// swapping stages.
+//
+// writeQueue also contains prev and next, this can be used by implementations
+// of WriteScheduler to construct data structures that represent the order of
+// writing between different streams (e.g. circular linked list).
type http2writeQueue struct {
- s []http2FrameWriteRequest
+ currQueue []http2FrameWriteRequest
+ nextQueue []http2FrameWriteRequest
+ currPos int
+
prev, next *http2writeQueue
}
-func (q *http2writeQueue) empty() bool { return len(q.s) == 0 }
+func (q *http2writeQueue) empty() bool {
+ return (len(q.currQueue) - q.currPos + len(q.nextQueue)) == 0
+}
func (q *http2writeQueue) push(wr http2FrameWriteRequest) {
- q.s = append(q.s, wr)
+ q.nextQueue = append(q.nextQueue, wr)
}
func (q *http2writeQueue) shift() http2FrameWriteRequest {
- if len(q.s) == 0 {
+ if q.empty() {
panic("invalid use of queue")
}
- wr := q.s[0]
- // TODO: less copy-happy queue.
- copy(q.s, q.s[1:])
- q.s[len(q.s)-1] = http2FrameWriteRequest{}
- q.s = q.s[:len(q.s)-1]
+ if q.currPos >= len(q.currQueue) {
+ q.currQueue, q.currPos, q.nextQueue = q.nextQueue, 0, q.currQueue[:0]
+ }
+ wr := q.currQueue[q.currPos]
+ q.currQueue[q.currPos] = http2FrameWriteRequest{}
+ q.currPos++
return wr
}
+func (q *http2writeQueue) peek() *http2FrameWriteRequest {
+ if q.currPos < len(q.currQueue) {
+ return &q.currQueue[q.currPos]
+ }
+ if len(q.nextQueue) > 0 {
+ return &q.nextQueue[0]
+ }
+ return nil
+}
+
// consume consumes up to n bytes from q.s[0]. If the frame is
// entirely consumed, it is removed from the queue. If the frame
// is partially consumed, the frame is kept with the consumed
// bytes removed. Returns true iff any bytes were consumed.
func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) {
- if len(q.s) == 0 {
+ if q.empty() {
return http2FrameWriteRequest{}, false
}
- consumed, rest, numresult := q.s[0].Consume(n)
+ consumed, rest, numresult := q.peek().Consume(n)
switch numresult {
case 0:
return http2FrameWriteRequest{}, false
case 1:
q.shift()
case 2:
- q.s[0] = rest
+ *q.peek() = rest
}
return consumed, true
}
// put inserts an unused writeQueue into the pool.
func (p *http2writeQueuePool) put(q *http2writeQueue) {
- for i := range q.s {
- q.s[i] = http2FrameWriteRequest{}
+ for i := range q.currQueue {
+ q.currQueue[i] = http2FrameWriteRequest{}
+ }
+ for i := range q.nextQueue {
+ q.nextQueue[i] = http2FrameWriteRequest{}
}
- q.s = q.s[:0]
+ q.currQueue = q.currQueue[:0]
+ q.nextQueue = q.nextQueue[:0]
+ q.currPos = 0
*p = append(*p, q)
}
func (z http2sortPriorityNodeSiblingsRFC7540) Less(i, k int) bool {
// Prefer the subtree that has sent fewer bytes relative to its weight.
// See sections 5.3.2 and 5.3.4.
- wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes)
- wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes)
+ wi, bi := float64(z[i].weight)+1, float64(z[i].subtreeBytes)
+ wk, bk := float64(z[k].weight)+1, float64(z[k].subtreeBytes)
if bi == 0 && bk == 0 {
return wi >= wk
}
q := n.q
ws.queuePool.put(&q)
- n.q.s = nil
if ws.maxClosedNodesInTree > 0 {
ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n)
} else {
prioritizeIncremental bool
}
-func http2newPriorityWriteSchedulerRFC9128() http2WriteScheduler {
+func http2newPriorityWriteSchedulerRFC9218() http2WriteScheduler {
ws := &http2priorityWriteSchedulerRFC9218{
streams: make(map[uint32]http2streamMetadata),
}
--- /dev/null
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dnsmessage
+
+import (
+ "slices"
+)
+
+// An SVCBResource is an SVCB Resource record.
+type SVCBResource struct {
+ Priority uint16
+ Target Name
+ Params []SVCParam // Must be in strict increasing order by Key.
+}
+
+func (r *SVCBResource) realType() Type {
+ return TypeSVCB
+}
+
+// GoString implements fmt.GoStringer.GoString.
+func (r *SVCBResource) GoString() string {
+ b := []byte("dnsmessage.SVCBResource{" +
+ "Priority: " + printUint16(r.Priority) + ", " +
+ "Target: " + r.Target.GoString() + ", " +
+ "Params: []dnsmessage.SVCParam{")
+ if len(r.Params) > 0 {
+ b = append(b, r.Params[0].GoString()...)
+ for _, p := range r.Params[1:] {
+ b = append(b, ", "+p.GoString()...)
+ }
+ }
+ b = append(b, "}}"...)
+ return string(b)
+}
+
+// An HTTPSResource is an HTTPS Resource record.
+// It has the same format as the SVCB record.
+type HTTPSResource struct {
+ // Alias for SVCB resource record.
+ SVCBResource
+}
+
+func (r *HTTPSResource) realType() Type {
+ return TypeHTTPS
+}
+
+// GoString implements fmt.GoStringer.GoString.
+func (r *HTTPSResource) GoString() string {
+ return "dnsmessage.HTTPSResource{SVCBResource: " + r.SVCBResource.GoString() + "}"
+}
+
+// GetParam returns a parameter value by key.
+func (r *SVCBResource) GetParam(key SVCParamKey) (value []byte, ok bool) {
+ for i := range r.Params {
+ if r.Params[i].Key == key {
+ return r.Params[i].Value, true
+ }
+ if r.Params[i].Key > key {
+ break
+ }
+ }
+ return nil, false
+}
+
+// SetParam sets a parameter value by key.
+// The Params list is kept sorted by key.
+func (r *SVCBResource) SetParam(key SVCParamKey, value []byte) {
+ i := 0
+ for i < len(r.Params) {
+ if r.Params[i].Key >= key {
+ break
+ }
+ i++
+ }
+
+ if i < len(r.Params) && r.Params[i].Key == key {
+ r.Params[i].Value = value
+ return
+ }
+
+ r.Params = slices.Insert(r.Params, i, SVCParam{Key: key, Value: value})
+}
+
+// DeleteParam deletes a parameter by key.
+// It returns true if the parameter was present.
+func (r *SVCBResource) DeleteParam(key SVCParamKey) bool {
+ for i := range r.Params {
+ if r.Params[i].Key == key {
+ r.Params = slices.Delete(r.Params, i, i+1)
+ return true
+ }
+ if r.Params[i].Key > key {
+ break
+ }
+ }
+ return false
+}
+
+// A SVCParam is a service parameter.
+type SVCParam struct {
+ Key SVCParamKey
+ Value []byte
+}
+
+// GoString implements fmt.GoStringer.GoString.
+func (p SVCParam) GoString() string {
+ return "dnsmessage.SVCParam{" +
+ "Key: " + p.Key.GoString() + ", " +
+ "Value: []byte{" + printByteSlice(p.Value) + "}}"
+}
+
+// A SVCParamKey is a key for a service parameter.
+type SVCParamKey uint16
+
+// Values defined at https://www.iana.org/assignments/dns-svcb/dns-svcb.xhtml#dns-svcparamkeys.
+const (
+ SVCParamMandatory SVCParamKey = 0
+ SVCParamALPN SVCParamKey = 1
+ SVCParamNoDefaultALPN SVCParamKey = 2
+ SVCParamPort SVCParamKey = 3
+ SVCParamIPv4Hint SVCParamKey = 4
+ SVCParamECH SVCParamKey = 5
+ SVCParamIPv6Hint SVCParamKey = 6
+ SVCParamDOHPath SVCParamKey = 7
+ SVCParamOHTTP SVCParamKey = 8
+ SVCParamTLSSupportedGroups SVCParamKey = 9
+)
+
+var svcParamKeyNames = map[SVCParamKey]string{
+ SVCParamMandatory: "Mandatory",
+ SVCParamALPN: "ALPN",
+ SVCParamNoDefaultALPN: "NoDefaultALPN",
+ SVCParamPort: "Port",
+ SVCParamIPv4Hint: "IPv4Hint",
+ SVCParamECH: "ECH",
+ SVCParamIPv6Hint: "IPv6Hint",
+ SVCParamDOHPath: "DOHPath",
+ SVCParamOHTTP: "OHTTP",
+ SVCParamTLSSupportedGroups: "TLSSupportedGroups",
+}
+
+// String implements fmt.Stringer.String.
+func (k SVCParamKey) String() string {
+ if n, ok := svcParamKeyNames[k]; ok {
+ return n
+ }
+ return printUint16(uint16(k))
+}
+
+// GoString implements fmt.GoStringer.GoString.
+func (k SVCParamKey) GoString() string {
+ if n, ok := svcParamKeyNames[k]; ok {
+ return "dnsmessage.SVCParam" + n
+ }
+ return printUint16(uint16(k))
+}
+
+func (r *SVCBResource) pack(msg []byte, _ map[string]uint16, _ int) ([]byte, error) {
+ oldMsg := msg
+ msg = packUint16(msg, r.Priority)
+ // https://datatracker.ietf.org/doc/html/rfc3597#section-4 prohibits name
+ // compression for RR types that are not "well-known".
+ // https://datatracker.ietf.org/doc/html/rfc9460#section-2.2 explicitly states that
+ // compression of the Target is prohibited, following RFC 3597.
+ msg, err := r.Target.pack(msg, nil, 0)
+ if err != nil {
+ return oldMsg, &nestedError{"SVCBResource.Target", err}
+ }
+ var previousKey SVCParamKey
+ for i, param := range r.Params {
+ if i > 0 && param.Key <= previousKey {
+ return oldMsg, &nestedError{"SVCBResource.Params", errParamOutOfOrder}
+ }
+ if len(param.Value) > (1<<16)-1 {
+ return oldMsg, &nestedError{"SVCBResource.Params", errTooLongSVCBValue}
+ }
+ msg = packUint16(msg, uint16(param.Key))
+ msg = packUint16(msg, uint16(len(param.Value)))
+ msg = append(msg, param.Value...)
+ }
+ return msg, nil
+}
+
+func unpackSVCBResource(msg []byte, off int, length uint16) (SVCBResource, error) {
+ // Wire format reference: https://www.rfc-editor.org/rfc/rfc9460.html#section-2.2.
+ r := SVCBResource{}
+ paramsOff := off
+ bodyEnd := off + int(length)
+
+ var err error
+ if r.Priority, paramsOff, err = unpackUint16(msg, paramsOff); err != nil {
+ return SVCBResource{}, &nestedError{"Priority", err}
+ }
+
+ if paramsOff, err = r.Target.unpack(msg, paramsOff); err != nil {
+ return SVCBResource{}, &nestedError{"Target", err}
+ }
+
+ // Two-pass parsing to avoid allocations.
+ // First, count the number of params.
+ n := 0
+ var totalValueLen uint16
+ off = paramsOff
+ var previousKey uint16
+ for off < bodyEnd {
+ var key, len uint16
+ if key, off, err = unpackUint16(msg, off); err != nil {
+ return SVCBResource{}, &nestedError{"Params key", err}
+ }
+ if n > 0 && key <= previousKey {
+ // As per https://www.rfc-editor.org/rfc/rfc9460.html#section-2.2, clients MUST
+ // consider the RR malformed if the SvcParamKeys are not in strictly increasing numeric order
+ return SVCBResource{}, &nestedError{"Params", errParamOutOfOrder}
+ }
+ if len, off, err = unpackUint16(msg, off); err != nil {
+ return SVCBResource{}, &nestedError{"Params value length", err}
+ }
+ if off+int(len) > bodyEnd {
+ return SVCBResource{}, errResourceLen
+ }
+ totalValueLen += len
+ off += int(len)
+ n++
+ }
+ if off != bodyEnd {
+ return SVCBResource{}, errResourceLen
+ }
+
+ // Second, fill in the params.
+ r.Params = make([]SVCParam, n)
+ // valuesBuf is used to hold all param values to reduce allocations.
+ // Each param's Value slice will point into this buffer.
+ valuesBuf := make([]byte, totalValueLen)
+ off = paramsOff
+ for i := 0; i < n; i++ {
+ p := &r.Params[i]
+ var key, len uint16
+ if key, off, err = unpackUint16(msg, off); err != nil {
+ return SVCBResource{}, &nestedError{"param key", err}
+ }
+ p.Key = SVCParamKey(key)
+ if len, off, err = unpackUint16(msg, off); err != nil {
+ return SVCBResource{}, &nestedError{"param length", err}
+ }
+ if copy(valuesBuf, msg[off:off+int(len)]) != int(len) {
+ return SVCBResource{}, &nestedError{"param value", errCalcLen}
+ }
+ p.Value = valuesBuf[:len:len]
+ valuesBuf = valuesBuf[len:]
+ off += int(len)
+ }
+
+ return r, nil
+}
+
+// genericSVCBResource parses a single Resource Record compatible with SVCB.
+func (p *Parser) genericSVCBResource(svcbType Type) (SVCBResource, error) {
+ if !p.resHeaderValid || p.resHeaderType != svcbType {
+ return SVCBResource{}, ErrNotStarted
+ }
+ r, err := unpackSVCBResource(p.msg, p.off, p.resHeaderLength)
+ if err != nil {
+ return SVCBResource{}, err
+ }
+ p.off += int(p.resHeaderLength)
+ p.resHeaderValid = false
+ p.index++
+ return r, nil
+}
+
+// SVCBResource parses a single SVCBResource.
+//
+// One of the XXXHeader methods must have been called before calling this
+// method.
+func (p *Parser) SVCBResource() (SVCBResource, error) {
+ return p.genericSVCBResource(TypeSVCB)
+}
+
+// HTTPSResource parses a single HTTPSResource.
+//
+// One of the XXXHeader methods must have been called before calling this
+// method.
+func (p *Parser) HTTPSResource() (HTTPSResource, error) {
+ svcb, err := p.genericSVCBResource(TypeHTTPS)
+ if err != nil {
+ return HTTPSResource{}, err
+ }
+ return HTTPSResource{svcb}, nil
+}
+
+// genericSVCBResource is the generic implementation for adding SVCB-like resources.
+func (b *Builder) genericSVCBResource(h ResourceHeader, r SVCBResource) error {
+ if err := b.checkResourceSection(); err != nil {
+ return err
+ }
+ msg, lenOff, err := h.pack(b.msg, b.compression, b.start)
+ if err != nil {
+ return &nestedError{"ResourceHeader", err}
+ }
+ preLen := len(msg)
+ if msg, err = r.pack(msg, b.compression, b.start); err != nil {
+ return &nestedError{"ResourceBody", err}
+ }
+ if err := h.fixLen(msg, lenOff, preLen); err != nil {
+ return err
+ }
+ if err := b.incrementSectionCount(); err != nil {
+ return err
+ }
+ b.msg = msg
+ return nil
+}
+
+// SVCBResource adds a single SVCBResource.
+func (b *Builder) SVCBResource(h ResourceHeader, r SVCBResource) error {
+ h.Type = r.realType()
+ return b.genericSVCBResource(h, r)
+}
+
+// HTTPSResource adds a single HTTPSResource.
+func (b *Builder) HTTPSResource(h ResourceHeader, r HTTPSResource) error {
+ h.Type = r.realType()
+ return b.genericSVCBResource(h, r.SVCBResource)
+}