mirror of
https://github.com/mmatczuk/go-http-tunnel.git
synced 2026-05-15 14:16:17 -06:00
Merge pull request #1 from mmatczuk/mma/unit_subtest
test: unit test refactored to promote test separation using subtests
This commit is contained in:
commit
92a10b3e1d
1 changed files with 76 additions and 60 deletions
|
|
@ -1,14 +1,16 @@
|
|||
// Package h2tun_test implements integration test that ensures that client-server
|
||||
// connectivity works properly.
|
||||
package h2tun_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -18,23 +20,33 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
payloadInitialSize = 16
|
||||
payloadInitialSize = 32
|
||||
payloadLen = 10
|
||||
)
|
||||
|
||||
var payload = randPayload(payloadInitialSize, payloadLen)
|
||||
// testContext stores state shared between sub tests.
|
||||
type testContext struct {
|
||||
// handler is entry point for HTTP tests.
|
||||
handler http.Handler
|
||||
// listener is entry point for TCP tests.
|
||||
listener net.Listener
|
||||
// payload is pre generated random data.
|
||||
payload [][]byte
|
||||
}
|
||||
|
||||
func TestProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
var ctx testContext
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
cert, id := selfSignedCert()
|
||||
|
||||
// prepare TCP listener
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Error("Listen failed", err)
|
||||
panic(err)
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
// prepare h2tun server
|
||||
s, err := h2tun.NewServer(&h2tun.ServerConfig{
|
||||
TLSConfig: h2tuntest.TLSConfig(cert),
|
||||
AllowedClients: []*h2tun.AllowedClient{
|
||||
|
|
@ -46,59 +58,87 @@ func TestProxy(t *testing.T) {
|
|||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Error("Server creation failed", err)
|
||||
panic(err)
|
||||
}
|
||||
s.Start()
|
||||
defer s.Close()
|
||||
|
||||
// prepare h2tun client
|
||||
c := h2tun.NewClient(&h2tun.ClientConfig{
|
||||
ServerAddr: s.Addr(),
|
||||
TLSClientConfig: h2tuntest.TLSConfig(cert),
|
||||
Proxy: h2tuntest.EchoProxyFunc,
|
||||
})
|
||||
if err := c.Start(); err != nil {
|
||||
t.Error("Client start failed", err)
|
||||
panic(err)
|
||||
}
|
||||
defer c.Stop()
|
||||
|
||||
ctx.handler = s
|
||||
ctx.listener = l
|
||||
ctx.payload = randPayload(payloadInitialSize, payloadLen)
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func randPayload(initialSize, n int) [][]byte {
|
||||
payload := make([][]byte, n)
|
||||
l := initialSize
|
||||
for i := 0; i < n; i++ {
|
||||
payload[i] = randBytes(l)
|
||||
l *= 2
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func randBytes(n int) []byte {
|
||||
b := make([]byte, n)
|
||||
read, err := rand.Read(b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if read != n {
|
||||
panic("Read did not fill whole slice")
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func TestProxying(t *testing.T) {
|
||||
data := []struct {
|
||||
protocol string
|
||||
repeat int
|
||||
name string
|
||||
seq []uint
|
||||
}{
|
||||
{"http", 4, []uint{100, 80, 60, 40, 20, 10}},
|
||||
{"http", 2, []uint{20, 40, 60, 80, 100}},
|
||||
{"http", 1, []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 100}},
|
||||
|
||||
{"tcp", 4, []uint{100, 80, 60, 40, 20, 10}},
|
||||
{"tcp", 2, []uint{20, 40, 60, 80, 100}},
|
||||
{"tcp", 1, []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 100}},
|
||||
{"http", "small", []uint{100, 80, 60, 40, 20, 10}},
|
||||
{"http", "mid", []uint{20, 40, 60, 80, 100}},
|
||||
{"http", "big", []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 100}},
|
||||
{"tcp", "small", []uint{100, 80, 60, 40, 20, 10}},
|
||||
{"tcp", "mid", []uint{20, 40, 60, 80, 100}},
|
||||
{"tcp", "big", []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 100}},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, tt := range data {
|
||||
for i := 0; i < tt.repeat; i++ {
|
||||
wg.Add(1)
|
||||
tt := tt
|
||||
name := fmt.Sprintf("%s/%s", tt.protocol, tt.name)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
switch tt.protocol {
|
||||
case "http":
|
||||
go testHTTP(t, s, tt.seq, &wg)
|
||||
testHTTP(t, tt.seq)
|
||||
case "tcp":
|
||||
go testTCP(t, l.Addr().String(), tt.seq, &wg)
|
||||
testTCP(t, tt.seq)
|
||||
default:
|
||||
panic("Unexpected network type")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func testHTTP(t *testing.T, h http.Handler, seq []uint, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
func testHTTP(t *testing.T, seq []uint) {
|
||||
var buf = bytes.NewBuffer(bigBuffer())
|
||||
for idx, s := range seq {
|
||||
for s > 0 {
|
||||
r, err := http.NewRequest(http.MethodPost, "http://foobar.com/some/path", bytes.NewReader(payload[idx]))
|
||||
r, err := http.NewRequest(http.MethodPost, "http://foobar.com/some/path", bytes.NewReader(ctx.payload[idx]))
|
||||
if err != nil {
|
||||
panic("Failed to create request")
|
||||
}
|
||||
|
|
@ -108,11 +148,11 @@ func testHTTP(t *testing.T, h http.Handler, seq []uint, wg *sync.WaitGroup) {
|
|||
Body: buf,
|
||||
Code: 200,
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
ctx.handler.ServeHTTP(w, r)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Error("Unexpected status code", w)
|
||||
}
|
||||
n, m := w.Body.Len(), len(payload[idx])
|
||||
n, m := w.Body.Len(), len(ctx.payload[idx])
|
||||
if n != m {
|
||||
t.Log("Read mismatch", n, m)
|
||||
}
|
||||
|
|
@ -121,12 +161,10 @@ func testHTTP(t *testing.T, h http.Handler, seq []uint, wg *sync.WaitGroup) {
|
|||
}
|
||||
}
|
||||
|
||||
func testTCP(t *testing.T, addr string, seq []uint, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
func testTCP(t *testing.T, seq []uint) {
|
||||
conn, err := net.Dial("tcp", ctx.listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Error("Dial failed", err)
|
||||
t.Fatal("Dial failed", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
|
|
@ -134,12 +172,12 @@ func testTCP(t *testing.T, addr string, seq []uint, wg *sync.WaitGroup) {
|
|||
var read, write int
|
||||
for idx, s := range seq {
|
||||
for s > 0 {
|
||||
m, err := conn.Write(payload[idx])
|
||||
m, err := conn.Write(ctx.payload[idx])
|
||||
if err != nil {
|
||||
t.Error("Write failed", err)
|
||||
}
|
||||
if m != len(payload[idx]) {
|
||||
t.Log("Write mismatch", m, len(payload[idx]))
|
||||
if m != len(ctx.payload[idx]) {
|
||||
t.Log("Write mismatch", m, len(ctx.payload[idx]))
|
||||
}
|
||||
write += m
|
||||
|
||||
|
|
@ -183,28 +221,6 @@ func selfSignedCert() (tls.Certificate, id.ID) {
|
|||
return cert, id.New(x509Cert.Raw)
|
||||
}
|
||||
|
||||
func randPayload(initialSize, n int) [][]byte {
|
||||
payload := make([][]byte, n)
|
||||
l := initialSize
|
||||
for i := 0; i < n; i++ {
|
||||
payload[i] = randBytes(l)
|
||||
l *= 2
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func randBytes(n int) []byte {
|
||||
b := make([]byte, n)
|
||||
read, err := rand.Read(b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if read != n {
|
||||
panic("Read did not fill whole slice")
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func bigBuffer() []byte {
|
||||
return make([]byte, len(payload[len(payload)-1]))
|
||||
return make([]byte, len(ctx.payload[len(ctx.payload)-1]))
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue