Merge pull request #1 from mmatczuk/mma/unit_subtest

test: unit test refactored to promote test separation using subtests
This commit is contained in:
Michał Matczuk 2016-11-03 13:52:42 +01:00 committed by GitHub
commit 92a10b3e1d

View file

@ -1,14 +1,16 @@
// Package h2tun_test implements integration test that ensures that client-server
// connectivity works properly.
package h2tun_test
import (
"bytes"
"crypto/tls"
"crypto/x509"
"fmt"
"math/rand"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
@ -18,23 +20,33 @@ import (
)
const (
payloadInitialSize = 16
payloadInitialSize = 32
payloadLen = 10
)
var payload = randPayload(payloadInitialSize, payloadLen)
// testContext stores state shared between sub tests.
type testContext struct {
// handler is entry point for HTTP tests.
handler http.Handler
// listener is entry point for TCP tests.
listener net.Listener
// payload is pre generated random data.
payload [][]byte
}
func TestProxy(t *testing.T) {
t.Parallel()
var ctx testContext
func TestMain(m *testing.M) {
cert, id := selfSignedCert()
// prepare TCP listener
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Error("Listen failed", err)
panic(err)
}
defer l.Close()
// prepare h2tun server
s, err := h2tun.NewServer(&h2tun.ServerConfig{
TLSConfig: h2tuntest.TLSConfig(cert),
AllowedClients: []*h2tun.AllowedClient{
@ -46,59 +58,87 @@ func TestProxy(t *testing.T) {
},
})
if err != nil {
t.Error("Server creation failed", err)
panic(err)
}
s.Start()
defer s.Close()
// prepare h2tun client
c := h2tun.NewClient(&h2tun.ClientConfig{
ServerAddr: s.Addr(),
TLSClientConfig: h2tuntest.TLSConfig(cert),
Proxy: h2tuntest.EchoProxyFunc,
})
if err := c.Start(); err != nil {
t.Error("Client start failed", err)
panic(err)
}
defer c.Stop()
ctx.handler = s
ctx.listener = l
ctx.payload = randPayload(payloadInitialSize, payloadLen)
m.Run()
}
func randPayload(initialSize, n int) [][]byte {
payload := make([][]byte, n)
l := initialSize
for i := 0; i < n; i++ {
payload[i] = randBytes(l)
l *= 2
}
return payload
}
func randBytes(n int) []byte {
b := make([]byte, n)
read, err := rand.Read(b)
if err != nil {
panic(err)
}
if read != n {
panic("Read did not fill whole slice")
}
return b
}
func TestProxying(t *testing.T) {
data := []struct {
protocol string
repeat int
name string
seq []uint
}{
{"http", 4, []uint{100, 80, 60, 40, 20, 10}},
{"http", 2, []uint{20, 40, 60, 80, 100}},
{"http", 1, []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 100}},
{"tcp", 4, []uint{100, 80, 60, 40, 20, 10}},
{"tcp", 2, []uint{20, 40, 60, 80, 100}},
{"tcp", 1, []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 100}},
{"http", "small", []uint{100, 80, 60, 40, 20, 10}},
{"http", "mid", []uint{20, 40, 60, 80, 100}},
{"http", "big", []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 100}},
{"tcp", "small", []uint{100, 80, 60, 40, 20, 10}},
{"tcp", "mid", []uint{20, 40, 60, 80, 100}},
{"tcp", "big", []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 100}},
}
var wg sync.WaitGroup
for _, tt := range data {
for i := 0; i < tt.repeat; i++ {
wg.Add(1)
tt := tt
name := fmt.Sprintf("%s/%s", tt.protocol, tt.name)
t.Run(name, func(t *testing.T) {
t.Parallel()
switch tt.protocol {
case "http":
go testHTTP(t, s, tt.seq, &wg)
testHTTP(t, tt.seq)
case "tcp":
go testTCP(t, l.Addr().String(), tt.seq, &wg)
testTCP(t, tt.seq)
default:
panic("Unexpected network type")
}
}
})
}
wg.Wait()
}
func testHTTP(t *testing.T, h http.Handler, seq []uint, wg *sync.WaitGroup) {
defer wg.Done()
func testHTTP(t *testing.T, seq []uint) {
var buf = bytes.NewBuffer(bigBuffer())
for idx, s := range seq {
for s > 0 {
r, err := http.NewRequest(http.MethodPost, "http://foobar.com/some/path", bytes.NewReader(payload[idx]))
r, err := http.NewRequest(http.MethodPost, "http://foobar.com/some/path", bytes.NewReader(ctx.payload[idx]))
if err != nil {
panic("Failed to create request")
}
@ -108,11 +148,11 @@ func testHTTP(t *testing.T, h http.Handler, seq []uint, wg *sync.WaitGroup) {
Body: buf,
Code: 200,
}
h.ServeHTTP(w, r)
ctx.handler.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Error("Unexpected status code", w)
}
n, m := w.Body.Len(), len(payload[idx])
n, m := w.Body.Len(), len(ctx.payload[idx])
if n != m {
t.Log("Read mismatch", n, m)
}
@ -121,12 +161,10 @@ func testHTTP(t *testing.T, h http.Handler, seq []uint, wg *sync.WaitGroup) {
}
}
func testTCP(t *testing.T, addr string, seq []uint, wg *sync.WaitGroup) {
defer wg.Done()
conn, err := net.Dial("tcp", addr)
func testTCP(t *testing.T, seq []uint) {
conn, err := net.Dial("tcp", ctx.listener.Addr().String())
if err != nil {
t.Error("Dial failed", err)
t.Fatal("Dial failed", err)
}
defer conn.Close()
@ -134,12 +172,12 @@ func testTCP(t *testing.T, addr string, seq []uint, wg *sync.WaitGroup) {
var read, write int
for idx, s := range seq {
for s > 0 {
m, err := conn.Write(payload[idx])
m, err := conn.Write(ctx.payload[idx])
if err != nil {
t.Error("Write failed", err)
}
if m != len(payload[idx]) {
t.Log("Write mismatch", m, len(payload[idx]))
if m != len(ctx.payload[idx]) {
t.Log("Write mismatch", m, len(ctx.payload[idx]))
}
write += m
@ -183,28 +221,6 @@ func selfSignedCert() (tls.Certificate, id.ID) {
return cert, id.New(x509Cert.Raw)
}
func randPayload(initialSize, n int) [][]byte {
payload := make([][]byte, n)
l := initialSize
for i := 0; i < n; i++ {
payload[i] = randBytes(l)
l *= 2
}
return payload
}
func randBytes(n int) []byte {
b := make([]byte, n)
read, err := rand.Read(b)
if err != nil {
panic(err)
}
if read != n {
panic("Read did not fill whole slice")
}
return b
}
func bigBuffer() []byte {
return make([]byte, len(payload[len(payload)-1]))
return make([]byte, len(ctx.payload[len(ctx.payload)-1]))
}