From 5d39260079b5170e6b4263adb4022cc4b54153c4 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Thu, 8 Nov 2018 22:08:35 -0800 Subject: [PATCH] net: preserve unexpired context values for LookupIPAddr To avoid any cancelation of the parent context from affecting lookupGroup operations, Resolver.LookupIPAddr previously used an entirely new context created from context.Background(). However, this meant that all the values in the parent context with which LookupIPAddr was invoked were dropped. This change provides a custom context implementation that only preserves values of the parent context by composing context.Background() and the parent context. It only falls back to the parent context to perform value lookups if the parent context has not yet expired. This context is never canceled, and has no deadlines. Fixes #28600 Change-Id: If2f570caa26c65bad638b7102c35c79d5e429fea Reviewed-on: https://go-review.googlesource.com/c/148698 Run-TryBot: Emmanuel Odeke Reviewed-by: Brad Fitzpatrick --- src/net/lookup.go | 32 ++++++++++++++++-- src/net/lookup_test.go | 75 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 2 deletions(-) diff --git a/src/net/lookup.go b/src/net/lookup.go index cb810dea26..e10889331e 100644 --- a/src/net/lookup.go +++ b/src/net/lookup.go @@ -205,6 +205,33 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err return r.lookupIPAddr(ctx, "ip", host) } +// onlyValuesCtx is a context that uses an underlying context +// for value lookup if the underlying context hasn't yet expired. +type onlyValuesCtx struct { + context.Context + lookupValues context.Context +} + +var _ context.Context = (*onlyValuesCtx)(nil) + +// Value performs a lookup if the original context hasn't expired. +func (ovc *onlyValuesCtx) Value(key interface{}) interface{} { + select { + case <-ovc.lookupValues.Done(): + return nil + default: + return ovc.lookupValues.Value(key) + } +} + +// withUnexpiredValuesPreserved returns a context.Context that only uses lookupCtx +// for its values, otherwise it is never canceled and has no deadline. +// If the lookup context expires, any looked up values will return nil. +// See Issue 28600. +func withUnexpiredValuesPreserved(lookupCtx context.Context) context.Context { + return &onlyValuesCtx{Context: context.Background(), lookupValues: lookupCtx} +} + // lookupIPAddr looks up host using the local resolver and particular network. // It returns a slice of that host's IPv4 and IPv6 addresses. func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IPAddr, error) { @@ -231,8 +258,9 @@ func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IP // We don't want a cancelation of ctx to affect the // lookupGroup operation. Otherwise if our context gets // canceled it might cause an error to be returned to a lookup - // using a completely different context. - lookupGroupCtx, lookupGroupCancel := context.WithCancel(context.Background()) + // using a completely different context. However we need to preserve + // only the values in context. See Issue 28600. + lookupGroupCtx, lookupGroupCancel := context.WithCancel(withUnexpiredValuesPreserved(ctx)) dnsWaitGroup.Add(1) ch, called := r.getLookupGroup().DoChan(host, func() (interface{}, error) { diff --git a/src/net/lookup_test.go b/src/net/lookup_test.go index aeeda8f7d0..65daa76467 100644 --- a/src/net/lookup_test.go +++ b/src/net/lookup_test.go @@ -1034,3 +1034,78 @@ func TestIPVersion(t *testing.T) { } } } + +// Issue 28600: The context that is used to lookup ips should always +// preserve the values from the context that was passed into LookupIPAddr. +func TestLookupIPAddrPreservesContextValues(t *testing.T) { + origTestHookLookupIP := testHookLookupIP + defer func() { testHookLookupIP = origTestHookLookupIP }() + + keyValues := []struct { + key, value interface{} + }{ + {"key-1", 12}, + {384, "value2"}, + {new(float64), 137}, + } + ctx := context.Background() + for _, kv := range keyValues { + ctx = context.WithValue(ctx, kv.key, kv.value) + } + + wantIPs := []IPAddr{ + {IP: IPv4(127, 0, 0, 1)}, + {IP: IPv6loopback}, + } + + checkCtxValues := func(ctx_ context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) { + for _, kv := range keyValues { + g, w := ctx_.Value(kv.key), kv.value + if !reflect.DeepEqual(g, w) { + t.Errorf("Value lookup:\n\tGot: %v\n\tWant: %v", g, w) + } + } + return wantIPs, nil + } + testHookLookupIP = checkCtxValues + + resolvers := []*Resolver{ + nil, + new(Resolver), + } + + for i, resolver := range resolvers { + gotIPs, err := resolver.LookupIPAddr(ctx, "golang.org") + if err != nil { + t.Errorf("Resolver #%d: unexpected error: %v", i, err) + } + if !reflect.DeepEqual(gotIPs, wantIPs) { + t.Errorf("#%d: mismatched IPAddr results\n\tGot: %v\n\tWant: %v", i, gotIPs, wantIPs) + } + } +} + +func TestWithUnexpiredValuesPreserved(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + // Insert a value into it. + key, value := "key-1", 2 + ctx = context.WithValue(ctx, key, value) + + // Now use the "values preserving context" like + // we would for LookupIPAddr. See Issue 28600. + ctx = withUnexpiredValuesPreserved(ctx) + + // Lookup before expiry. + if g, w := ctx.Value(key), value; g != w { + t.Errorf("Lookup before expiry: Got %v Want %v", g, w) + } + + // Cancel the context. + cancel() + + // Lookup after expiry should return nil + if g := ctx.Value(key); g != nil { + t.Errorf("Lookup after expiry: Got %v want nil", g) + } +} -- 2.50.0