g.mu.Unlock()
 }
 
-// Forget tells the singleflight to forget about a key.  Future calls
-// to Do for this key will call the function rather than waiting for
-// an earlier call to complete.
-func (g *Group) Forget(key string) {
+// ForgetUnshared tells the singleflight to forget about a key if it is not
+// shared with any other goroutines. Future calls to Do for a forgotten key
+// will call the function rather than waiting for an earlier call to complete.
+// Returns whether the key was forgotten or unknown--that is, whether no
+// other goroutines are waiting for the result.
+func (g *Group) ForgetUnshared(key string) bool {
        g.mu.Lock()
-       delete(g.m, key)
-       g.mu.Unlock()
+       defer g.mu.Unlock()
+       c, ok := g.m[key]
+       if !ok {
+               return true
+       }
+       if c.dups == 0 {
+               delete(g.m, key)
+               return true
+       }
+       return false
 }
 
                resolverFunc = alt
        }
 
+       // We don't want a cancelation of ctx to affect the
+       // lookupGroup operation. Otherwise if our context gets
+       // canceled it might cause an error to be returned to a lookup
+       // using a completely different context.
+       lookupGroupCtx, lookupGroupCancel := context.WithCancel(context.Background())
+
        dnsWaitGroup.Add(1)
        ch, called := lookupGroup.DoChan(host, func() (interface{}, error) {
                defer dnsWaitGroup.Done()
-               return testHookLookupIP(ctx, resolverFunc, host)
+               return testHookLookupIP(lookupGroupCtx, resolverFunc, host)
        })
        if !called {
                dnsWaitGroup.Done()
 
        select {
        case <-ctx.Done():
-               // If the DNS lookup timed out for some reason, force
-               // future requests to start the DNS lookup again
-               // rather than waiting for the current lookup to
-               // complete. See issue 8602.
-               ctxErr := ctx.Err()
-               if ctxErr == context.DeadlineExceeded {
-                       lookupGroup.Forget(host)
+               // Our context was canceled. If we are the only
+               // goroutine looking up this key, then drop the key
+               // from the lookupGroup and cancel the lookup.
+               // If there are other goroutines looking up this key,
+               // let the lookup continue uncanceled, and let later
+               // lookups with the same key share the result.
+               // See issues 8602, 20703, 22724.
+               if lookupGroup.ForgetUnshared(host) {
+                       lookupGroupCancel()
+               } else {
+                       go func() {
+                               <-ch
+                               lookupGroupCancel()
+                       }()
                }
-               err := mapErr(ctxErr)
+               err := mapErr(ctx.Err())
                if trace != nil && trace.DNSDone != nil {
                        trace.DNSDone(nil, false, err)
                }
                return nil, err
        case r := <-ch:
+               lookupGroupCancel()
                if trace != nil && trace.DNSDone != nil {
                        addrs, _ := r.Val.([]IPAddr)
                        trace.DNSDone(ipAddrsEface(addrs), r.Shared, r.Err)
 
                t.Fatalf("lookup error = %v, want %v", err, errNoSuchHost)
        }
 }
+
+func TestLookupContextCancel(t *testing.T) {
+       if testenv.Builder() == "" {
+               testenv.MustHaveExternalNetwork(t)
+       }
+       if runtime.GOOS == "nacl" {
+               t.Skip("skip on nacl")
+       }
+
+       defer dnsWaitGroup.Wait()
+
+       ctx, ctxCancel := context.WithCancel(context.Background())
+       ctxCancel()
+       _, err := DefaultResolver.LookupIPAddr(ctx, "google.com")
+       if err != errCanceled {
+               testenv.SkipFlakyNet(t)
+               t.Fatal(err)
+       }
+       ctx = context.Background()
+       _, err = DefaultResolver.LookupIPAddr(ctx, "google.com")
+       if err != nil {
+               testenv.SkipFlakyNet(t)
+               t.Fatal(err)
+       }
+}