diff --git a/client.go b/client.go index ac86fc4..7d27274 100644 --- a/client.go +++ b/client.go @@ -7,11 +7,13 @@ import ( "net" "net/http" + "github.com/koding/h2tun/proto" "github.com/koding/logging" "golang.org/x/net/http2" ) type Client struct { + proxy ProxyFunc serverAddr string tlsConfig *tls.Config conn net.Conn @@ -19,8 +21,9 @@ type Client struct { log logging.Logger } -func NewClient(serverAddr string, tlsConfig *tls.Config) *Client { +func NewClient(proxy ProxyFunc, serverAddr string, tlsConfig *tls.Config) *Client { return &Client{ + proxy: proxy, serverAddr: serverAddr, tlsConfig: tlsConfig, httpServer: &http2.Server{}, @@ -36,12 +39,29 @@ func (c *Client) Connect() error { c.conn = conn c.httpServer.ServeConn(conn, &http2.ServeConnOpts{ - Handler: http.HandlerFunc(c.proxy), + Handler: http.HandlerFunc(c.serveHTTP), }) return nil } +func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodConnect { + c.log.Info("Handshake: hello from server") + http.Error(w, "Nice to see you", http.StatusOK) + return + } + + msg, err := proto.ParseControlMessage(r.Header) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + c.log.Debug("Proxy init %s %v", r.RemoteAddr, msg) + c.proxy(flushWriter{w}, r.Body, msg) + c.log.Debug("Proxy over %s %v", r.RemoteAddr, msg) +} + type flushWriter struct { w io.Writer } @@ -54,15 +74,6 @@ func (fw flushWriter) Write(p []byte) (n int, err error) { return } -func (c *Client) proxy(w http.ResponseWriter, r *http.Request) { - c.log.Info("New proxy request") - if r.Method == http.MethodConnect { - http.Error(w, "OK", http.StatusOK) - } else { - io.Copy(flushWriter{w}, r.Body) - } -} - func (c *Client) Close() error { if c.conn == nil { return nil diff --git a/control_msg.go b/control_msg.go deleted file mode 100644 index 9cf26ee..0000000 --- a/control_msg.go +++ /dev/null @@ -1,26 +0,0 @@ -package h2tun - -// ControlMessage is sent from server to client to establish tunneled connection. -type ControlMessage struct { - Action Action - Protocol Type - LocalPort int -} - -// Action represents type of ControlMsg request. -type Action int - -// ControlMessage actions. -const ( - RequestClientSession Action = iota + 1 -) - -// Type represents tunneled connection type. -type Type int - -// ControlMessage protocols. -const ( - HTTP Type = iota + 1 - WS - RAW -) diff --git a/h2tun_test.go b/h2tun_test.go index 05655b7..3cd22a9 100644 --- a/h2tun_test.go +++ b/h2tun_test.go @@ -1,22 +1,29 @@ package h2tun_test import ( + "bufio" "bytes" "crypto/tls" "crypto/x509" + "fmt" "io" + "io/ioutil" "net" + "net/http" + "net/http/httptest" + "strings" "sync" "testing" "time" "github.com/andrew-d/id" "github.com/koding/h2tun" + "github.com/koding/h2tun/proto" "github.com/koding/logging" "github.com/stretchr/testify/assert" ) -func TestHTTP(t *testing.T) { +func TestTCP(t *testing.T) { logging.DefaultLevel = logging.DEBUG logging.DefaultHandler.SetLevel(logging.DEBUG) @@ -42,7 +49,7 @@ func TestHTTP(t *testing.T) { server.Start() defer server.Close() - client := h2tun.NewClient(server.Addr().String(), tlsConfig(cert)) + client := h2tun.NewClient(echoProxyFunc, server.Addr().String(), tlsConfig(cert)) go client.Connect() defer client.Close() @@ -73,6 +80,80 @@ func TestHTTP(t *testing.T) { wg.Wait() } +func echoProxyFunc(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage) { + io.Copy(w, r) +} + +func TestHTTP(t *testing.T) { + logging.DefaultLevel = logging.DEBUG + logging.DefaultHandler.SetLevel(logging.DEBUG) + + cert, err := loadTestCert() + assert.Nil(t, err) + clientID := idFromTLSCert(cert) + + server, err := h2tun.NewServer( + tlsConfig(cert), + []*h2tun.AllowedClient{ + { + ID: clientID, + Host: "foobar.com", + }, + }, + ) + assert.Nil(t, err) + server.Start() + defer server.Close() + + client := h2tun.NewClient(echoHTTPProxyFunc, server.Addr().String(), tlsConfig(cert)) + go client.Connect() + defer client.Close() + + time.Sleep(time.Second) + + s := httptest.NewServer(server) + defer s.Close() + + const testPayload = "this is a test" + + _, port, _ := net.SplitHostPort(s.Listener.Addr().String()) + r, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://foobar.com:%s/some/path", port), strings.NewReader(testPayload)) + assert.Nil(t, err) + + resp, err := http.DefaultClient.Do(r) + assert.Nil(t, err) + body, err := ioutil.ReadAll(resp.Body) + assert.Equal(t, testPayload, string(body)) +} + +func echoHTTPProxyFunc(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage) { + req, err := http.ReadRequest(bufio.NewReader(r)) + if err != nil { + panic(err) + } + + body, err := ioutil.ReadAll(req.Body) + if err != nil { + panic(err) + } + + headers := make(http.Header) + headers.Set("Content-Type", "text/plain") + + resp := &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: req, + Header: headers, + ContentLength: int64(len(body)), + Body: ioutil.NopCloser(bytes.NewReader(body)), + } + resp.Write(w) +} + func loadTestCert() (tls.Certificate, error) { return tls.LoadX509KeyPair("./test-fixtures/selfsigned.crt", "./test-fixtures/selfsigned.key") } diff --git a/proto/controlmsg.go b/proto/controlmsg.go new file mode 100644 index 0000000..d50f3ba --- /dev/null +++ b/proto/controlmsg.go @@ -0,0 +1,73 @@ +package proto + +import ( + "errors" + "fmt" + "net/http" + "regexp" +) + +const ( + HTTPProtocol = "HTTP" +) + +// Action represents type of ControlMsg request. +type Action int + +// ControlMessage actions. +const ( + RequestClientSession Action = iota +) + +// ControlMessage headers +const ( + ForwardedHeader = "Forwarded" +) + +// ControlMessage is sent from server to client to establish tunneled connection. +type ControlMessage struct { + Action Action + Protocol string + ForwardedFor string + ForwardedBy string + URLPath string +} + +var xffRegexp = regexp.MustCompile("(for|by|proto|path)=([^;$]+)") + +// NewControlMessage creates control message based on `Forwarded` http header. +func ParseControlMessage(h http.Header) (*ControlMessage, error) { + v := h.Get(ForwardedHeader) + if v == "" { + return nil, errors.New("missing Forwarded header") + } + + var msg = &ControlMessage{} + + for _, i := range xffRegexp.FindAllStringSubmatch(v, -1) { + switch i[1] { + case "for": + msg.ForwardedFor = i[2] + case "by": + msg.ForwardedBy = i[2] + case "proto": + msg.Protocol = i[2] + case "path": + msg.URLPath = i[2] + } + } + + return msg, nil +} + +// WriteTo writes ControlMessage to `Forwarded` http header, "by" and "for" parameters +// take form of full IP and port. +// +// If the server receiving proxied requests requires some address-based functionality, +// this parameter MAY instead contain an IP address (and, potentially, a port number) +// +// see https://tools.ietf.org/html/rfc7239. +func (c *ControlMessage) WriteTo(h http.Header) { + h.Set(ForwardedHeader, fmt.Sprintf("for=%s; by=%s; proto=%s; path=%s", + c.ForwardedFor, c.ForwardedBy, c.Protocol, c.URLPath)) +} diff --git a/proto/controlmsg_test.go b/proto/controlmsg_test.go new file mode 100644 index 0000000..7c70d2c --- /dev/null +++ b/proto/controlmsg_test.go @@ -0,0 +1,24 @@ +package proto + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestControlMessage_WriteParse(t *testing.T) { + msg := &ControlMessage{ + Protocol: "tcp", + ForwardedFor: "127.0.0.1:58104", + ForwardedBy: "127.0.0.1:7777", + URLPath: "/some/path", + } + + h := make(http.Header) + msg.WriteTo(h) + actual, err := ParseControlMessage(h) + + assert.Nil(t, err) + assert.Equal(t, msg, actual) +} diff --git a/proxy.go b/proxy.go new file mode 100644 index 0000000..6124639 --- /dev/null +++ b/proxy.go @@ -0,0 +1,10 @@ +package h2tun + +import ( + "io" + + "github.com/koding/h2tun/proto" +) + +// ProxyFunc is responsible for forwarding a remote connection to local server and writing the response back. +type ProxyFunc func(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage) diff --git a/server.go b/server.go index 5842909..ef43f31 100644 --- a/server.go +++ b/server.go @@ -1,6 +1,7 @@ package h2tun import ( + "bufio" "crypto/tls" "fmt" "io" @@ -10,22 +11,26 @@ import ( "time" "github.com/andrew-d/id" + "github.com/koding/h2tun/proto" "github.com/koding/logging" "golang.org/x/net/http2" ) +// TODO mma introduce config object // TODO mma add ListenerFunc func(net.Listener) net.Listener to allow for tls listener decoration -// TODO mma add dynamic allowed clients modifications +// TODO mma document +// +// TODO (phase2) mma add dynamic allowed clients modifications + type Server struct { allowedClients []*AllowedClient - listener net.Listener + + listener net.Listener httpClient *http.Client hostConn map[string]net.Conn hostConnMu sync.RWMutex - tcpPorts map[int]*AllowedClient - log logging.Logger } @@ -100,19 +105,19 @@ func (s *Server) handleClient(conn net.Conn) { id, err := peerID(conn.(*tls.Conn)) if err != nil { s.log.Warning("Certificate error: %s", err) - goto cleanup + goto reject } client, ok = s.checkID(id) if !ok { s.log.Warning("Unknown certificate: %q", id.String()) - goto cleanup + goto reject } - req, err = http.NewRequest(http.MethodConnect, url(client, ""), nil) + req, err = http.NewRequest(http.MethodConnect, url(client.Host), nil) if err != nil { s.log.Error("Invalid host %q for client %q", client.Host, client.ID) - goto cleanup + goto reject } if err = conn.SetDeadline(time.Time{}); err != nil { @@ -122,28 +127,37 @@ func (s *Server) handleClient(conn net.Conn) { if err := s.addHostConn(client, conn); err != nil { s.log.Warning("Could not add host: %s", err) - goto cleanup + goto reject } resp, err = s.httpClient.Do(req) if err != nil { s.log.Warning("Handshake failed %s", err) - goto cleanup + goto reject } if resp.StatusCode != http.StatusOK { s.log.Warning("Handshake failed") - goto cleanup + goto reject } return -cleanup: +reject: conn.Close() if client != nil { s.deleteHostConn(client.Host) } } +func (s *Server) checkID(id id.ID) (*AllowedClient, bool) { + for _, c := range s.allowedClients { + if id.Equals(c.ID) { + return c, true + } + } + return nil, false +} + func (s *Server) addHostConn(client *AllowedClient, conn net.Conn) error { key := hostPort(client.Host) @@ -166,19 +180,9 @@ func (s *Server) deleteHostConn(host string) { } func hostPort(host string) string { - // TODO mma add support for custom ports return fmt.Sprint(host, ":443") } -func (s *Server) checkID(id id.ID) (*AllowedClient, bool) { - for _, c := range s.allowedClients { - if id.Equals(c.ID) { - return c, true - } - } - return nil, false -} - func (s *Server) listenClientListeners() { for _, client := range s.allowedClients { if client.Listeners == nil { @@ -198,49 +202,96 @@ func (s *Server) listen(l net.Listener, client *AllowedClient) { s.log.Warning("Accept failed: %s", err) continue } - s.log.Debug("Accepted connection from %q", conn.RemoteAddr().String()) + s.log.Debug("Accepted connection from %q", conn.RemoteAddr()) - // TODO mma get Protocol from Network - // TODO mma get LocalIP from Addr - msg := &ControlMessage{ - Action: RequestClientSession, - Protocol: RAW, + msg := &proto.ControlMessage{ + Action: proto.RequestClientSession, + Protocol: l.Addr().Network(), + ForwardedFor: conn.RemoteAddr().String(), + ForwardedBy: conn.LocalAddr().String(), } - go s.proxy(conn, client, msg) + go s.proxy(client.Host, conn, conn, msg) } } -func (s *Server) proxy(conn net.Conn, client *AllowedClient, msg *ControlMessage) { - defer conn.Close() +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + msg := &proto.ControlMessage{ + Action: proto.RequestClientSession, + Protocol: proto.HTTPProtocol, + ForwardedFor: r.RemoteAddr, + ForwardedBy: r.Host, + URLPath: r.URL.Path, + } + + s.proxy(trimPort(r.Host), w, r, msg) +} + +func trimPort(hostPort string) (host string) { + host, _, _ = net.SplitHostPort(hostPort) + if host == "" { + return hostPort + } + return +} + +func (s *Server) proxy(host string, w io.Writer, r interface{}, msg *proto.ControlMessage) { + s.log.Debug("Proxy init %s %v", host, msg) + + defer func() { + if c, ok := r.(io.Closer); ok { + c.Close() + } + }() pr, pw := io.Pipe() defer pr.Close() defer pw.Close() - req, err := http.NewRequest(http.MethodPut, url(client, ""), pr) + req, err := http.NewRequest(http.MethodPut, url(host), pr) if err != nil { s.log.Error("Request creation failed: %s", err) return } + msg.WriteTo(req.Header) - // read from caller, write to tunnel client - var wg sync.WaitGroup - wg.Add(1) - go func() { - transfer("local to remote", pw, conn, s.log) - wg.Done() - }() + var localToRemoteDone = make(chan struct{}) - // read from tunnel client, write to caller - resp, err := s.httpClient.Do(req) - if err != nil { - s.log.Error("Proxing conn from %q to %q failed: %s", conn.RemoteAddr().String(), client.Host, err) - return + localToRemote := func() { + if hr, ok := r.(*http.Request); ok { + hr.Write(pw) + pw.Close() + } else { + transfer("local to remote", pw, r.(io.ReadCloser), s.log) + } + close(localToRemoteDone) } - transfer("remote to local", conn, resp.Body, s.log) - wg.Wait() + remoteToLocal := func() { + resp, err := s.httpClient.Do(req) + if err != nil { + s.log.Error("Proxing conn to client %q failed: %s", host, err) + return + } + if hw, ok := w.(http.ResponseWriter); ok { + pr, err := http.ReadResponse(bufio.NewReader(resp.Body), r.(*http.Request)) + if err != nil { + s.log.Error("Reading HTTP response failed: %s", err) + return + } + copyHeader(hw.Header(), pr.Header) + hw.WriteHeader(pr.StatusCode) + transfer("remote to local", hw, pr.Body, s.log) + } else { + transfer("remote to local", w, resp.Body, s.log) + } + } + + go localToRemote() + remoteToLocal() + <-localToRemoteDone + + s.log.Debug("Proxy over %s %v", host, msg) } func (s *Server) Addr() net.Addr { diff --git a/utils.go b/utils.go index 5a39972..a004695 100644 --- a/utils.go +++ b/utils.go @@ -3,12 +3,13 @@ package h2tun import ( "fmt" "io" + "net/http" "github.com/koding/logging" ) -func url(client *AllowedClient, path string) string { - return fmt.Sprint("https://", client.Host, path) +func url(host string) string { + return fmt.Sprint("https://", host) } type closeWriter interface { @@ -19,7 +20,7 @@ type closeReader interface { } func transfer(side string, dst io.Writer, src io.ReadCloser, log logging.Logger) { - log.Debug("proxing") + log.Debug("proxing %s", side) n, err := io.Copy(dst, src) if err != nil { @@ -36,5 +37,13 @@ func transfer(side string, dst io.Writer, src io.ReadCloser, log logging.Logger) src.Close() } - log.Debug("done proxing %d bytes", n) + log.Debug("done proxing %s %d bytes", side, n) +} + +func copyHeader(dst, src http.Header) { + for k, v := range src { + vv := make([]string, len(v)) + copy(vv, v) + dst[k] = vv + } }