mirror of
https://github.com/mmatczuk/go-http-tunnel.git
synced 2026-05-15 14:16:17 -06:00
tcp proxy: custom listeners
This commit is contained in:
parent
db930471bc
commit
dcb2f8ba09
6 changed files with 255 additions and 127 deletions
24
client.go
24
client.go
|
|
@ -3,9 +3,11 @@ package h2tun
|
|||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/koding/logging"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
|
|
@ -14,6 +16,7 @@ type Client struct {
|
|||
tlsConfig *tls.Config
|
||||
conn net.Conn
|
||||
httpServer *http2.Server
|
||||
log logging.Logger
|
||||
}
|
||||
|
||||
func NewClient(serverAddr string, tlsConfig *tls.Config) *Client {
|
||||
|
|
@ -21,6 +24,7 @@ func NewClient(serverAddr string, tlsConfig *tls.Config) *Client {
|
|||
serverAddr: serverAddr,
|
||||
tlsConfig: tlsConfig,
|
||||
httpServer: &http2.Server{},
|
||||
log: logging.NewLogger("client"),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -38,9 +42,25 @@ func (c *Client) Connect() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
type flushWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (fw flushWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = fw.w.Write(p)
|
||||
if f, ok := fw.w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Client) proxy(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Print(r.URL)
|
||||
http.Error(w, "OK", http.StatusOK)
|
||||
c.log.Info("New proxy request")
|
||||
if r.Method == http.MethodConnect {
|
||||
http.Error(w, "OK", http.StatusOK)
|
||||
} else {
|
||||
io.Copy(flushWriter{w}, r.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
|
|
|
|||
26
control_msg.go
Normal file
26
control_msg.go
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
package h2tun
|
||||
|
||||
// ControlMessage is sent from server to client to establish tunneled connection.
|
||||
type ControlMessage struct {
|
||||
Action Action
|
||||
Protocol Type
|
||||
LocalPort int
|
||||
}
|
||||
|
||||
// Action represents type of ControlMsg request.
|
||||
type Action int
|
||||
|
||||
// ControlMessage actions.
|
||||
const (
|
||||
RequestClientSession Action = iota + 1
|
||||
)
|
||||
|
||||
// Type represents tunneled connection type.
|
||||
type Type int
|
||||
|
||||
// ControlMessage protocols.
|
||||
const (
|
||||
HTTP Type = iota + 1
|
||||
WS
|
||||
RAW
|
||||
)
|
||||
|
|
@ -1,9 +1,14 @@
|
|||
package h2tun_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/andrew-d/id"
|
||||
"github.com/koding/h2tun"
|
||||
|
|
@ -13,21 +18,59 @@ import (
|
|||
|
||||
func TestHTTP(t *testing.T) {
|
||||
logging.DefaultLevel = logging.DEBUG
|
||||
logging.DefaultHandler.SetLevel(logging.DEBUG)
|
||||
|
||||
cert, err := loadTestCert()
|
||||
assert.Nil(t, err)
|
||||
clientID := idFromTLSCert(cert)
|
||||
|
||||
server, err := h2tun.NewServer(tlsConfig(cert), []*h2tun.AllowedClient{
|
||||
{ID: clientID, Host: "foobar.com"},
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
listener, err := net.Listen("tcp", ":7777")
|
||||
assert.Nil(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
server, err := h2tun.NewServer(
|
||||
tlsConfig(cert),
|
||||
[]*h2tun.AllowedClient{
|
||||
{
|
||||
ID: clientID,
|
||||
Host: "foobar.com",
|
||||
Listeners: []net.Listener{listener},
|
||||
},
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
server.Start()
|
||||
defer server.Close()
|
||||
|
||||
client := h2tun.NewClient(server.Addr().String(), tlsConfig(cert))
|
||||
go client.Connect()
|
||||
defer client.Close()
|
||||
|
||||
select {}
|
||||
time.Sleep(time.Second)
|
||||
|
||||
conn, err := net.Dial("tcp", "localhost:7777")
|
||||
assert.Nil(t, err)
|
||||
|
||||
const testPayload = "this is a test"
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
for _, c := range testPayload {
|
||||
_, err := conn.Write([]byte{byte(c)})
|
||||
assert.Nil(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
conn.Close()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
io.Copy(b, conn)
|
||||
assert.Equal(t, testPayload, b.String())
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func loadTestCert() (tls.Certificate, error) {
|
||||
|
|
@ -50,7 +93,6 @@ func tlsConfig(cert tls.Certificate) *tls.Config {
|
|||
}
|
||||
|
||||
func idFromTLSCert(cert tls.Certificate) id.ID {
|
||||
// Get the x509 cert for the given TLS certificate.
|
||||
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
|
|
|||
|
|
@ -1,85 +0,0 @@
|
|||
package h2tun_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func sleep() {
|
||||
time.Sleep(time.Duration(rand.Intn(2000)) * time.Millisecond)
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
|
||||
type EchoMessage struct {
|
||||
Value string `json:"value,omitempty"`
|
||||
Close bool `json:"close,omitempty"`
|
||||
}
|
||||
|
||||
func handlerEchoWS(sleepFn func()) func(w http.ResponseWriter, r *http.Request) error {
|
||||
return func(w http.ResponseWriter, r *http.Request) (e error) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
if e == nil {
|
||||
e = err
|
||||
}
|
||||
}()
|
||||
|
||||
if sleepFn != nil {
|
||||
sleepFn()
|
||||
}
|
||||
|
||||
for {
|
||||
var msg EchoMessage
|
||||
err := conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ReadJSON error: %s", err)
|
||||
}
|
||||
|
||||
if sleepFn != nil {
|
||||
sleepFn()
|
||||
}
|
||||
|
||||
err = conn.WriteJSON(&msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("WriteJSON error: %s", err)
|
||||
}
|
||||
|
||||
if msg.Close {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handlerEchoHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, r.URL.Query().Get("echo"))
|
||||
}
|
||||
|
||||
func handlerLatencyEchoHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
sleep()
|
||||
handlerEchoHTTP(w, r)
|
||||
}
|
||||
|
||||
func handlerEchoTCP(conn net.Conn) {
|
||||
io.Copy(conn, conn)
|
||||
}
|
||||
|
||||
func handlerLatencyEchoTCP(conn net.Conn) {
|
||||
sleep()
|
||||
handlerEchoTCP(conn)
|
||||
}
|
||||
153
server.go
153
server.go
|
|
@ -3,17 +3,19 @@ package h2tun
|
|||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
"github.com/andrew-d/id"
|
||||
"github.com/koding/logging"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// TODO mma add ListenerFunc func(net.Listener) net.Listener to allow for tls listener decoration
|
||||
// TODO mma add dynamic allowed clients modifications
|
||||
type Server struct {
|
||||
allowedClients []*AllowedClient
|
||||
listener net.Listener
|
||||
|
|
@ -22,12 +24,15 @@ type Server struct {
|
|||
hostConn map[string]net.Conn
|
||||
hostConnMu sync.RWMutex
|
||||
|
||||
tcpPorts map[int]*AllowedClient
|
||||
|
||||
log logging.Logger
|
||||
}
|
||||
|
||||
type AllowedClient struct {
|
||||
ID id.ID
|
||||
Host string
|
||||
ID id.ID
|
||||
Host string
|
||||
Listeners []net.Listener
|
||||
}
|
||||
|
||||
func NewServer(tlsConfig *tls.Config, allowedClients []*AllowedClient) (*Server, error) {
|
||||
|
|
@ -43,16 +48,13 @@ func NewServer(tlsConfig *tls.Config, allowedClients []*AllowedClient) (*Server,
|
|||
}
|
||||
s.initHTTPClient()
|
||||
|
||||
go s.listenControl()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Server) initHTTPClient() {
|
||||
// TODO mma try using connection pool for transport
|
||||
s.hostConn = make(map[string]net.Conn)
|
||||
|
||||
s.httpClient = &http.Client{
|
||||
// TODO mma try using connection pool for transport
|
||||
Transport: &http2.Transport{
|
||||
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
s.hostConnMu.RLock()
|
||||
|
|
@ -68,6 +70,11 @@ func (s *Server) initHTTPClient() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) Start() {
|
||||
go s.listenControl()
|
||||
s.listenClientListeners()
|
||||
}
|
||||
|
||||
func (s *Server) listenControl() {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
|
|
@ -82,70 +89,84 @@ func (s *Server) listenControl() {
|
|||
func (s *Server) handleClient(conn net.Conn) {
|
||||
s.log.Info("New client %s", conn.RemoteAddr().String())
|
||||
|
||||
var (
|
||||
client *AllowedClient
|
||||
req *http.Request
|
||||
resp *http.Response
|
||||
err error
|
||||
ok bool
|
||||
)
|
||||
|
||||
id, err := peerID(conn.(*tls.Conn))
|
||||
if err != nil {
|
||||
s.log.Warning("Certificate error: %s", err)
|
||||
conn.Close()
|
||||
return
|
||||
goto cleanup
|
||||
}
|
||||
client, ok := s.checkID(id)
|
||||
|
||||
client, ok = s.checkID(id)
|
||||
if !ok {
|
||||
s.log.Warning("Unknown certificate: %q", id.String())
|
||||
conn.Close()
|
||||
return
|
||||
goto cleanup
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodConnect, fmt.Sprintf("https://%s", client.Host), nil)
|
||||
req, err = http.NewRequest(http.MethodConnect, url(client, ""), nil)
|
||||
if err != nil {
|
||||
s.log.Error("Invalid host %q for client %q", client.Host, client.ID)
|
||||
conn.Close()
|
||||
return
|
||||
goto cleanup
|
||||
}
|
||||
|
||||
if err := conn.SetDeadline(time.Time{}); err != nil {
|
||||
if err = conn.SetDeadline(time.Time{}); err != nil {
|
||||
s.log.Warning("Setting no deadline failed: %s", err)
|
||||
// recoverable
|
||||
}
|
||||
|
||||
s.addHostConn(client, conn)
|
||||
if err := s.addHostConn(client, conn); err != nil {
|
||||
s.log.Warning("Could not add host: %s", err)
|
||||
goto cleanup
|
||||
}
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
resp, err = s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
s.log.Warning("Handshake failed %s", err)
|
||||
conn.Close()
|
||||
s.deleteHostConn(client.Host)
|
||||
return
|
||||
goto cleanup
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
s.log.Warning("Handshake failed")
|
||||
conn.Close()
|
||||
goto cleanup
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
cleanup:
|
||||
conn.Close()
|
||||
if client != nil {
|
||||
s.deleteHostConn(client.Host)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) addHostConn(client *AllowedClient, conn net.Conn) {
|
||||
func (s *Server) addHostConn(client *AllowedClient, conn net.Conn) error {
|
||||
key := hostPort(client.Host)
|
||||
|
||||
s.hostConnMu.Lock()
|
||||
oldConn := s.hostConn[key]
|
||||
if oldConn != nil {
|
||||
s.log.Info("Closing old connection for host &q, old was from %s, new is from %s",
|
||||
client.Host, oldConn.RemoteAddr().String(), conn.RemoteAddr().String())
|
||||
oldConn.Close()
|
||||
defer s.hostConnMu.Unlock()
|
||||
|
||||
if c, ok := s.hostConn[key]; ok {
|
||||
return fmt.Errorf("client %q already connected from %q", client.ID, c.RemoteAddr().String())
|
||||
}
|
||||
|
||||
s.hostConn[key] = conn
|
||||
s.hostConnMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) deleteHostConn(host string) {
|
||||
key := hostPort(host)
|
||||
s.hostConnMu.Lock()
|
||||
delete(s.hostConn, key)
|
||||
delete(s.hostConn, hostPort(host))
|
||||
s.hostConnMu.Unlock()
|
||||
}
|
||||
|
||||
func hostPort(host string) string {
|
||||
// TODO add support for custom ports
|
||||
// TODO mma add support for custom ports
|
||||
return fmt.Sprint(host, ":443")
|
||||
}
|
||||
|
||||
|
|
@ -158,6 +179,70 @@ func (s *Server) checkID(id id.ID) (*AllowedClient, bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func (s *Server) listenClientListeners() {
|
||||
for _, client := range s.allowedClients {
|
||||
if client.Listeners == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, l := range client.Listeners {
|
||||
go s.listen(l, client)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) listen(l net.Listener, client *AllowedClient) {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
s.log.Warning("Accept failed: %s", err)
|
||||
continue
|
||||
}
|
||||
s.log.Debug("Accepted connection from %q", conn.RemoteAddr().String())
|
||||
|
||||
// TODO mma get Protocol from Network
|
||||
// TODO mma get LocalIP from Addr
|
||||
msg := &ControlMessage{
|
||||
Action: RequestClientSession,
|
||||
Protocol: RAW,
|
||||
}
|
||||
|
||||
go s.proxy(conn, client, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) proxy(conn net.Conn, client *AllowedClient, msg *ControlMessage) {
|
||||
defer conn.Close()
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
defer pw.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodPut, url(client, ""), pr)
|
||||
if err != nil {
|
||||
s.log.Error("Request creation failed: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// read from caller, write to tunnel client
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
transfer("local to remote", pw, conn, s.log)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// read from tunnel client, write to caller
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
s.log.Error("Proxing conn from %q to %q failed: %s", conn.RemoteAddr().String(), client.Host, err)
|
||||
return
|
||||
}
|
||||
transfer("remote to local", conn, resp.Body, s.log)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *Server) Addr() net.Addr {
|
||||
return s.listener.Addr()
|
||||
}
|
||||
|
|
|
|||
40
utils.go
Normal file
40
utils.go
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package h2tun
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/koding/logging"
|
||||
)
|
||||
|
||||
func url(client *AllowedClient, path string) string {
|
||||
return fmt.Sprint("https://", client.Host, path)
|
||||
}
|
||||
|
||||
type closeWriter interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
type closeReader interface {
|
||||
CloseRead() error
|
||||
}
|
||||
|
||||
func transfer(side string, dst io.Writer, src io.ReadCloser, log logging.Logger) {
|
||||
log.Debug("proxing")
|
||||
|
||||
n, err := io.Copy(dst, src)
|
||||
if err != nil {
|
||||
log.Error("%s: copy error: %s", side, err)
|
||||
}
|
||||
|
||||
if d, ok := dst.(closeWriter); ok {
|
||||
d.CloseWrite()
|
||||
}
|
||||
|
||||
if s, ok := src.(closeReader); ok {
|
||||
s.CloseRead()
|
||||
} else {
|
||||
src.Close()
|
||||
}
|
||||
|
||||
log.Debug("done proxing %d bytes", n)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue