tunnel: removal of panics

This commit is contained in:
Michał Matczuk 2017-11-25 22:35:04 +01:00
parent fc43db429f
commit 65877e76a8
6 changed files with 41 additions and 38 deletions

View file

@ -7,6 +7,7 @@ package tunnel
import ( import (
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -57,18 +58,18 @@ type Client struct {
// NewClient creates a new unconnected Client based on configuration. Caller // NewClient creates a new unconnected Client based on configuration. Caller
// must invoke Start() on returned instance in order to connect server. // must invoke Start() on returned instance in order to connect server.
func NewClient(config *ClientConfig) *Client { func NewClient(config *ClientConfig) (*Client, error) {
if config.ServerAddr == "" { if config.ServerAddr == "" {
panic("missing ServerAddr") return nil, errors.New("missing ServerAddr")
} }
if config.TLSClientConfig == nil { if config.TLSClientConfig == nil {
panic("missing TLSClientConfig") return nil, errors.New("missing TLSClientConfig")
} }
if config.Tunnels == nil || len(config.Tunnels) == 0 { if len(config.Tunnels) == 0 {
panic("missing Tunnels") return nil, errors.New("missing Tunnels")
} }
if config.Proxy == nil { if config.Proxy == nil {
panic("missing Proxy") return nil, errors.New("missing Proxy")
} }
logger := config.Logger logger := config.Logger
@ -82,7 +83,7 @@ func NewClient(config *ClientConfig) *Client {
logger: logger, logger: logger,
} }
return c return c, nil
} }
// Start connects client to the server, it returns error if there is a // Start connects client to the server, it returns error if there is a

View file

@ -23,7 +23,7 @@ func TestClient_Dial(t *testing.T) {
s := httptest.NewTLSServer(nil) s := httptest.NewTLSServer(nil)
defer s.Close() defer s.Close()
c := NewClient(&ClientConfig{ c, err := NewClient(&ClientConfig{
ServerAddr: s.Listener.Addr().String(), ServerAddr: s.Listener.Addr().String(),
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
@ -31,6 +31,9 @@ func TestClient_Dial(t *testing.T) {
Tunnels: map[string]*proto.Tunnel{"test": {}}, Tunnels: map[string]*proto.Tunnel{"test": {}},
Proxy: Proxy(ProxyFuncs{}), Proxy: Proxy(ProxyFuncs{}),
}) })
if err != nil {
t.Fatal(err)
}
conn, err := c.dial() conn, err := c.dial()
if err != nil { if err != nil {
@ -58,7 +61,7 @@ func TestClient_DialBackoff(t *testing.T) {
return nil, errors.New("foobar") return nil, errors.New("foobar")
} }
c := NewClient(&ClientConfig{ c, err := NewClient(&ClientConfig{
ServerAddr: "8.8.8.8", ServerAddr: "8.8.8.8",
TLSClientConfig: &tls.Config{}, TLSClientConfig: &tls.Config{},
DialTLS: d, DialTLS: d,
@ -66,12 +69,14 @@ func TestClient_DialBackoff(t *testing.T) {
Tunnels: map[string]*proto.Tunnel{"test": {}}, Tunnels: map[string]*proto.Tunnel{"test": {}},
Proxy: Proxy(ProxyFuncs{}), Proxy: Proxy(ProxyFuncs{}),
}) })
if err != nil {
t.Fatal(err)
}
start := time.Now() start := time.Now()
_, err := c.dial() _, err = c.dial()
end := time.Now()
if end.Sub(start) < 100*time.Millisecond { if time.Since(start) < 100*time.Millisecond {
t.Fatal("Wait mismatch", err) t.Fatal("Wait mismatch", err)
} }

View file

@ -90,11 +90,11 @@ func main() {
b, err := yaml.Marshal(config) b, err := yaml.Marshal(config)
if err != nil { if err != nil {
fatal("failed to load config: %s", err) fatal("failed to dump config: %s", err)
} }
logger.Log("config", string(b)) logger.Log("config", string(b))
client := tunnel.NewClient(&tunnel.ClientConfig{ client, err := tunnel.NewClient(&tunnel.ClientConfig{
ServerAddr: config.ServerAddr, ServerAddr: config.ServerAddr,
TLSClientConfig: tlsconf, TLSClientConfig: tlsconf,
Backoff: expBackoff(config.Backoff), Backoff: expBackoff(config.Backoff),
@ -102,9 +102,12 @@ func main() {
Proxy: proxy(config.Tunnels, logger), Proxy: proxy(config.Tunnels, logger),
Logger: logger, Logger: logger,
}) })
if err != nil {
fatal("failed to create client: %s", err)
}
if err := client.Start(); err != nil { if err := client.Start(); err != nil {
fatal("%s", err) fatal("failed to start tunnels: %s", err)
} }
} }

View file

@ -7,7 +7,6 @@ package tunnel
import ( import (
"bufio" "bufio"
"context" "context"
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -38,10 +37,6 @@ type HTTPProxy struct {
// NewHTTPProxy creates a new direct HTTPProxy, everything will be proxied to // NewHTTPProxy creates a new direct HTTPProxy, everything will be proxied to
// localURL. // localURL.
func NewHTTPProxy(localURL *url.URL, logger log.Logger) *HTTPProxy { func NewHTTPProxy(localURL *url.URL, logger log.Logger) *HTTPProxy {
if localURL == nil {
panic("empty localURL")
}
if logger == nil { if logger == nil {
logger = log.NewNopLogger() logger = log.NewNopLogger()
} }
@ -58,10 +53,6 @@ func NewHTTPProxy(localURL *url.URL, logger log.Logger) *HTTPProxy {
// NewMultiHTTPProxy creates a new dispatching HTTPProxy, requests may go to // NewMultiHTTPProxy creates a new dispatching HTTPProxy, requests may go to
// different backends based on localURLMap. // different backends based on localURLMap.
func NewMultiHTTPProxy(localURLMap map[string]*url.URL, logger log.Logger) *HTTPProxy { func NewMultiHTTPProxy(localURLMap map[string]*url.URL, logger log.Logger) *HTTPProxy {
if localURLMap == nil {
panic("empty localURLMap")
}
if logger == nil { if logger == nil {
logger = log.NewNopLogger() logger = log.NewNopLogger()
} }
@ -91,7 +82,11 @@ func (p *HTTPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessag
rw, ok := w.(http.ResponseWriter) rw, ok := w.(http.ResponseWriter)
if !ok { if !ok {
panic(fmt.Sprintf("Expected http.ResponseWriter got %T", w)) p.logger.Log(
"level", 0,
"msg", "expected http.ResponseWriter",
"ctrlMsg", msg,
)
} }
req, err := http.ReadRequest(bufio.NewReader(r)) req, err := http.ReadRequest(bufio.NewReader(r))
@ -168,7 +163,7 @@ func singleJoiningSlash(a, b string) string {
} }
func (p *HTTPProxy) localURLFor(u *url.URL) *url.URL { func (p *HTTPProxy) localURLFor(u *url.URL) *url.URL {
if p.localURLMap == nil { if len(p.localURLMap) == 0 {
return p.localURL return p.localURL
} }

View file

@ -51,7 +51,7 @@ func echoHTTP(t testing.TB, l net.Listener) {
if r.Body != nil { if r.Body != nil {
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
panic(err) t.Fatal(err)
} }
w.Write(body) w.Write(body)
} }
@ -132,7 +132,7 @@ func makeTunnelClient(t testing.TB, serverAddr string, httpLocalAddr, httpAddr,
} }
cert, _ := selfSignedCert() cert, _ := selfSignedCert()
c := tunnel.NewClient(&tunnel.ClientConfig{ c, err := tunnel.NewClient(&tunnel.ClientConfig{
ServerAddr: serverAddr, ServerAddr: serverAddr,
TLSClientConfig: tlsConfig(cert), TLSClientConfig: tlsConfig(cert),
Tunnels: tunnels, Tunnels: tunnels,
@ -142,7 +142,14 @@ func makeTunnelClient(t testing.TB, serverAddr string, httpLocalAddr, httpAddr,
}), }),
Logger: log.NewStdLogger(), Logger: log.NewStdLogger(),
}) })
go c.Start() if err != nil {
t.Fatal(err)
}
go func() {
if err := c.Start(); err != nil {
t.Log(err)
}
}()
return c return c
} }

View file

@ -31,10 +31,6 @@ type TCPProxy struct {
// NewTCPProxy creates new direct TCPProxy, everything will be proxied to // NewTCPProxy creates new direct TCPProxy, everything will be proxied to
// localAddr. // localAddr.
func NewTCPProxy(localAddr string, logger log.Logger) *TCPProxy { func NewTCPProxy(localAddr string, logger log.Logger) *TCPProxy {
if localAddr == "" {
panic("missing localAddr")
}
if logger == nil { if logger == nil {
logger = log.NewNopLogger() logger = log.NewNopLogger()
} }
@ -48,10 +44,6 @@ func NewTCPProxy(localAddr string, logger log.Logger) *TCPProxy {
// NewMultiTCPProxy creates a new dispatching TCPProxy, connections may go to // NewMultiTCPProxy creates a new dispatching TCPProxy, connections may go to
// different backends based on localAddrMap. // different backends based on localAddrMap.
func NewMultiTCPProxy(localAddrMap map[string]string, logger log.Logger) *TCPProxy { func NewMultiTCPProxy(localAddrMap map[string]string, logger log.Logger) *TCPProxy {
if localAddrMap == nil {
panic("missing localAddrMap")
}
if logger == nil { if logger == nil {
logger = log.NewNopLogger() logger = log.NewNopLogger()
} }
@ -117,7 +109,7 @@ func (p *TCPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage
} }
func (p *TCPProxy) localAddrFor(hostPort string) string { func (p *TCPProxy) localAddrFor(hostPort string) string {
if p.localAddrMap == nil { if len(p.localAddrMap) == 0 {
return p.localAddr return p.localAddr
} }