mirror of
https://github.com/mmatczuk/go-http-tunnel.git
synced 2026-05-15 14:16:17 -06:00
http proxy: control message, improvements to proxing
This commit is contained in:
parent
dcb2f8ba09
commit
f8a8ff0163
8 changed files with 322 additions and 89 deletions
33
client.go
33
client.go
|
|
@ -7,11 +7,13 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/koding/h2tun/proto"
|
||||
"github.com/koding/logging"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
proxy ProxyFunc
|
||||
serverAddr string
|
||||
tlsConfig *tls.Config
|
||||
conn net.Conn
|
||||
|
|
@ -19,8 +21,9 @@ type Client struct {
|
|||
log logging.Logger
|
||||
}
|
||||
|
||||
func NewClient(serverAddr string, tlsConfig *tls.Config) *Client {
|
||||
func NewClient(proxy ProxyFunc, serverAddr string, tlsConfig *tls.Config) *Client {
|
||||
return &Client{
|
||||
proxy: proxy,
|
||||
serverAddr: serverAddr,
|
||||
tlsConfig: tlsConfig,
|
||||
httpServer: &http2.Server{},
|
||||
|
|
@ -36,12 +39,29 @@ func (c *Client) Connect() error {
|
|||
c.conn = conn
|
||||
|
||||
c.httpServer.ServeConn(conn, &http2.ServeConnOpts{
|
||||
Handler: http.HandlerFunc(c.proxy),
|
||||
Handler: http.HandlerFunc(c.serveHTTP),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodConnect {
|
||||
c.log.Info("Handshake: hello from server")
|
||||
http.Error(w, "Nice to see you", http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := proto.ParseControlMessage(r.Header)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
c.log.Debug("Proxy init %s %v", r.RemoteAddr, msg)
|
||||
c.proxy(flushWriter{w}, r.Body, msg)
|
||||
c.log.Debug("Proxy over %s %v", r.RemoteAddr, msg)
|
||||
}
|
||||
|
||||
type flushWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
|
@ -54,15 +74,6 @@ func (fw flushWriter) Write(p []byte) (n int, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *Client) proxy(w http.ResponseWriter, r *http.Request) {
|
||||
c.log.Info("New proxy request")
|
||||
if r.Method == http.MethodConnect {
|
||||
http.Error(w, "OK", http.StatusOK)
|
||||
} else {
|
||||
io.Copy(flushWriter{w}, r.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
if c.conn == nil {
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -1,26 +0,0 @@
|
|||
package h2tun
|
||||
|
||||
// ControlMessage is sent from server to client to establish tunneled connection.
|
||||
type ControlMessage struct {
|
||||
Action Action
|
||||
Protocol Type
|
||||
LocalPort int
|
||||
}
|
||||
|
||||
// Action represents type of ControlMsg request.
|
||||
type Action int
|
||||
|
||||
// ControlMessage actions.
|
||||
const (
|
||||
RequestClientSession Action = iota + 1
|
||||
)
|
||||
|
||||
// Type represents tunneled connection type.
|
||||
type Type int
|
||||
|
||||
// ControlMessage protocols.
|
||||
const (
|
||||
HTTP Type = iota + 1
|
||||
WS
|
||||
RAW
|
||||
)
|
||||
|
|
@ -1,22 +1,29 @@
|
|||
package h2tun_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/andrew-d/id"
|
||||
"github.com/koding/h2tun"
|
||||
"github.com/koding/h2tun/proto"
|
||||
"github.com/koding/logging"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHTTP(t *testing.T) {
|
||||
func TestTCP(t *testing.T) {
|
||||
logging.DefaultLevel = logging.DEBUG
|
||||
logging.DefaultHandler.SetLevel(logging.DEBUG)
|
||||
|
||||
|
|
@ -42,7 +49,7 @@ func TestHTTP(t *testing.T) {
|
|||
server.Start()
|
||||
defer server.Close()
|
||||
|
||||
client := h2tun.NewClient(server.Addr().String(), tlsConfig(cert))
|
||||
client := h2tun.NewClient(echoProxyFunc, server.Addr().String(), tlsConfig(cert))
|
||||
go client.Connect()
|
||||
defer client.Close()
|
||||
|
||||
|
|
@ -73,6 +80,80 @@ func TestHTTP(t *testing.T) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
func echoProxyFunc(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage) {
|
||||
io.Copy(w, r)
|
||||
}
|
||||
|
||||
func TestHTTP(t *testing.T) {
|
||||
logging.DefaultLevel = logging.DEBUG
|
||||
logging.DefaultHandler.SetLevel(logging.DEBUG)
|
||||
|
||||
cert, err := loadTestCert()
|
||||
assert.Nil(t, err)
|
||||
clientID := idFromTLSCert(cert)
|
||||
|
||||
server, err := h2tun.NewServer(
|
||||
tlsConfig(cert),
|
||||
[]*h2tun.AllowedClient{
|
||||
{
|
||||
ID: clientID,
|
||||
Host: "foobar.com",
|
||||
},
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
server.Start()
|
||||
defer server.Close()
|
||||
|
||||
client := h2tun.NewClient(echoHTTPProxyFunc, server.Addr().String(), tlsConfig(cert))
|
||||
go client.Connect()
|
||||
defer client.Close()
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
s := httptest.NewServer(server)
|
||||
defer s.Close()
|
||||
|
||||
const testPayload = "this is a test"
|
||||
|
||||
_, port, _ := net.SplitHostPort(s.Listener.Addr().String())
|
||||
r, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://foobar.com:%s/some/path", port), strings.NewReader(testPayload))
|
||||
assert.Nil(t, err)
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
assert.Nil(t, err)
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
assert.Equal(t, testPayload, string(body))
|
||||
}
|
||||
|
||||
func echoHTTPProxyFunc(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage) {
|
||||
req, err := http.ReadRequest(bufio.NewReader(r))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "text/plain")
|
||||
|
||||
resp := &http.Response{
|
||||
Status: "200 OK",
|
||||
StatusCode: 200,
|
||||
Proto: "HTTP/1.0",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 0,
|
||||
Request: req,
|
||||
Header: headers,
|
||||
ContentLength: int64(len(body)),
|
||||
Body: ioutil.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
resp.Write(w)
|
||||
}
|
||||
|
||||
func loadTestCert() (tls.Certificate, error) {
|
||||
return tls.LoadX509KeyPair("./test-fixtures/selfsigned.crt", "./test-fixtures/selfsigned.key")
|
||||
}
|
||||
|
|
|
|||
73
proto/controlmsg.go
Normal file
73
proto/controlmsg.go
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
package proto
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
const (
|
||||
HTTPProtocol = "HTTP"
|
||||
)
|
||||
|
||||
// Action represents type of ControlMsg request.
|
||||
type Action int
|
||||
|
||||
// ControlMessage actions.
|
||||
const (
|
||||
RequestClientSession Action = iota
|
||||
)
|
||||
|
||||
// ControlMessage headers
|
||||
const (
|
||||
ForwardedHeader = "Forwarded"
|
||||
)
|
||||
|
||||
// ControlMessage is sent from server to client to establish tunneled connection.
|
||||
type ControlMessage struct {
|
||||
Action Action
|
||||
Protocol string
|
||||
ForwardedFor string
|
||||
ForwardedBy string
|
||||
URLPath string
|
||||
}
|
||||
|
||||
var xffRegexp = regexp.MustCompile("(for|by|proto|path)=([^;$]+)")
|
||||
|
||||
// NewControlMessage creates control message based on `Forwarded` http header.
|
||||
func ParseControlMessage(h http.Header) (*ControlMessage, error) {
|
||||
v := h.Get(ForwardedHeader)
|
||||
if v == "" {
|
||||
return nil, errors.New("missing Forwarded header")
|
||||
}
|
||||
|
||||
var msg = &ControlMessage{}
|
||||
|
||||
for _, i := range xffRegexp.FindAllStringSubmatch(v, -1) {
|
||||
switch i[1] {
|
||||
case "for":
|
||||
msg.ForwardedFor = i[2]
|
||||
case "by":
|
||||
msg.ForwardedBy = i[2]
|
||||
case "proto":
|
||||
msg.Protocol = i[2]
|
||||
case "path":
|
||||
msg.URLPath = i[2]
|
||||
}
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// WriteTo writes ControlMessage to `Forwarded` http header, "by" and "for" parameters
|
||||
// take form of full IP and port.
|
||||
//
|
||||
// If the server receiving proxied requests requires some address-based functionality,
|
||||
// this parameter MAY instead contain an IP address (and, potentially, a port number)
|
||||
//
|
||||
// see https://tools.ietf.org/html/rfc7239.
|
||||
func (c *ControlMessage) WriteTo(h http.Header) {
|
||||
h.Set(ForwardedHeader, fmt.Sprintf("for=%s; by=%s; proto=%s; path=%s",
|
||||
c.ForwardedFor, c.ForwardedBy, c.Protocol, c.URLPath))
|
||||
}
|
||||
24
proto/controlmsg_test.go
Normal file
24
proto/controlmsg_test.go
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
package proto
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestControlMessage_WriteParse(t *testing.T) {
|
||||
msg := &ControlMessage{
|
||||
Protocol: "tcp",
|
||||
ForwardedFor: "127.0.0.1:58104",
|
||||
ForwardedBy: "127.0.0.1:7777",
|
||||
URLPath: "/some/path",
|
||||
}
|
||||
|
||||
h := make(http.Header)
|
||||
msg.WriteTo(h)
|
||||
actual, err := ParseControlMessage(h)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, msg, actual)
|
||||
}
|
||||
10
proxy.go
Normal file
10
proxy.go
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
package h2tun
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/koding/h2tun/proto"
|
||||
)
|
||||
|
||||
// ProxyFunc is responsible for forwarding a remote connection to local server and writing the response back.
|
||||
type ProxyFunc func(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage)
|
||||
143
server.go
143
server.go
|
|
@ -1,6 +1,7 @@
|
|||
package h2tun
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
|
|
@ -10,22 +11,26 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/andrew-d/id"
|
||||
"github.com/koding/h2tun/proto"
|
||||
"github.com/koding/logging"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// TODO mma introduce config object
|
||||
// TODO mma add ListenerFunc func(net.Listener) net.Listener to allow for tls listener decoration
|
||||
// TODO mma add dynamic allowed clients modifications
|
||||
// TODO mma document
|
||||
//
|
||||
// TODO (phase2) mma add dynamic allowed clients modifications
|
||||
|
||||
type Server struct {
|
||||
allowedClients []*AllowedClient
|
||||
listener net.Listener
|
||||
|
||||
listener net.Listener
|
||||
|
||||
httpClient *http.Client
|
||||
hostConn map[string]net.Conn
|
||||
hostConnMu sync.RWMutex
|
||||
|
||||
tcpPorts map[int]*AllowedClient
|
||||
|
||||
log logging.Logger
|
||||
}
|
||||
|
||||
|
|
@ -100,19 +105,19 @@ func (s *Server) handleClient(conn net.Conn) {
|
|||
id, err := peerID(conn.(*tls.Conn))
|
||||
if err != nil {
|
||||
s.log.Warning("Certificate error: %s", err)
|
||||
goto cleanup
|
||||
goto reject
|
||||
}
|
||||
|
||||
client, ok = s.checkID(id)
|
||||
if !ok {
|
||||
s.log.Warning("Unknown certificate: %q", id.String())
|
||||
goto cleanup
|
||||
goto reject
|
||||
}
|
||||
|
||||
req, err = http.NewRequest(http.MethodConnect, url(client, ""), nil)
|
||||
req, err = http.NewRequest(http.MethodConnect, url(client.Host), nil)
|
||||
if err != nil {
|
||||
s.log.Error("Invalid host %q for client %q", client.Host, client.ID)
|
||||
goto cleanup
|
||||
goto reject
|
||||
}
|
||||
|
||||
if err = conn.SetDeadline(time.Time{}); err != nil {
|
||||
|
|
@ -122,28 +127,37 @@ func (s *Server) handleClient(conn net.Conn) {
|
|||
|
||||
if err := s.addHostConn(client, conn); err != nil {
|
||||
s.log.Warning("Could not add host: %s", err)
|
||||
goto cleanup
|
||||
goto reject
|
||||
}
|
||||
|
||||
resp, err = s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
s.log.Warning("Handshake failed %s", err)
|
||||
goto cleanup
|
||||
goto reject
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
s.log.Warning("Handshake failed")
|
||||
goto cleanup
|
||||
goto reject
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
cleanup:
|
||||
reject:
|
||||
conn.Close()
|
||||
if client != nil {
|
||||
s.deleteHostConn(client.Host)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) checkID(id id.ID) (*AllowedClient, bool) {
|
||||
for _, c := range s.allowedClients {
|
||||
if id.Equals(c.ID) {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *Server) addHostConn(client *AllowedClient, conn net.Conn) error {
|
||||
key := hostPort(client.Host)
|
||||
|
||||
|
|
@ -166,19 +180,9 @@ func (s *Server) deleteHostConn(host string) {
|
|||
}
|
||||
|
||||
func hostPort(host string) string {
|
||||
// TODO mma add support for custom ports
|
||||
return fmt.Sprint(host, ":443")
|
||||
}
|
||||
|
||||
func (s *Server) checkID(id id.ID) (*AllowedClient, bool) {
|
||||
for _, c := range s.allowedClients {
|
||||
if id.Equals(c.ID) {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *Server) listenClientListeners() {
|
||||
for _, client := range s.allowedClients {
|
||||
if client.Listeners == nil {
|
||||
|
|
@ -198,49 +202,96 @@ func (s *Server) listen(l net.Listener, client *AllowedClient) {
|
|||
s.log.Warning("Accept failed: %s", err)
|
||||
continue
|
||||
}
|
||||
s.log.Debug("Accepted connection from %q", conn.RemoteAddr().String())
|
||||
s.log.Debug("Accepted connection from %q", conn.RemoteAddr())
|
||||
|
||||
// TODO mma get Protocol from Network
|
||||
// TODO mma get LocalIP from Addr
|
||||
msg := &ControlMessage{
|
||||
Action: RequestClientSession,
|
||||
Protocol: RAW,
|
||||
msg := &proto.ControlMessage{
|
||||
Action: proto.RequestClientSession,
|
||||
Protocol: l.Addr().Network(),
|
||||
ForwardedFor: conn.RemoteAddr().String(),
|
||||
ForwardedBy: conn.LocalAddr().String(),
|
||||
}
|
||||
|
||||
go s.proxy(conn, client, msg)
|
||||
go s.proxy(client.Host, conn, conn, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) proxy(conn net.Conn, client *AllowedClient, msg *ControlMessage) {
|
||||
defer conn.Close()
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
msg := &proto.ControlMessage{
|
||||
Action: proto.RequestClientSession,
|
||||
Protocol: proto.HTTPProtocol,
|
||||
ForwardedFor: r.RemoteAddr,
|
||||
ForwardedBy: r.Host,
|
||||
URLPath: r.URL.Path,
|
||||
}
|
||||
|
||||
s.proxy(trimPort(r.Host), w, r, msg)
|
||||
}
|
||||
|
||||
func trimPort(hostPort string) (host string) {
|
||||
host, _, _ = net.SplitHostPort(hostPort)
|
||||
if host == "" {
|
||||
return hostPort
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) proxy(host string, w io.Writer, r interface{}, msg *proto.ControlMessage) {
|
||||
s.log.Debug("Proxy init %s %v", host, msg)
|
||||
|
||||
defer func() {
|
||||
if c, ok := r.(io.Closer); ok {
|
||||
c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
defer pw.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodPut, url(client, ""), pr)
|
||||
req, err := http.NewRequest(http.MethodPut, url(host), pr)
|
||||
if err != nil {
|
||||
s.log.Error("Request creation failed: %s", err)
|
||||
return
|
||||
}
|
||||
msg.WriteTo(req.Header)
|
||||
|
||||
// read from caller, write to tunnel client
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
transfer("local to remote", pw, conn, s.log)
|
||||
wg.Done()
|
||||
}()
|
||||
var localToRemoteDone = make(chan struct{})
|
||||
|
||||
// read from tunnel client, write to caller
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
s.log.Error("Proxing conn from %q to %q failed: %s", conn.RemoteAddr().String(), client.Host, err)
|
||||
return
|
||||
localToRemote := func() {
|
||||
if hr, ok := r.(*http.Request); ok {
|
||||
hr.Write(pw)
|
||||
pw.Close()
|
||||
} else {
|
||||
transfer("local to remote", pw, r.(io.ReadCloser), s.log)
|
||||
}
|
||||
close(localToRemoteDone)
|
||||
}
|
||||
transfer("remote to local", conn, resp.Body, s.log)
|
||||
|
||||
wg.Wait()
|
||||
remoteToLocal := func() {
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
s.log.Error("Proxing conn to client %q failed: %s", host, err)
|
||||
return
|
||||
}
|
||||
if hw, ok := w.(http.ResponseWriter); ok {
|
||||
pr, err := http.ReadResponse(bufio.NewReader(resp.Body), r.(*http.Request))
|
||||
if err != nil {
|
||||
s.log.Error("Reading HTTP response failed: %s", err)
|
||||
return
|
||||
}
|
||||
copyHeader(hw.Header(), pr.Header)
|
||||
hw.WriteHeader(pr.StatusCode)
|
||||
transfer("remote to local", hw, pr.Body, s.log)
|
||||
} else {
|
||||
transfer("remote to local", w, resp.Body, s.log)
|
||||
}
|
||||
}
|
||||
|
||||
go localToRemote()
|
||||
remoteToLocal()
|
||||
<-localToRemoteDone
|
||||
|
||||
s.log.Debug("Proxy over %s %v", host, msg)
|
||||
}
|
||||
|
||||
func (s *Server) Addr() net.Addr {
|
||||
|
|
|
|||
17
utils.go
17
utils.go
|
|
@ -3,12 +3,13 @@ package h2tun
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/koding/logging"
|
||||
)
|
||||
|
||||
func url(client *AllowedClient, path string) string {
|
||||
return fmt.Sprint("https://", client.Host, path)
|
||||
func url(host string) string {
|
||||
return fmt.Sprint("https://", host)
|
||||
}
|
||||
|
||||
type closeWriter interface {
|
||||
|
|
@ -19,7 +20,7 @@ type closeReader interface {
|
|||
}
|
||||
|
||||
func transfer(side string, dst io.Writer, src io.ReadCloser, log logging.Logger) {
|
||||
log.Debug("proxing")
|
||||
log.Debug("proxing %s", side)
|
||||
|
||||
n, err := io.Copy(dst, src)
|
||||
if err != nil {
|
||||
|
|
@ -36,5 +37,13 @@ func transfer(side string, dst io.Writer, src io.ReadCloser, log logging.Logger)
|
|||
src.Close()
|
||||
}
|
||||
|
||||
log.Debug("done proxing %d bytes", n)
|
||||
log.Debug("done proxing %s %d bytes", side, n)
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, v := range src {
|
||||
vv := make([]string, len(v))
|
||||
copy(vv, v)
|
||||
dst[k] = vv
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue