import (
"bytes"
"compress/gzip"
+ "crypto/rand"
+ "crypto/sha1"
"crypto/tls"
"fmt"
+ "hash"
"io"
"io/ioutil"
"log"
t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
}
}
+
+func TestBidiStreamReverseProxy(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ backend := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if _, err := io.Copy(w, r.Body); err != nil {
+ log.Printf("bidi backend copy: %v", err)
+ }
+ }))
+ defer backend.close()
+
+ backURL, err := url.Parse(backend.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rp := httputil.NewSingleHostReverseProxy(backURL)
+ rp.Transport = backend.tr
+ proxy := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ rp.ServeHTTP(w, r)
+ }))
+ defer proxy.close()
+
+ bodyRes := make(chan interface{}, 1) // error or hash.Hash
+ pr, pw := io.Pipe()
+ req, _ := NewRequest("PUT", proxy.ts.URL, pr)
+ const size = 4 << 20
+ go func() {
+ h := sha1.New()
+ _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
+ go pw.Close()
+ if err != nil {
+ bodyRes <- err
+ } else {
+ bodyRes <- h
+ }
+ }()
+ res, err := backend.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ hgot := sha1.New()
+ n, err := io.Copy(hgot, res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != size {
+ t.Fatalf("got %d bytes; want %d", n, size)
+ }
+ select {
+ case v := <-bodyRes:
+ switch v := v.(type) {
+ default:
+ t.Fatalf("body copy: %v", err)
+ case hash.Hash:
+ if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
+ t.Errorf("written bytes didn't match received bytes")
+ }
+ }
+ case <-time.After(10 * time.Second):
+ t.Fatal("timeout")
+ }
+
+}