websocket: experiment

This commit is contained in:
mmatczuk 2017-02-27 10:40:10 +01:00
parent 5c54b16fd1
commit 0dfac07ed4
10 changed files with 267 additions and 43 deletions

View file

@ -73,10 +73,6 @@ func NewMultiHTTPProxy(localURLMap map[string]*url.URL, logger log.Logger) *HTTP
// Proxy is a ProxyFunc.
func (p *HTTPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage) {
if msg.Protocol != proto.HTTP {
panic(fmt.Sprintf("Expected proxy protocol, got %s", msg.Protocol))
}
rw, ok := w.(http.ResponseWriter)
if !ok {
panic(fmt.Sprintf("Expected http.ResponseWriter got %T", w))

View file

@ -12,6 +12,8 @@ import (
"testing"
"time"
"golang.org/x/net/websocket"
"github.com/mmatczuk/go-http-tunnel"
"github.com/mmatczuk/go-http-tunnel/id"
"github.com/mmatczuk/go-http-tunnel/log"
@ -36,9 +38,9 @@ type testContext struct {
var ctx testContext
func TestMain(m *testing.M) {
logger := log.NewFilterLogger(log.NewStdLogger(), 1)
logger := log.NewFilterLogger(log.NewStdLogger(), 2)
// prepare tunnel server
// server
cert, identifier := selfSignedCert()
s, err := tunnel.NewServer(&tunnel.ServerConfig{
Addr: ":0",
@ -52,25 +54,19 @@ func TestMain(m *testing.M) {
go s.Start()
defer s.Stop()
// run server HTTP interface
// server: expose HTTP
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
defer l.Close()
go http.Serve(l, s)
httpAddr := l.Addr()
// prepare local TCP echo service
echoTCPListener, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
defer echoTCPListener.Close()
go EchoTCP(echoTCPListener)
// server: expose TCP
tcpAddr := freeAddr()
// prepare local HTTP echo service
// echo: HTTP
echoHTTPListener, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
@ -78,10 +74,15 @@ func TestMain(m *testing.M) {
defer echoHTTPListener.Close()
go EchoHTTP(echoHTTPListener)
// allocate free port
tcpAddr := freeAddr()
// echo: TCP
echoTCPListener, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
defer echoTCPListener.Close()
go EchoTCP(echoTCPListener)
// prepare tunnel client
// client: tunnels
tunnels := map[string]*proto.Tunnel{
"http": {
Protocol: proto.HTTP,
@ -94,6 +95,7 @@ func TestMain(m *testing.M) {
},
}
// client: proxy HTTP
httpProxy := tunnel.NewMultiHTTPProxy(map[string]*url.URL{
"localhost:" + port(httpAddr): {
Scheme: "http",
@ -101,16 +103,25 @@ func TestMain(m *testing.M) {
},
}, log.NewContext(logger).WithPrefix("HTTP proxy", ":"))
// client: proxy WS
wsProxy := tunnel.NewWSProxy(
&url.URL{Scheme: "ws", Host: "127.0.0.1:" + port(echoHTTPListener.Addr())},
log.NewContext(logger).WithPrefix("WS proxy", ":"),
)
// client: proxy WS
tcpProxy := tunnel.NewMultiTCPProxy(map[string]string{
port(tcpAddr): echoTCPListener.Addr().String(),
}, log.NewContext(logger).WithPrefix("TCP proxy", ":"))
// client
c := tunnel.NewClient(&tunnel.ClientConfig{
ServerAddr: s.Addr(),
TLSClientConfig: TLSConfig(cert),
Tunnels: tunnels,
Proxy: tunnel.Proxy(tunnel.ProxyFuncs{
HTTP: httpProxy.Proxy,
WS: wsProxy.Proxy,
TCP: tcpProxy.Proxy,
}),
Logger: log.NewContext(logger).WithPrefix("client", ":"),
@ -120,6 +131,7 @@ func TestMain(m *testing.M) {
time.Sleep(500 * time.Millisecond)
defer c.Stop()
// test context
ctx.httpAddr = httpAddr
ctx.tcpAddr = tcpAddr
ctx.payload = randPayload(payloadInitialSize, payloadLen)
@ -133,12 +145,13 @@ func TestProxying(t *testing.T) {
name string
seq []uint
}{
{"http", "small", []uint{200, 160, 120, 80, 40, 20}},
{"http", "mid", []uint{40, 80, 120, 160, 200}},
{"http", "big", []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 200}},
{"tcp", "small", []uint{200, 160, 120, 80, 40, 20}},
{"tcp", "mid", []uint{40, 80, 120, 160, 200}},
{"tcp", "big", []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 200}},
{"ws", "small", []uint{200, 160, 120, 80, 40, 20}},
//{"http", "small", []uint{200, 160, 120, 80, 40, 20}},
//{"http", "mid", []uint{40, 80, 120, 160, 200}},
//{"http", "big", []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 200}},
//{"tcp", "small", []uint{200, 160, 120, 80, 40, 20}},
//{"tcp", "mid", []uint{40, 80, 120, 160, 200}},
//{"tcp", "big", []uint{0, 0, 0, 0, 0, 0, 0, 0, 0, 200}},
}
for _, tt := range data {
@ -150,8 +163,29 @@ func TestProxying(t *testing.T) {
switch tt.protocol {
case "http":
testHTTP(t, ctx.payload[idx], repeat)
case "ws":
config, err := websocket.NewConfig(
fmt.Sprintf("ws://localhost:%s/some/path", port(ctx.httpAddr)),
"http://localhost/",
)
if err != nil {
panic("Invalid config")
}
config.Header.Set("Authorization", "Basic dXNlcjpwYXNzd29yZA==")
ws, err := websocket.DialConfig(config)
if err != nil {
t.Fatal("Dial failed", err)
}
testConn(t, ws, ctx.payload[idx], repeat)
ws.Close()
case "tcp":
testTCP(t, ctx.payload[idx], repeat)
conn, err := net.Dial("tcp", ctx.tcpAddr.String())
if err != nil {
t.Fatal("Dial failed", err)
}
testConn(t, conn, ctx.payload[idx], repeat)
conn.Close()
default:
panic("Unexpected network type")
}
@ -167,7 +201,7 @@ func testHTTP(t *testing.T, payload []byte, repeat uint) {
if err != nil {
panic("Failed to create request")
}
r.SetBasicAuth("user", "password")
r.Header.Set("Authorization", "Basic dXNlcjpwYXNzd29yZA==")
resp, err := http.DefaultClient.Do(r)
if err != nil {
@ -193,13 +227,7 @@ func testHTTP(t *testing.T, payload []byte, repeat uint) {
}
}
func testTCP(t *testing.T, payload []byte, repeat uint) {
conn, err := net.Dial("tcp", ctx.tcpAddr.String())
if err != nil {
t.Fatal("Dial failed", err)
}
defer conn.Close()
func testConn(t *testing.T, conn net.Conn, payload []byte, repeat uint) {
var buf = bigBuffer()
var read, write int
for repeat > 0 {

View file

@ -7,12 +7,24 @@ import (
"math/rand"
"net"
"net/http"
"strings"
"golang.org/x/net/websocket"
)
// EchoHTTP starts serving HTTP requests on listener l, it accepts connections,
// reads request body and writes is back in response.
func EchoHTTP(l net.Listener) {
wsServer := &websocket.Server{Handler: func(ws *websocket.Conn) {
io.Copy(ws, ws)
}}
http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if isWebSocketConn(r) {
wsServer.ServeHTTP(w, r)
return
}
w.WriteHeader(http.StatusOK)
if r.Body != nil {
body, err := ioutil.ReadAll(r.Body)
@ -24,6 +36,23 @@ func EchoHTTP(l net.Listener) {
}))
}
func isWebSocketConn(r *http.Request) bool {
return r.Method == "GET" && headerContains(r.Header["Connection"], "upgrade") &&
headerContains(r.Header["Upgrade"], "websocket")
}
func headerContains(header []string, value string) bool {
for _, h := range header {
for _, v := range strings.Split(h, ",") {
if strings.EqualFold(strings.TrimSpace(v), value) {
return true
}
}
}
return false
}
// EchoTCP accepts connections and copies back received bytes.
func EchoTCP(l net.Listener) {
for {

View file

@ -30,6 +30,7 @@ const (
TCP4 = "tcp4"
TCP6 = "tcp6"
UNIX = "unix"
WS = "ws"
)
// ControlMessage is sent from server to client to establish tunneled
@ -39,9 +40,10 @@ type ControlMessage struct {
Protocol string
ForwardedFor string
ForwardedBy string
Path string
}
var xffRegexp = regexp.MustCompile("(for|proto|by)=([^;$]+)")
var xffRegexp = regexp.MustCompile("(proto|for|by|path)=([^;$]+)")
// ParseControlMessage creates new ControlMessage based on "Forwarded" http
// header.
@ -55,12 +57,14 @@ func ParseControlMessage(h http.Header) (*ControlMessage, error) {
for _, i := range xffRegexp.FindAllStringSubmatch(v, -1) {
switch i[1] {
case "proto":
msg.Protocol = i[2]
case "for":
msg.ForwardedFor = i[2]
case "by":
msg.ForwardedBy = i[2]
case "proto":
msg.Protocol = i[2]
case "path":
msg.Path = i[2]
}
}
@ -72,6 +76,6 @@ func ParseControlMessage(h http.Header) (*ControlMessage, error) {
//
// See Forwarded header specification https://tools.ietf.org/html/rfc7239.
func (c *ControlMessage) Update(h http.Header) {
v := fmt.Sprintf("for=%s; proto=%s; by=%s", c.ForwardedFor, c.Protocol, c.ForwardedBy)
v := fmt.Sprintf("proto=%s; for=%s; by=%s; path=%s", c.Protocol, c.ForwardedFor, c.ForwardedBy, c.Path)
h.Set(ForwardedHeader, v)
}

View file

@ -13,6 +13,7 @@ func TestControlMessage_WriteParse(t *testing.T) {
Protocol: "tcp",
ForwardedFor: "127.0.0.1:58104",
ForwardedBy: "127.0.0.1:7777",
Path: "/some/path",
}
var h = http.Header{}

View file

@ -14,6 +14,8 @@ type ProxyFunc func(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage)
type ProxyFuncs struct {
// HTTP is custom implementation of HTTP proxing.
HTTP ProxyFunc
// WS is custom implementation of WS proxing.
WS ProxyFunc
// TCP is custom implementation of TCP proxing.
TCP ProxyFunc
}
@ -25,6 +27,8 @@ func Proxy(p ProxyFuncs) ProxyFunc {
switch msg.Protocol {
case proto.HTTP:
f = p.HTTP
case proto.WS:
f = p.WS
case proto.TCP, proto.TCP4, proto.TCP6, proto.UNIX:
f = p.TCP
}

View file

@ -12,6 +12,7 @@ import (
"time"
"golang.org/x/net/http2"
"golang.org/x/net/websocket"
"github.com/mmatczuk/go-http-tunnel/id"
"github.com/mmatczuk/go-http-tunnel/log"
@ -38,6 +39,7 @@ type Server struct {
*registry
config *ServerConfig
listener net.Listener
wsServer *websocket.Server
connPool *connPool
httpClient *http.Client
logger log.Logger
@ -62,6 +64,10 @@ func NewServer(config *ServerConfig) (*Server, error) {
logger: logger,
}
s.wsServer = &websocket.Server{
Handler: s.ServeWS,
}
t := &http2.Transport{}
pool := newConnPool(t, s.disconnected)
t.ConnPool = pool
@ -487,8 +493,14 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe
<-done
}
// ServeHTTP proxies http connection to the client.
// ServeHTTP proxies HTTP connection to the client.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// handle websockets
if isWebSocketConn(r) {
s.wsServer.ServeHTTP(w, r)
return
}
resp, err := s.RoundTrip(r)
if err == errUnauthorised {
w.Header().Set("WWW-Authenticate", "Basic realm=\"User Visible Realm\"")
@ -526,6 +538,7 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
Protocol: proto.HTTP,
ForwardedFor: r.RemoteAddr,
ForwardedBy: r.Host,
Path: r.URL.Path,
}
identifier, auth, ok := s.Subscriber(r.Host)
@ -604,6 +617,47 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control
return resp, nil
}
// ServeWS is WebSocket handler.
func (s *Server) ServeWS(ws *websocket.Conn) {
r := ws.Request()
var err error
identifier, auth, ok := s.Subscriber(r.Host)
if !ok {
err = errClientNotSubscribed
}
if auth != nil {
user, password, _ := r.BasicAuth()
if auth.User != user || auth.Password != password {
err = errUnauthorised
}
r.Header.Del("Authorization")
}
if err != nil {
s.logger.Log(
"level", 0,
"action", "round trip failed",
"addr", r.RemoteAddr,
"url", r.URL,
"err", err,
)
ws.WriteClose(http.StatusBadGateway)
return
}
msg := &proto.ControlMessage{
Action: proto.Proxy,
Protocol: proto.WS,
ForwardedFor: r.RemoteAddr,
ForwardedBy: r.Host,
Path: r.URL.Path,
}
s.proxyConn(identifier, ws, msg)
}
func (s *Server) proxyRequest(identifier id.ID, msg *proto.ControlMessage, r io.Reader) (*http.Request, error) {
if msg.Action != proto.Proxy {
panic("Invalid action")

View file

@ -62,10 +62,6 @@ func NewMultiTCPProxy(localAddrMap map[string]string, logger log.Logger) *TCPPro
func (p *TCPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage) {
w = flushWriter{w}
if msg.Protocol != "tcp" {
panic(fmt.Sprintf("Expected proxy protocol, got %s", msg.Protocol))
}
target := p.localAddrFor(msg.ForwardedBy)
if target == "" {
p.logger.Log(
@ -95,10 +91,12 @@ func (p *TCPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage
))
close(done)
}()
transfer(local, r, log.NewContext(p.logger).With(
"dst", target,
"src", msg.ForwardedBy,
))
<-done
}

View file

@ -3,6 +3,7 @@ package tunnel
import (
"io"
"net/http"
"strings"
"github.com/mmatczuk/go-http-tunnel/log"
)
@ -42,6 +43,23 @@ func transfer(dst io.Writer, src io.ReadCloser, logger log.Logger) {
)
}
func isWebSocketConn(r *http.Request) bool {
return r.Method == "GET" && headerContains(r.Header["Connection"], "upgrade") &&
headerContains(r.Header["Upgrade"], "websocket")
}
func headerContains(header []string, value string) bool {
for _, h := range header {
for _, v := range strings.Split(h, ",") {
if strings.EqualFold(strings.TrimSpace(v), value) {
return true
}
}
}
return false
}
func copyHeader(dst, src http.Header) {
for k, v := range src {
vv := make([]string, len(v))

92
wsproxy.go Normal file
View file

@ -0,0 +1,92 @@
package tunnel
import (
"fmt"
"io"
"net/url"
"golang.org/x/net/websocket"
"github.com/mmatczuk/go-http-tunnel/log"
"github.com/mmatczuk/go-http-tunnel/proto"
)
// WSProxy forwards HTTP traffic.
type WSProxy struct {
// localURL specifies default base URL of local service.
localURL *url.URL
// logger is the proxy logger.
logger log.Logger
}
// NewWSProxy creates a new direct WSProxy, everything will be proxied to
// localURL.
func NewWSProxy(localURL *url.URL, logger log.Logger) *WSProxy {
if localURL == nil {
panic("Empty localURL")
}
if logger == nil {
logger = log.NewNopLogger()
}
p := &WSProxy{
localURL: localURL,
logger: logger,
}
return p
}
// Proxy is a ProxyFunc.
func (p *WSProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage) {
w = flushWriter{w}
target := p.localURL
if target == nil {
p.logger.Log(
"level", 1,
"msg", "no target",
"host", msg.ForwardedBy,
)
return
}
// TODO support path target.Path = singleJoiningSlash(target.Path, msg.Path)
config, err := websocket.NewConfig(target.String(), fmt.Sprintf("http://%s/", msg.ForwardedBy))
if err != nil {
p.logger.Log(
"level", 0,
"msg", "failed to create ws config",
"err", err,
)
return
}
// TODO support config.Header
ws, err := websocket.DialConfig(config)
if err != nil {
p.logger.Log(
"level", 0,
"msg", "ws dial failed",
"err", err,
)
return
}
done := make(chan struct{})
go func() {
transfer(w, ws, log.NewContext(p.logger).With(
"dst", msg.ForwardedBy,
"src", target,
))
close(done)
}()
transfer(ws, r, log.NewContext(p.logger).With(
"dst", target,
"src", msg.ForwardedBy,
))
<-done
}