tcp proxy: custom listeners

This commit is contained in:
mmatczuk 2016-09-19 23:42:14 +02:00
parent db930471bc
commit dcb2f8ba09
6 changed files with 255 additions and 127 deletions

View file

@ -3,9 +3,11 @@ package h2tun
import (
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"github.com/koding/logging"
"golang.org/x/net/http2"
)
@ -14,6 +16,7 @@ type Client struct {
tlsConfig *tls.Config
conn net.Conn
httpServer *http2.Server
log logging.Logger
}
func NewClient(serverAddr string, tlsConfig *tls.Config) *Client {
@ -21,6 +24,7 @@ func NewClient(serverAddr string, tlsConfig *tls.Config) *Client {
serverAddr: serverAddr,
tlsConfig: tlsConfig,
httpServer: &http2.Server{},
log: logging.NewLogger("client"),
}
}
@ -38,9 +42,25 @@ func (c *Client) Connect() error {
return nil
}
type flushWriter struct {
w io.Writer
}
func (fw flushWriter) Write(p []byte) (n int, err error) {
n, err = fw.w.Write(p)
if f, ok := fw.w.(http.Flusher); ok {
f.Flush()
}
return
}
func (c *Client) proxy(w http.ResponseWriter, r *http.Request) {
fmt.Print(r.URL)
http.Error(w, "OK", http.StatusOK)
c.log.Info("New proxy request")
if r.Method == http.MethodConnect {
http.Error(w, "OK", http.StatusOK)
} else {
io.Copy(flushWriter{w}, r.Body)
}
}
func (c *Client) Close() error {

26
control_msg.go Normal file
View file

@ -0,0 +1,26 @@
package h2tun
// ControlMessage is sent from server to client to establish tunneled connection.
type ControlMessage struct {
Action Action
Protocol Type
LocalPort int
}
// Action represents type of ControlMsg request.
type Action int
// ControlMessage actions.
const (
RequestClientSession Action = iota + 1
)
// Type represents tunneled connection type.
type Type int
// ControlMessage protocols.
const (
HTTP Type = iota + 1
WS
RAW
)

View file

@ -1,9 +1,14 @@
package h2tun_test
import (
"bytes"
"crypto/tls"
"crypto/x509"
"io"
"net"
"sync"
"testing"
"time"
"github.com/andrew-d/id"
"github.com/koding/h2tun"
@ -13,21 +18,59 @@ import (
func TestHTTP(t *testing.T) {
logging.DefaultLevel = logging.DEBUG
logging.DefaultHandler.SetLevel(logging.DEBUG)
cert, err := loadTestCert()
assert.Nil(t, err)
clientID := idFromTLSCert(cert)
server, err := h2tun.NewServer(tlsConfig(cert), []*h2tun.AllowedClient{
{ID: clientID, Host: "foobar.com"},
})
assert.NotNil(t, err)
listener, err := net.Listen("tcp", ":7777")
assert.Nil(t, err)
defer listener.Close()
server, err := h2tun.NewServer(
tlsConfig(cert),
[]*h2tun.AllowedClient{
{
ID: clientID,
Host: "foobar.com",
Listeners: []net.Listener{listener},
},
},
)
assert.Nil(t, err)
server.Start()
defer server.Close()
client := h2tun.NewClient(server.Addr().String(), tlsConfig(cert))
go client.Connect()
defer client.Close()
select {}
time.Sleep(time.Second)
conn, err := net.Dial("tcp", "localhost:7777")
assert.Nil(t, err)
const testPayload = "this is a test"
var wg sync.WaitGroup
wg.Add(2)
go func() {
for _, c := range testPayload {
_, err := conn.Write([]byte{byte(c)})
assert.Nil(t, err)
time.Sleep(time.Millisecond)
}
conn.Close()
wg.Done()
}()
go func() {
b := bytes.NewBuffer([]byte{})
io.Copy(b, conn)
assert.Equal(t, testPayload, b.String())
wg.Done()
}()
wg.Wait()
}
func loadTestCert() (tls.Certificate, error) {
@ -50,7 +93,6 @@ func tlsConfig(cert tls.Certificate) *tls.Config {
}
func idFromTLSCert(cert tls.Certificate) id.ID {
// Get the x509 cert for the given TLS certificate.
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
panic(err)

View file

@ -1,85 +0,0 @@
package h2tun_test
import (
"fmt"
"io"
"math/rand"
"net"
"net/http"
"time"
"github.com/gorilla/websocket"
)
func sleep() {
time.Sleep(time.Duration(rand.Intn(2000)) * time.Millisecond)
}
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
type EchoMessage struct {
Value string `json:"value,omitempty"`
Close bool `json:"close,omitempty"`
}
func handlerEchoWS(sleepFn func()) func(w http.ResponseWriter, r *http.Request) error {
return func(w http.ResponseWriter, r *http.Request) (e error) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return err
}
defer func() {
err := conn.Close()
if e == nil {
e = err
}
}()
if sleepFn != nil {
sleepFn()
}
for {
var msg EchoMessage
err := conn.ReadJSON(&msg)
if err != nil {
return fmt.Errorf("ReadJSON error: %s", err)
}
if sleepFn != nil {
sleepFn()
}
err = conn.WriteJSON(&msg)
if err != nil {
return fmt.Errorf("WriteJSON error: %s", err)
}
if msg.Close {
return nil
}
}
}
}
func handlerEchoHTTP(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, r.URL.Query().Get("echo"))
}
func handlerLatencyEchoHTTP(w http.ResponseWriter, r *http.Request) {
sleep()
handlerEchoHTTP(w, r)
}
func handlerEchoTCP(conn net.Conn) {
io.Copy(conn, conn)
}
func handlerLatencyEchoTCP(conn net.Conn) {
sleep()
handlerEchoTCP(conn)
}

153
server.go
View file

@ -3,17 +3,19 @@ package h2tun
import (
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"
"golang.org/x/net/http2"
"github.com/andrew-d/id"
"github.com/koding/logging"
"golang.org/x/net/http2"
)
// TODO mma add ListenerFunc func(net.Listener) net.Listener to allow for tls listener decoration
// TODO mma add dynamic allowed clients modifications
type Server struct {
allowedClients []*AllowedClient
listener net.Listener
@ -22,12 +24,15 @@ type Server struct {
hostConn map[string]net.Conn
hostConnMu sync.RWMutex
tcpPorts map[int]*AllowedClient
log logging.Logger
}
type AllowedClient struct {
ID id.ID
Host string
ID id.ID
Host string
Listeners []net.Listener
}
func NewServer(tlsConfig *tls.Config, allowedClients []*AllowedClient) (*Server, error) {
@ -43,16 +48,13 @@ func NewServer(tlsConfig *tls.Config, allowedClients []*AllowedClient) (*Server,
}
s.initHTTPClient()
go s.listenControl()
return s, nil
}
func (s *Server) initHTTPClient() {
// TODO mma try using connection pool for transport
s.hostConn = make(map[string]net.Conn)
s.httpClient = &http.Client{
// TODO mma try using connection pool for transport
Transport: &http2.Transport{
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
s.hostConnMu.RLock()
@ -68,6 +70,11 @@ func (s *Server) initHTTPClient() {
}
}
func (s *Server) Start() {
go s.listenControl()
s.listenClientListeners()
}
func (s *Server) listenControl() {
for {
conn, err := s.listener.Accept()
@ -82,70 +89,84 @@ func (s *Server) listenControl() {
func (s *Server) handleClient(conn net.Conn) {
s.log.Info("New client %s", conn.RemoteAddr().String())
var (
client *AllowedClient
req *http.Request
resp *http.Response
err error
ok bool
)
id, err := peerID(conn.(*tls.Conn))
if err != nil {
s.log.Warning("Certificate error: %s", err)
conn.Close()
return
goto cleanup
}
client, ok := s.checkID(id)
client, ok = s.checkID(id)
if !ok {
s.log.Warning("Unknown certificate: %q", id.String())
conn.Close()
return
goto cleanup
}
req, err := http.NewRequest(http.MethodConnect, fmt.Sprintf("https://%s", client.Host), nil)
req, err = http.NewRequest(http.MethodConnect, url(client, ""), nil)
if err != nil {
s.log.Error("Invalid host %q for client %q", client.Host, client.ID)
conn.Close()
return
goto cleanup
}
if err := conn.SetDeadline(time.Time{}); err != nil {
if err = conn.SetDeadline(time.Time{}); err != nil {
s.log.Warning("Setting no deadline failed: %s", err)
// recoverable
}
s.addHostConn(client, conn)
if err := s.addHostConn(client, conn); err != nil {
s.log.Warning("Could not add host: %s", err)
goto cleanup
}
resp, err := s.httpClient.Do(req)
resp, err = s.httpClient.Do(req)
if err != nil {
s.log.Warning("Handshake failed %s", err)
conn.Close()
s.deleteHostConn(client.Host)
return
goto cleanup
}
if resp.StatusCode != http.StatusOK {
s.log.Warning("Handshake failed")
conn.Close()
goto cleanup
}
return
cleanup:
conn.Close()
if client != nil {
s.deleteHostConn(client.Host)
return
}
}
func (s *Server) addHostConn(client *AllowedClient, conn net.Conn) {
func (s *Server) addHostConn(client *AllowedClient, conn net.Conn) error {
key := hostPort(client.Host)
s.hostConnMu.Lock()
oldConn := s.hostConn[key]
if oldConn != nil {
s.log.Info("Closing old connection for host &q, old was from %s, new is from %s",
client.Host, oldConn.RemoteAddr().String(), conn.RemoteAddr().String())
oldConn.Close()
defer s.hostConnMu.Unlock()
if c, ok := s.hostConn[key]; ok {
return fmt.Errorf("client %q already connected from %q", client.ID, c.RemoteAddr().String())
}
s.hostConn[key] = conn
s.hostConnMu.Unlock()
return nil
}
func (s *Server) deleteHostConn(host string) {
key := hostPort(host)
s.hostConnMu.Lock()
delete(s.hostConn, key)
delete(s.hostConn, hostPort(host))
s.hostConnMu.Unlock()
}
func hostPort(host string) string {
// TODO add support for custom ports
// TODO mma add support for custom ports
return fmt.Sprint(host, ":443")
}
@ -158,6 +179,70 @@ func (s *Server) checkID(id id.ID) (*AllowedClient, bool) {
return nil, false
}
func (s *Server) listenClientListeners() {
for _, client := range s.allowedClients {
if client.Listeners == nil {
continue
}
for _, l := range client.Listeners {
go s.listen(l, client)
}
}
}
func (s *Server) listen(l net.Listener, client *AllowedClient) {
for {
conn, err := l.Accept()
if err != nil {
s.log.Warning("Accept failed: %s", err)
continue
}
s.log.Debug("Accepted connection from %q", conn.RemoteAddr().String())
// TODO mma get Protocol from Network
// TODO mma get LocalIP from Addr
msg := &ControlMessage{
Action: RequestClientSession,
Protocol: RAW,
}
go s.proxy(conn, client, msg)
}
}
func (s *Server) proxy(conn net.Conn, client *AllowedClient, msg *ControlMessage) {
defer conn.Close()
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
req, err := http.NewRequest(http.MethodPut, url(client, ""), pr)
if err != nil {
s.log.Error("Request creation failed: %s", err)
return
}
// read from caller, write to tunnel client
var wg sync.WaitGroup
wg.Add(1)
go func() {
transfer("local to remote", pw, conn, s.log)
wg.Done()
}()
// read from tunnel client, write to caller
resp, err := s.httpClient.Do(req)
if err != nil {
s.log.Error("Proxing conn from %q to %q failed: %s", conn.RemoteAddr().String(), client.Host, err)
return
}
transfer("remote to local", conn, resp.Body, s.log)
wg.Wait()
}
func (s *Server) Addr() net.Addr {
return s.listener.Addr()
}

40
utils.go Normal file
View file

@ -0,0 +1,40 @@
package h2tun
import (
"fmt"
"io"
"github.com/koding/logging"
)
func url(client *AllowedClient, path string) string {
return fmt.Sprint("https://", client.Host, path)
}
type closeWriter interface {
CloseWrite() error
}
type closeReader interface {
CloseRead() error
}
func transfer(side string, dst io.Writer, src io.ReadCloser, log logging.Logger) {
log.Debug("proxing")
n, err := io.Copy(dst, src)
if err != nil {
log.Error("%s: copy error: %s", side, err)
}
if d, ok := dst.(closeWriter); ok {
d.CloseWrite()
}
if s, ok := src.(closeReader); ok {
s.CloseRead()
} else {
src.Close()
}
log.Debug("done proxing %d bytes", n)
}