http proxy: control message, improvements to proxing

This commit is contained in:
mmatczuk 2016-09-23 22:34:52 +02:00
parent dcb2f8ba09
commit f8a8ff0163
8 changed files with 322 additions and 89 deletions

View file

@ -7,11 +7,13 @@ import (
"net"
"net/http"
"github.com/koding/h2tun/proto"
"github.com/koding/logging"
"golang.org/x/net/http2"
)
type Client struct {
proxy ProxyFunc
serverAddr string
tlsConfig *tls.Config
conn net.Conn
@ -19,8 +21,9 @@ type Client struct {
log logging.Logger
}
func NewClient(serverAddr string, tlsConfig *tls.Config) *Client {
func NewClient(proxy ProxyFunc, serverAddr string, tlsConfig *tls.Config) *Client {
return &Client{
proxy: proxy,
serverAddr: serverAddr,
tlsConfig: tlsConfig,
httpServer: &http2.Server{},
@ -36,12 +39,29 @@ func (c *Client) Connect() error {
c.conn = conn
c.httpServer.ServeConn(conn, &http2.ServeConnOpts{
Handler: http.HandlerFunc(c.proxy),
Handler: http.HandlerFunc(c.serveHTTP),
})
return nil
}
func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
c.log.Info("Handshake: hello from server")
http.Error(w, "Nice to see you", http.StatusOK)
return
}
msg, err := proto.ParseControlMessage(r.Header)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
c.log.Debug("Proxy init %s %v", r.RemoteAddr, msg)
c.proxy(flushWriter{w}, r.Body, msg)
c.log.Debug("Proxy over %s %v", r.RemoteAddr, msg)
}
type flushWriter struct {
w io.Writer
}
@ -54,15 +74,6 @@ func (fw flushWriter) Write(p []byte) (n int, err error) {
return
}
func (c *Client) proxy(w http.ResponseWriter, r *http.Request) {
c.log.Info("New proxy request")
if r.Method == http.MethodConnect {
http.Error(w, "OK", http.StatusOK)
} else {
io.Copy(flushWriter{w}, r.Body)
}
}
func (c *Client) Close() error {
if c.conn == nil {
return nil

View file

@ -1,26 +0,0 @@
package h2tun
// ControlMessage is sent from server to client to establish tunneled connection.
type ControlMessage struct {
Action Action
Protocol Type
LocalPort int
}
// Action represents type of ControlMsg request.
type Action int
// ControlMessage actions.
const (
RequestClientSession Action = iota + 1
)
// Type represents tunneled connection type.
type Type int
// ControlMessage protocols.
const (
HTTP Type = iota + 1
WS
RAW
)

View file

@ -1,22 +1,29 @@
package h2tun_test
import (
"bufio"
"bytes"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/andrew-d/id"
"github.com/koding/h2tun"
"github.com/koding/h2tun/proto"
"github.com/koding/logging"
"github.com/stretchr/testify/assert"
)
func TestHTTP(t *testing.T) {
func TestTCP(t *testing.T) {
logging.DefaultLevel = logging.DEBUG
logging.DefaultHandler.SetLevel(logging.DEBUG)
@ -42,7 +49,7 @@ func TestHTTP(t *testing.T) {
server.Start()
defer server.Close()
client := h2tun.NewClient(server.Addr().String(), tlsConfig(cert))
client := h2tun.NewClient(echoProxyFunc, server.Addr().String(), tlsConfig(cert))
go client.Connect()
defer client.Close()
@ -73,6 +80,80 @@ func TestHTTP(t *testing.T) {
wg.Wait()
}
func echoProxyFunc(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage) {
io.Copy(w, r)
}
func TestHTTP(t *testing.T) {
logging.DefaultLevel = logging.DEBUG
logging.DefaultHandler.SetLevel(logging.DEBUG)
cert, err := loadTestCert()
assert.Nil(t, err)
clientID := idFromTLSCert(cert)
server, err := h2tun.NewServer(
tlsConfig(cert),
[]*h2tun.AllowedClient{
{
ID: clientID,
Host: "foobar.com",
},
},
)
assert.Nil(t, err)
server.Start()
defer server.Close()
client := h2tun.NewClient(echoHTTPProxyFunc, server.Addr().String(), tlsConfig(cert))
go client.Connect()
defer client.Close()
time.Sleep(time.Second)
s := httptest.NewServer(server)
defer s.Close()
const testPayload = "this is a test"
_, port, _ := net.SplitHostPort(s.Listener.Addr().String())
r, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://foobar.com:%s/some/path", port), strings.NewReader(testPayload))
assert.Nil(t, err)
resp, err := http.DefaultClient.Do(r)
assert.Nil(t, err)
body, err := ioutil.ReadAll(resp.Body)
assert.Equal(t, testPayload, string(body))
}
func echoHTTPProxyFunc(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage) {
req, err := http.ReadRequest(bufio.NewReader(r))
if err != nil {
panic(err)
}
body, err := ioutil.ReadAll(req.Body)
if err != nil {
panic(err)
}
headers := make(http.Header)
headers.Set("Content-Type", "text/plain")
resp := &http.Response{
Status: "200 OK",
StatusCode: 200,
Proto: "HTTP/1.0",
ProtoMajor: 1,
ProtoMinor: 0,
Request: req,
Header: headers,
ContentLength: int64(len(body)),
Body: ioutil.NopCloser(bytes.NewReader(body)),
}
resp.Write(w)
}
func loadTestCert() (tls.Certificate, error) {
return tls.LoadX509KeyPair("./test-fixtures/selfsigned.crt", "./test-fixtures/selfsigned.key")
}

73
proto/controlmsg.go Normal file
View file

@ -0,0 +1,73 @@
package proto
import (
"errors"
"fmt"
"net/http"
"regexp"
)
const (
HTTPProtocol = "HTTP"
)
// Action represents type of ControlMsg request.
type Action int
// ControlMessage actions.
const (
RequestClientSession Action = iota
)
// ControlMessage headers
const (
ForwardedHeader = "Forwarded"
)
// ControlMessage is sent from server to client to establish tunneled connection.
type ControlMessage struct {
Action Action
Protocol string
ForwardedFor string
ForwardedBy string
URLPath string
}
var xffRegexp = regexp.MustCompile("(for|by|proto|path)=([^;$]+)")
// NewControlMessage creates control message based on `Forwarded` http header.
func ParseControlMessage(h http.Header) (*ControlMessage, error) {
v := h.Get(ForwardedHeader)
if v == "" {
return nil, errors.New("missing Forwarded header")
}
var msg = &ControlMessage{}
for _, i := range xffRegexp.FindAllStringSubmatch(v, -1) {
switch i[1] {
case "for":
msg.ForwardedFor = i[2]
case "by":
msg.ForwardedBy = i[2]
case "proto":
msg.Protocol = i[2]
case "path":
msg.URLPath = i[2]
}
}
return msg, nil
}
// WriteTo writes ControlMessage to `Forwarded` http header, "by" and "for" parameters
// take form of full IP and port.
//
// If the server receiving proxied requests requires some address-based functionality,
// this parameter MAY instead contain an IP address (and, potentially, a port number)
//
// see https://tools.ietf.org/html/rfc7239.
func (c *ControlMessage) WriteTo(h http.Header) {
h.Set(ForwardedHeader, fmt.Sprintf("for=%s; by=%s; proto=%s; path=%s",
c.ForwardedFor, c.ForwardedBy, c.Protocol, c.URLPath))
}

24
proto/controlmsg_test.go Normal file
View file

@ -0,0 +1,24 @@
package proto
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func TestControlMessage_WriteParse(t *testing.T) {
msg := &ControlMessage{
Protocol: "tcp",
ForwardedFor: "127.0.0.1:58104",
ForwardedBy: "127.0.0.1:7777",
URLPath: "/some/path",
}
h := make(http.Header)
msg.WriteTo(h)
actual, err := ParseControlMessage(h)
assert.Nil(t, err)
assert.Equal(t, msg, actual)
}

10
proxy.go Normal file
View file

@ -0,0 +1,10 @@
package h2tun
import (
"io"
"github.com/koding/h2tun/proto"
)
// ProxyFunc is responsible for forwarding a remote connection to local server and writing the response back.
type ProxyFunc func(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage)

143
server.go
View file

@ -1,6 +1,7 @@
package h2tun
import (
"bufio"
"crypto/tls"
"fmt"
"io"
@ -10,22 +11,26 @@ import (
"time"
"github.com/andrew-d/id"
"github.com/koding/h2tun/proto"
"github.com/koding/logging"
"golang.org/x/net/http2"
)
// TODO mma introduce config object
// TODO mma add ListenerFunc func(net.Listener) net.Listener to allow for tls listener decoration
// TODO mma add dynamic allowed clients modifications
// TODO mma document
//
// TODO (phase2) mma add dynamic allowed clients modifications
type Server struct {
allowedClients []*AllowedClient
listener net.Listener
listener net.Listener
httpClient *http.Client
hostConn map[string]net.Conn
hostConnMu sync.RWMutex
tcpPorts map[int]*AllowedClient
log logging.Logger
}
@ -100,19 +105,19 @@ func (s *Server) handleClient(conn net.Conn) {
id, err := peerID(conn.(*tls.Conn))
if err != nil {
s.log.Warning("Certificate error: %s", err)
goto cleanup
goto reject
}
client, ok = s.checkID(id)
if !ok {
s.log.Warning("Unknown certificate: %q", id.String())
goto cleanup
goto reject
}
req, err = http.NewRequest(http.MethodConnect, url(client, ""), nil)
req, err = http.NewRequest(http.MethodConnect, url(client.Host), nil)
if err != nil {
s.log.Error("Invalid host %q for client %q", client.Host, client.ID)
goto cleanup
goto reject
}
if err = conn.SetDeadline(time.Time{}); err != nil {
@ -122,28 +127,37 @@ func (s *Server) handleClient(conn net.Conn) {
if err := s.addHostConn(client, conn); err != nil {
s.log.Warning("Could not add host: %s", err)
goto cleanup
goto reject
}
resp, err = s.httpClient.Do(req)
if err != nil {
s.log.Warning("Handshake failed %s", err)
goto cleanup
goto reject
}
if resp.StatusCode != http.StatusOK {
s.log.Warning("Handshake failed")
goto cleanup
goto reject
}
return
cleanup:
reject:
conn.Close()
if client != nil {
s.deleteHostConn(client.Host)
}
}
func (s *Server) checkID(id id.ID) (*AllowedClient, bool) {
for _, c := range s.allowedClients {
if id.Equals(c.ID) {
return c, true
}
}
return nil, false
}
func (s *Server) addHostConn(client *AllowedClient, conn net.Conn) error {
key := hostPort(client.Host)
@ -166,19 +180,9 @@ func (s *Server) deleteHostConn(host string) {
}
func hostPort(host string) string {
// TODO mma add support for custom ports
return fmt.Sprint(host, ":443")
}
func (s *Server) checkID(id id.ID) (*AllowedClient, bool) {
for _, c := range s.allowedClients {
if id.Equals(c.ID) {
return c, true
}
}
return nil, false
}
func (s *Server) listenClientListeners() {
for _, client := range s.allowedClients {
if client.Listeners == nil {
@ -198,49 +202,96 @@ func (s *Server) listen(l net.Listener, client *AllowedClient) {
s.log.Warning("Accept failed: %s", err)
continue
}
s.log.Debug("Accepted connection from %q", conn.RemoteAddr().String())
s.log.Debug("Accepted connection from %q", conn.RemoteAddr())
// TODO mma get Protocol from Network
// TODO mma get LocalIP from Addr
msg := &ControlMessage{
Action: RequestClientSession,
Protocol: RAW,
msg := &proto.ControlMessage{
Action: proto.RequestClientSession,
Protocol: l.Addr().Network(),
ForwardedFor: conn.RemoteAddr().String(),
ForwardedBy: conn.LocalAddr().String(),
}
go s.proxy(conn, client, msg)
go s.proxy(client.Host, conn, conn, msg)
}
}
func (s *Server) proxy(conn net.Conn, client *AllowedClient, msg *ControlMessage) {
defer conn.Close()
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
msg := &proto.ControlMessage{
Action: proto.RequestClientSession,
Protocol: proto.HTTPProtocol,
ForwardedFor: r.RemoteAddr,
ForwardedBy: r.Host,
URLPath: r.URL.Path,
}
s.proxy(trimPort(r.Host), w, r, msg)
}
func trimPort(hostPort string) (host string) {
host, _, _ = net.SplitHostPort(hostPort)
if host == "" {
return hostPort
}
return
}
func (s *Server) proxy(host string, w io.Writer, r interface{}, msg *proto.ControlMessage) {
s.log.Debug("Proxy init %s %v", host, msg)
defer func() {
if c, ok := r.(io.Closer); ok {
c.Close()
}
}()
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
req, err := http.NewRequest(http.MethodPut, url(client, ""), pr)
req, err := http.NewRequest(http.MethodPut, url(host), pr)
if err != nil {
s.log.Error("Request creation failed: %s", err)
return
}
msg.WriteTo(req.Header)
// read from caller, write to tunnel client
var wg sync.WaitGroup
wg.Add(1)
go func() {
transfer("local to remote", pw, conn, s.log)
wg.Done()
}()
var localToRemoteDone = make(chan struct{})
// read from tunnel client, write to caller
resp, err := s.httpClient.Do(req)
if err != nil {
s.log.Error("Proxing conn from %q to %q failed: %s", conn.RemoteAddr().String(), client.Host, err)
return
localToRemote := func() {
if hr, ok := r.(*http.Request); ok {
hr.Write(pw)
pw.Close()
} else {
transfer("local to remote", pw, r.(io.ReadCloser), s.log)
}
close(localToRemoteDone)
}
transfer("remote to local", conn, resp.Body, s.log)
wg.Wait()
remoteToLocal := func() {
resp, err := s.httpClient.Do(req)
if err != nil {
s.log.Error("Proxing conn to client %q failed: %s", host, err)
return
}
if hw, ok := w.(http.ResponseWriter); ok {
pr, err := http.ReadResponse(bufio.NewReader(resp.Body), r.(*http.Request))
if err != nil {
s.log.Error("Reading HTTP response failed: %s", err)
return
}
copyHeader(hw.Header(), pr.Header)
hw.WriteHeader(pr.StatusCode)
transfer("remote to local", hw, pr.Body, s.log)
} else {
transfer("remote to local", w, resp.Body, s.log)
}
}
go localToRemote()
remoteToLocal()
<-localToRemoteDone
s.log.Debug("Proxy over %s %v", host, msg)
}
func (s *Server) Addr() net.Addr {

View file

@ -3,12 +3,13 @@ package h2tun
import (
"fmt"
"io"
"net/http"
"github.com/koding/logging"
)
func url(client *AllowedClient, path string) string {
return fmt.Sprint("https://", client.Host, path)
func url(host string) string {
return fmt.Sprint("https://", host)
}
type closeWriter interface {
@ -19,7 +20,7 @@ type closeReader interface {
}
func transfer(side string, dst io.Writer, src io.ReadCloser, log logging.Logger) {
log.Debug("proxing")
log.Debug("proxing %s", side)
n, err := io.Copy(dst, src)
if err != nil {
@ -36,5 +37,13 @@ func transfer(side string, dst io.Writer, src io.ReadCloser, log logging.Logger)
src.Close()
}
log.Debug("done proxing %d bytes", n)
log.Debug("done proxing %s %d bytes", side, n)
}
func copyHeader(dst, src http.Header) {
for k, v := range src {
vv := make([]string, len(v))
copy(vv, v)
dst[k] = vv
}
}