diff --git a/client.go b/client.go index 041fb9f..3647ad7 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ package tunnel import ( "crypto/tls" "encoding/json" + "errors" "fmt" "net" "net/http" @@ -57,18 +58,18 @@ type Client struct { // NewClient creates a new unconnected Client based on configuration. Caller // must invoke Start() on returned instance in order to connect server. -func NewClient(config *ClientConfig) *Client { +func NewClient(config *ClientConfig) (*Client, error) { if config.ServerAddr == "" { - panic("missing ServerAddr") + return nil, errors.New("missing ServerAddr") } if config.TLSClientConfig == nil { - panic("missing TLSClientConfig") + return nil, errors.New("missing TLSClientConfig") } - if config.Tunnels == nil || len(config.Tunnels) == 0 { - panic("missing Tunnels") + if len(config.Tunnels) == 0 { + return nil, errors.New("missing Tunnels") } if config.Proxy == nil { - panic("missing Proxy") + return nil, errors.New("missing Proxy") } logger := config.Logger @@ -82,7 +83,7 @@ func NewClient(config *ClientConfig) *Client { logger: logger, } - return c + return c, nil } // Start connects client to the server, it returns error if there is a diff --git a/client_test.go b/client_test.go index 7e04d9e..36310b9 100644 --- a/client_test.go +++ b/client_test.go @@ -23,7 +23,7 @@ func TestClient_Dial(t *testing.T) { s := httptest.NewTLSServer(nil) defer s.Close() - c := NewClient(&ClientConfig{ + c, err := NewClient(&ClientConfig{ ServerAddr: s.Listener.Addr().String(), TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, @@ -31,6 +31,9 @@ func TestClient_Dial(t *testing.T) { Tunnels: map[string]*proto.Tunnel{"test": {}}, Proxy: Proxy(ProxyFuncs{}), }) + if err != nil { + t.Fatal(err) + } conn, err := c.dial() if err != nil { @@ -58,7 +61,7 @@ func TestClient_DialBackoff(t *testing.T) { return nil, errors.New("foobar") } - c := NewClient(&ClientConfig{ + c, err := NewClient(&ClientConfig{ ServerAddr: "8.8.8.8", TLSClientConfig: &tls.Config{}, DialTLS: d, @@ -66,12 +69,14 @@ func TestClient_DialBackoff(t *testing.T) { Tunnels: map[string]*proto.Tunnel{"test": {}}, Proxy: Proxy(ProxyFuncs{}), }) + if err != nil { + t.Fatal(err) + } start := time.Now() - _, err := c.dial() - end := time.Now() + _, err = c.dial() - if end.Sub(start) < 100*time.Millisecond { + if time.Since(start) < 100*time.Millisecond { t.Fatal("Wait mismatch", err) } diff --git a/cmd/tunnel/tunnel.go b/cmd/tunnel/tunnel.go index 785010a..5ead7af 100644 --- a/cmd/tunnel/tunnel.go +++ b/cmd/tunnel/tunnel.go @@ -90,11 +90,11 @@ func main() { b, err := yaml.Marshal(config) if err != nil { - fatal("failed to load config: %s", err) + fatal("failed to dump config: %s", err) } logger.Log("config", string(b)) - client := tunnel.NewClient(&tunnel.ClientConfig{ + client, err := tunnel.NewClient(&tunnel.ClientConfig{ ServerAddr: config.ServerAddr, TLSClientConfig: tlsconf, Backoff: expBackoff(config.Backoff), @@ -102,9 +102,12 @@ func main() { Proxy: proxy(config.Tunnels, logger), Logger: logger, }) + if err != nil { + fatal("failed to create client: %s", err) + } if err := client.Start(); err != nil { - fatal("%s", err) + fatal("failed to start tunnels: %s", err) } } diff --git a/httpproxy.go b/httpproxy.go index ac85b5e..eb7f2b2 100644 --- a/httpproxy.go +++ b/httpproxy.go @@ -7,7 +7,6 @@ package tunnel import ( "bufio" "context" - "fmt" "io" "net" "net/http" @@ -38,10 +37,6 @@ type HTTPProxy struct { // NewHTTPProxy creates a new direct HTTPProxy, everything will be proxied to // localURL. func NewHTTPProxy(localURL *url.URL, logger log.Logger) *HTTPProxy { - if localURL == nil { - panic("empty localURL") - } - if logger == nil { logger = log.NewNopLogger() } @@ -58,10 +53,6 @@ func NewHTTPProxy(localURL *url.URL, logger log.Logger) *HTTPProxy { // NewMultiHTTPProxy creates a new dispatching HTTPProxy, requests may go to // different backends based on localURLMap. func NewMultiHTTPProxy(localURLMap map[string]*url.URL, logger log.Logger) *HTTPProxy { - if localURLMap == nil { - panic("empty localURLMap") - } - if logger == nil { logger = log.NewNopLogger() } @@ -91,7 +82,11 @@ func (p *HTTPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessag rw, ok := w.(http.ResponseWriter) if !ok { - panic(fmt.Sprintf("Expected http.ResponseWriter got %T", w)) + p.logger.Log( + "level", 0, + "msg", "expected http.ResponseWriter", + "ctrlMsg", msg, + ) } req, err := http.ReadRequest(bufio.NewReader(r)) @@ -168,7 +163,7 @@ func singleJoiningSlash(a, b string) string { } func (p *HTTPProxy) localURLFor(u *url.URL) *url.URL { - if p.localURLMap == nil { + if len(p.localURLMap) == 0 { return p.localURL } diff --git a/integration_test.go b/integration_test.go index 5412360..6fcf43e 100644 --- a/integration_test.go +++ b/integration_test.go @@ -51,7 +51,7 @@ func echoHTTP(t testing.TB, l net.Listener) { if r.Body != nil { body, err := ioutil.ReadAll(r.Body) if err != nil { - panic(err) + t.Fatal(err) } w.Write(body) } @@ -132,7 +132,7 @@ func makeTunnelClient(t testing.TB, serverAddr string, httpLocalAddr, httpAddr, } cert, _ := selfSignedCert() - c := tunnel.NewClient(&tunnel.ClientConfig{ + c, err := tunnel.NewClient(&tunnel.ClientConfig{ ServerAddr: serverAddr, TLSClientConfig: tlsConfig(cert), Tunnels: tunnels, @@ -142,7 +142,14 @@ func makeTunnelClient(t testing.TB, serverAddr string, httpLocalAddr, httpAddr, }), Logger: log.NewStdLogger(), }) - go c.Start() + if err != nil { + t.Fatal(err) + } + go func() { + if err := c.Start(); err != nil { + t.Log(err) + } + }() return c } diff --git a/tcpproxy.go b/tcpproxy.go index a55d48f..80e584e 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -31,10 +31,6 @@ type TCPProxy struct { // NewTCPProxy creates new direct TCPProxy, everything will be proxied to // localAddr. func NewTCPProxy(localAddr string, logger log.Logger) *TCPProxy { - if localAddr == "" { - panic("missing localAddr") - } - if logger == nil { logger = log.NewNopLogger() } @@ -48,10 +44,6 @@ func NewTCPProxy(localAddr string, logger log.Logger) *TCPProxy { // NewMultiTCPProxy creates a new dispatching TCPProxy, connections may go to // different backends based on localAddrMap. func NewMultiTCPProxy(localAddrMap map[string]string, logger log.Logger) *TCPProxy { - if localAddrMap == nil { - panic("missing localAddrMap") - } - if logger == nil { logger = log.NewNopLogger() } @@ -117,7 +109,7 @@ func (p *TCPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage } func (p *TCPProxy) localAddrFor(hostPort string) string { - if p.localAddrMap == nil { + if len(p.localAddrMap) == 0 { return p.localAddr }