return r.lookupIPAddr(ctx, "ip", host)
}
+// onlyValuesCtx is a context that uses an underlying context
+// for value lookup if the underlying context hasn't yet expired.
+type onlyValuesCtx struct {
+ context.Context
+ lookupValues context.Context
+}
+
+var _ context.Context = (*onlyValuesCtx)(nil)
+
+// Value performs a lookup if the original context hasn't expired.
+func (ovc *onlyValuesCtx) Value(key interface{}) interface{} {
+ select {
+ case <-ovc.lookupValues.Done():
+ return nil
+ default:
+ return ovc.lookupValues.Value(key)
+ }
+}
+
+// withUnexpiredValuesPreserved returns a context.Context that only uses lookupCtx
+// for its values, otherwise it is never canceled and has no deadline.
+// If the lookup context expires, any looked up values will return nil.
+// See Issue 28600.
+func withUnexpiredValuesPreserved(lookupCtx context.Context) context.Context {
+ return &onlyValuesCtx{Context: context.Background(), lookupValues: lookupCtx}
+}
+
// lookupIPAddr looks up host using the local resolver and particular network.
// It returns a slice of that host's IPv4 and IPv6 addresses.
func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IPAddr, error) {
// We don't want a cancelation of ctx to affect the
// lookupGroup operation. Otherwise if our context gets
// canceled it might cause an error to be returned to a lookup
- // using a completely different context.
- lookupGroupCtx, lookupGroupCancel := context.WithCancel(context.Background())
+ // using a completely different context. However we need to preserve
+ // only the values in context. See Issue 28600.
+ lookupGroupCtx, lookupGroupCancel := context.WithCancel(withUnexpiredValuesPreserved(ctx))
dnsWaitGroup.Add(1)
ch, called := r.getLookupGroup().DoChan(host, func() (interface{}, error) {
}
}
}
+
+// Issue 28600: The context that is used to lookup ips should always
+// preserve the values from the context that was passed into LookupIPAddr.
+func TestLookupIPAddrPreservesContextValues(t *testing.T) {
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+
+ keyValues := []struct {
+ key, value interface{}
+ }{
+ {"key-1", 12},
+ {384, "value2"},
+ {new(float64), 137},
+ }
+ ctx := context.Background()
+ for _, kv := range keyValues {
+ ctx = context.WithValue(ctx, kv.key, kv.value)
+ }
+
+ wantIPs := []IPAddr{
+ {IP: IPv4(127, 0, 0, 1)},
+ {IP: IPv6loopback},
+ }
+
+ checkCtxValues := func(ctx_ context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
+ for _, kv := range keyValues {
+ g, w := ctx_.Value(kv.key), kv.value
+ if !reflect.DeepEqual(g, w) {
+ t.Errorf("Value lookup:\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ }
+ return wantIPs, nil
+ }
+ testHookLookupIP = checkCtxValues
+
+ resolvers := []*Resolver{
+ nil,
+ new(Resolver),
+ }
+
+ for i, resolver := range resolvers {
+ gotIPs, err := resolver.LookupIPAddr(ctx, "golang.org")
+ if err != nil {
+ t.Errorf("Resolver #%d: unexpected error: %v", i, err)
+ }
+ if !reflect.DeepEqual(gotIPs, wantIPs) {
+ t.Errorf("#%d: mismatched IPAddr results\n\tGot: %v\n\tWant: %v", i, gotIPs, wantIPs)
+ }
+ }
+}
+
+func TestWithUnexpiredValuesPreserved(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+
+ // Insert a value into it.
+ key, value := "key-1", 2
+ ctx = context.WithValue(ctx, key, value)
+
+ // Now use the "values preserving context" like
+ // we would for LookupIPAddr. See Issue 28600.
+ ctx = withUnexpiredValuesPreserved(ctx)
+
+ // Lookup before expiry.
+ if g, w := ctx.Value(key), value; g != w {
+ t.Errorf("Lookup before expiry: Got %v Want %v", g, w)
+ }
+
+ // Cancel the context.
+ cancel()
+
+ // Lookup after expiry should return nil
+ if g := ctx.Value(key); g != nil {
+ t.Errorf("Lookup after expiry: Got %v want nil", g)
+ }
+}