Merge pull request #11 from mmatczuk/mmt/cli

command line tunnel
This commit is contained in:
Michał Matczuk 2017-02-15 09:46:22 +01:00 committed by GitHub
commit efe21afce5
20 changed files with 1233 additions and 521 deletions

10
TODO.md
View file

@ -1,18 +1,18 @@
Release 1.0
1. cli: cli and file configuration based on ngrok2 https://ngrok.com/docs#config
1. docs: README update
1. docs: new README
Backlog
1. monitoring: client connection state machine
1. monitoring: ping https://godoc.org/github.com/hashicorp/yamux#Session.Ping
1. monitoring: prometheus.io integration
1. proxy: WebSockets
1. docs: demo
1. proxy: UDP
1. proxy: file system
1. proxy: host_header modifier
1. security: certificate signature checks
1. cli: integrated certificate generation
1. monitoring: prometheus.io integration
Notes for README

View file

@ -16,9 +16,8 @@ import (
)
var (
// DefaultDialTimeout specifies how long client should wait for tunnel
// server or local service connection.
DefaultDialTimeout = 10 * time.Second
// DefaultTimeout specifies general purpose timeout.
DefaultTimeout = 10 * time.Second
)
// ClientConfig is configuration of the Client.
@ -51,6 +50,7 @@ type Client struct {
conn net.Conn
connMu sync.Mutex
httpServer *http2.Server
serverErr error
logger log.Logger
}
@ -63,6 +63,9 @@ func NewClient(config *ClientConfig) *Client {
if config.TLSClientConfig == nil {
panic("Missing TLSClientConfig")
}
if config.Tunnels == nil || len(config.Tunnels) == 0 {
panic("Missing Tunnels")
}
if config.Proxy == nil {
panic("Missing Proxy")
}
@ -85,29 +88,53 @@ func NewClient(config *ClientConfig) *Client {
// error, otherwise it spawns a new goroutine with http/2 server handling
// ControlMessages.
func (c *Client) Start() error {
c.connMu.Lock()
defer c.connMu.Unlock()
c.logger.Log(
"level", 1,
"action", "start",
)
for {
conn, err := c.connect()
if err != nil {
return err
}
c.httpServer.ServeConn(conn, &http2.ServeConnOpts{
Handler: http.HandlerFunc(c.serveHTTP),
})
c.logger.Log(
"level", 1,
"action", "disconnected",
)
c.connMu.Lock()
err = c.serverErr
c.conn = nil
c.serverErr = nil
c.connMu.Unlock()
if err != nil {
return fmt.Errorf("server error: %s", err)
}
}
}
func (c *Client) connect() (net.Conn, error) {
c.connMu.Lock()
defer c.connMu.Unlock()
if c.conn != nil {
return fmt.Errorf("already connected")
return nil, fmt.Errorf("already connected")
}
conn, err := c.dial()
if err != nil {
return fmt.Errorf("failed to connect to server: %s", err)
return nil, fmt.Errorf("failed to connect to server: %s", err)
}
c.conn = conn
go c.httpServer.ServeConn(conn, &http2.ServeConnOpts{
Handler: http.HandlerFunc(c.serveHTTP),
})
return nil
return conn, nil
}
func (c *Client) dial() (net.Conn, error) {
@ -129,7 +156,7 @@ func (c *Client) dial() (net.Conn, error) {
conn, err = c.config.DialTLS(network, addr, tlsConfig)
} else {
conn, err = tls.DialWithDialer(
&net.Dialer{Timeout: DefaultDialTimeout},
&net.Dialer{Timeout: DefaultTimeout},
network, addr, tlsConfig,
)
}
@ -181,7 +208,11 @@ func (c *Client) dial() (net.Conn, error) {
func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
c.handleHandshake(w, r)
if r.Header.Get(proto.ErrorHeader) != "" {
c.handleHandshakeError(w, r)
} else {
c.handleHandshake(w, r)
}
return
}
@ -218,6 +249,21 @@ func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) {
)
}
func (c *Client) handleHandshakeError(w http.ResponseWriter, r *http.Request) {
err := fmt.Errorf(r.Header.Get(proto.ErrorHeader))
c.logger.Log(
"level", 1,
"action", "handshake error",
"addr", r.RemoteAddr,
"err", err,
)
c.connMu.Lock()
c.serverErr = err
c.connMu.Unlock()
}
func (c *Client) handleHandshake(w http.ResponseWriter, r *http.Request) {
c.logger.Log(
"level", 1,
@ -227,18 +273,16 @@ func (c *Client) handleHandshake(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if c.config.Tunnels != nil {
b, err := json.Marshal(c.config.Tunnels)
if err != nil {
c.logger.Log(
"level", 0,
"msg", "handshake failed",
"err", err,
)
return
}
w.Write(b)
b, err := json.Marshal(c.config.Tunnels)
if err != nil {
c.logger.Log(
"level", 0,
"msg", "handshake failed",
"err", err,
)
return
}
w.Write(b)
}
// Stop closes the connection between client and server. After stopping client

View file

@ -10,6 +10,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/mmatczuk/tunnel/mock"
"github.com/mmatczuk/tunnel/proto"
)
func TestClient_Dial(t *testing.T) {
@ -23,7 +24,8 @@ func TestClient_Dial(t *testing.T) {
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
Proxy: Proxy(ProxyFuncs{}),
Tunnels: map[string]*proto.Tunnel{"test": {}},
Proxy: Proxy(ProxyFuncs{}),
})
conn, err := c.dial()
@ -57,6 +59,7 @@ func TestClient_DialBackoff(t *testing.T) {
TLSClientConfig: &tls.Config{},
DialTLS: d,
Backoff: b,
Tunnels: map[string]*proto.Tunnel{"test": {}},
Proxy: Proxy(ProxyFuncs{}),
})

37
cmd/cmd/log.go Normal file
View file

@ -0,0 +1,37 @@
package cmd
import (
"io"
"os"
"time"
kitlog "github.com/go-kit/kit/log"
"github.com/mmatczuk/tunnel/log"
)
// NewLogger returns logfmt based logger, printing messages up to log level
// logLevel.
func NewLogger(to string, level int) (log.Logger, error) {
var w io.Writer
switch to {
case "none":
return log.NewNopLogger(), nil
case "stdout":
w = os.Stdout
case "stderr":
w = os.Stderr
default:
f, err := os.Create(to)
if err != nil {
return nil, err
}
w = f
}
var logger kitlog.Logger
logger = kitlog.NewJSONLogger(kitlog.NewSyncWriter(w))
logger = kitlog.NewContext(logger).WithPrefix("time", kitlog.Timestamp(time.Now))
logger = log.NewFilterLogger(logger, level)
return logger, nil
}

152
cmd/tunnel/config.go Normal file
View file

@ -0,0 +1,152 @@
package main
import (
"fmt"
"io/ioutil"
"path/filepath"
"time"
"gopkg.in/yaml.v2"
"github.com/mmatczuk/tunnel/proto"
)
type BackoffConfig struct {
InitialInterval time.Duration `yaml:"interval,omitempty"`
Multiplier float64 `yaml:"multiplier,omitempty"`
MaxInterval time.Duration `yaml:"max_interval,omitempty"`
MaxElapsedTime time.Duration `yaml:"max_time,omitempty"`
}
type TunnelConfig struct {
Protocol string `yaml:"proto,omitempty"`
Addr string `yaml:"addr,omitempty"`
Auth string `yaml:"auth,omitempty"`
Host string `yaml:"host,omitempty"`
RemoteAddr string `yaml:"remote_addr,omitempty"`
}
type Config struct {
ServerAddr string `yaml:"server_addr,omitempty"`
InsecureSkipVerify bool `yaml:"insecure_skip_verify,omitempty"`
TLSCrt string `yaml:"tls_crt,omitempty"`
TLSKey string `yaml:"tls_key,omitempty"`
Backoff *BackoffConfig `yaml:"backoff,omitempty"`
Tunnels map[string]*TunnelConfig `yaml:"tunnels,omitempty"`
}
var defaultBackoffConfig = BackoffConfig{
InitialInterval: 500 * time.Millisecond,
Multiplier: 1.5,
MaxInterval: 60 * time.Second,
MaxElapsedTime: 15 * time.Minute,
}
func loadConfiguration(path string) (*Config, error) {
configBuf, err := ioutil.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read file %q: %s", path, err)
}
// deserialize/parse the config
var config Config
if err = yaml.Unmarshal(configBuf, &config); err != nil {
return nil, fmt.Errorf("failed to parse file %q: %s", path, err)
}
// set default values
if config.TLSCrt == "" {
config.TLSCrt = filepath.Join(filepath.Dir(path), "client.crt")
}
if config.TLSKey == "" {
config.TLSKey = filepath.Join(filepath.Dir(path), "client.key")
}
if config.Backoff == nil {
config.Backoff = &defaultBackoffConfig
} else {
if config.Backoff.InitialInterval == 0 {
config.Backoff.InitialInterval = defaultBackoffConfig.InitialInterval
}
if config.Backoff.Multiplier == 0 {
config.Backoff.Multiplier = defaultBackoffConfig.Multiplier
}
if config.Backoff.MaxInterval == 0 {
config.Backoff.MaxInterval = defaultBackoffConfig.MaxInterval
}
if config.Backoff.MaxElapsedTime == 0 {
config.Backoff.MaxElapsedTime = defaultBackoffConfig.MaxElapsedTime
}
}
// validate and normalize configuration
if config.ServerAddr == "" {
return nil, fmt.Errorf("server_addr: missing")
}
if config.ServerAddr, err = normalizeAddress(config.ServerAddr); err != nil {
return nil, fmt.Errorf("server_addr: %s", err)
}
for name, t := range config.Tunnels {
switch t.Protocol {
case proto.HTTP:
if err := validateHTTP(t); err != nil {
return nil, fmt.Errorf("%s %s", name, err)
}
case proto.TCP, proto.TCP4, proto.TCP6:
if err := validateTCP(t); err != nil {
return nil, fmt.Errorf("%s %s", name, err)
}
default:
return nil, fmt.Errorf("%s invalid protocol %q", name, t.Protocol)
}
}
return &config, nil
}
func validateHTTP(t *TunnelConfig) error {
var err error
if t.Host == "" {
return fmt.Errorf("host: missing")
}
if t.Addr == "" {
return fmt.Errorf("addr: missing")
}
if t.Addr, err = normalizeURL(t.Addr); err != nil {
return fmt.Errorf("addr: %s", err)
}
// unexpected
if t.RemoteAddr != "" {
return fmt.Errorf("remote_addr: unexpected")
}
return nil
}
func validateTCP(t *TunnelConfig) error {
var err error
if t.RemoteAddr, err = normalizeAddress(t.RemoteAddr); err != nil {
return fmt.Errorf("remote_addr: %s", err)
}
if t.Addr == "" {
return fmt.Errorf("addr: missing")
}
if t.Addr, err = normalizeAddress(t.Addr); err != nil {
return fmt.Errorf("addr: %s", err)
}
// unexpected
if t.Host != "" {
return fmt.Errorf("host: unexpected")
}
if t.Auth != "" {
return fmt.Errorf("auth: unexpected")
}
return nil
}

52
cmd/tunnel/normalize.go Normal file
View file

@ -0,0 +1,52 @@
package main
import (
"fmt"
"net"
"net/url"
"strconv"
"strings"
)
func normalizeAddress(addr string) (string, error) {
// normalize port to addr
if _, err := strconv.Atoi(addr); err == nil {
addr = ":" + addr
}
host, port, err := net.SplitHostPort(addr)
if err != nil {
return "", err
}
if host == "" {
host = "127.0.0.1"
}
return fmt.Sprintf("%s:%s", host, port), nil
}
func normalizeURL(rawurl string) (string, error) {
// check scheme
s := strings.SplitN(rawurl, "://", 2)
if len(s) > 1 {
switch s[0] {
case "http", "https":
default:
return "", fmt.Errorf("unsupported url schema, choose 'http' or 'https'")
}
} else {
rawurl = fmt.Sprint("http://", rawurl)
}
u, err := url.Parse(rawurl)
if err != nil {
return "", err
}
if u.Path != "" && !strings.HasSuffix(u.Path, "/") {
return "", fmt.Errorf("url must end with '/'")
}
return rawurl, nil
}

View file

@ -0,0 +1,111 @@
package main
import (
"strings"
"testing"
)
func TestNormalizeAddress(t *testing.T) {
t.Parallel()
tests := []struct {
addr string
expected string
error string
}{
{
addr: "22",
expected: "127.0.0.1:22",
},
{
addr: ":22",
expected: "127.0.0.1:22",
},
{
addr: "0.0.0.0:22",
expected: "0.0.0.0:22",
},
{
addr: "0.0.0.0",
error: "missing port",
},
{
addr: "",
error: "missing port",
},
}
for i, tt := range tests {
actual, err := normalizeAddress(tt.addr)
if actual != tt.expected {
t.Errorf("[%d] expected %q got %q err: %s", i, tt.expected, actual, err)
}
if tt.error != "" && err == nil {
t.Errorf("[%d] expected error", i)
}
if err != nil && (tt.error == "" || !strings.Contains(err.Error(), tt.error)) {
t.Errorf("[%d] expected error contains %q, got %q", i, tt.error, err)
}
}
}
func TestNormalizeURL(t *testing.T) {
t.Parallel()
tests := []struct {
rawurl string
expected string
error string
}{
{
rawurl: "localhost",
expected: "http://localhost",
},
{
rawurl: "localhost:80",
expected: "http://localhost:80",
},
{
rawurl: "localhost:80/path/",
expected: "http://localhost:80/path/",
},
{
rawurl: "localhost:80/path",
error: "/",
},
{
rawurl: "https://localhost",
expected: "https://localhost",
},
{
rawurl: "https://localhost:443",
expected: "https://localhost:443",
},
{
rawurl: "https://localhost:443/path/",
expected: "https://localhost:443/path/",
},
{
rawurl: "https://localhost:443/path",
error: "/",
},
{
rawurl: "ftp://localhost",
error: "unsupported url schema",
},
}
for i, tt := range tests {
actual, err := normalizeURL(tt.rawurl)
if actual != tt.expected {
t.Errorf("[%d] expected %q got %q, err: %s", i, tt.expected, actual, err)
}
if tt.error != "" && err == nil {
t.Errorf("[%d] expected error", i)
}
if err != nil && (tt.error == "" || !strings.Contains(err.Error(), tt.error)) {
t.Errorf("[%d] expected error contains %q, got %q", i, tt.error, err)
}
}
}

108
cmd/tunnel/options.go Normal file
View file

@ -0,0 +1,108 @@
package main
import (
"flag"
"fmt"
"os"
"os/user"
"path/filepath"
)
const usage1 string = `Usage: tunnel [OPTIONS] <command> [command args] [...]
options:
`
const usage2 string = `
Commands:
tunnel start [tunnel] [...] Start tunnels by name from config file
tunnel start-all Start all tunnels defined in config file
tunnel list List tunnel names from config file
Examples:
tunnel start www ssh
tunnel -config=config.yaml -log=stdout -log-level 2 start ssh
tunnel start-all
config.yaml:
server_addr: SERVER_IP:4443
insecure_skip_verify: true
tunnels:
www:
proto: http
addr: http://IP:8080/ui/
auth: user:password
host: ui.mytunnel.com
ssh:
proto: tcp
addr: IP:22
remote_addr: 0.0.0.0:2222
`
func init() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, usage1)
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, usage2)
}
}
type options struct {
debug bool
config string
logTo string
logLevel int
command string
args []string
}
func parseArgs() (*options, error) {
debug := flag.Bool("debug", false, "Starts gops agent")
config := flag.String("config", filepath.Join(defaultPath(), "config.yaml"), "Path to tunnel configuration file")
logTo := flag.String("log", "stdout", "Write log messages to this file, file name or 'stdout', 'stderr', 'none'")
logLevel := flag.Int("log-level", 1, "Level of messages to log, 0-3")
flag.Parse()
opts := &options{
debug: *debug,
config: *config,
logTo: *logTo,
logLevel: *logLevel,
command: flag.Arg(0),
}
switch opts.command {
case "list":
opts.args = flag.Args()[1:]
if len(opts.args) > 0 {
return nil, fmt.Errorf("list takes no arguments")
}
case "start":
opts.args = flag.Args()[1:]
if len(opts.args) == 0 {
return nil, fmt.Errorf("you must specify at least one tunnel to start")
}
case "start-all":
opts.args = flag.Args()[1:]
if len(opts.args) > 0 {
return nil, fmt.Errorf("start-all takes no arguments")
}
default:
return nil, fmt.Errorf("expected command")
}
return opts, nil
}
func defaultPath() string {
// user.Current() does not work on linux when cross compiling because
// it requires CGO; use os.Getenv("HOME") hack until we compile natively
var dir string
if user, err := user.Current(); err == nil {
dir = user.HomeDir
} else {
dir = os.Getenv("HOME")
}
return filepath.Join(dir, ".tunnel")
}

153
cmd/tunnel/tunnel.go Normal file
View file

@ -0,0 +1,153 @@
package main
import (
"crypto/tls"
"fmt"
"net/url"
"os"
"sort"
"gopkg.in/yaml.v2"
"github.com/cenkalti/backoff"
"github.com/google/gops/agent"
"github.com/mmatczuk/tunnel"
"github.com/mmatczuk/tunnel/cmd/cmd"
"github.com/mmatczuk/tunnel/log"
"github.com/mmatczuk/tunnel/proto"
)
func main() {
opts, err := parseArgs()
if err != nil {
fatal(err.Error())
}
if opts.debug {
if err := agent.Listen(nil); err != nil {
fatal("gops agent failed to start: %s", err)
}
}
logger, err := cmd.NewLogger(opts.logTo, opts.logLevel)
if err != nil {
fatal("failed to init logger: %s", err)
}
// read configuration file
config, err := loadConfiguration(opts.config)
if err != nil {
fatal("configuration error: %s", err)
}
switch opts.command {
case "list":
var names []string
for n := range config.Tunnels {
names = append(names, n)
}
sort.Strings(names)
for _, n := range names {
fmt.Println(n)
}
return
case "start":
tunnels := make(map[string]*TunnelConfig)
for _, arg := range opts.args {
t, ok := config.Tunnels[arg]
if !ok {
fatal("no such tunnel %q", arg)
}
tunnels[arg] = t
}
config.Tunnels = tunnels
}
cert, err := tls.LoadX509KeyPair(config.TLSCrt, config.TLSKey)
if err != nil {
fatal("failed to load certificate: %s", err)
}
b, err := yaml.Marshal(config)
if err != nil {
panic(err)
}
logger.Log("config", string(b))
client := tunnel.NewClient(&tunnel.ClientConfig{
ServerAddr: config.ServerAddr,
TLSClientConfig: tlsConfig(cert, config),
Backoff: expBackoff(config.Backoff),
Tunnels: tunnels(config.Tunnels),
Proxy: proxy(config.Tunnels, logger),
Logger: logger,
})
if err := client.Start(); err != nil {
fatal("%s", err)
}
}
func tlsConfig(cert tls.Certificate, config *Config) *tls.Config {
return &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: config.InsecureSkipVerify,
}
}
func expBackoff(config *BackoffConfig) *backoff.ExponentialBackOff {
b := backoff.NewExponentialBackOff()
b.InitialInterval = config.InitialInterval
b.Multiplier = config.Multiplier
b.MaxInterval = config.MaxInterval
b.MaxElapsedTime = config.MaxElapsedTime
return b
}
func tunnels(m map[string]*TunnelConfig) map[string]*proto.Tunnel {
p := make(map[string]*proto.Tunnel)
for name, t := range m {
p[name] = &proto.Tunnel{
Protocol: t.Protocol,
Host: t.Host,
Auth: t.Auth,
Addr: t.RemoteAddr,
}
}
return p
}
func proxy(m map[string]*TunnelConfig, logger log.Logger) tunnel.ProxyFunc {
httpURL := make(map[string]*url.URL)
tcpAddr := make(map[string]string)
for _, t := range m {
switch t.Protocol {
case proto.HTTP:
u, err := url.Parse(t.Addr)
if err != nil {
panic(err)
}
httpURL[t.Host] = u
case proto.TCP, proto.TCP4, proto.TCP6:
tcpAddr[t.RemoteAddr] = t.Addr
}
}
return tunnel.Proxy(tunnel.ProxyFuncs{
HTTP: tunnel.NewMultiHTTPProxy(httpURL, logger).Proxy,
TCP: tunnel.NewMultiTCPProxy(tcpAddr, logger).Proxy,
})
}
func fatal(format string, a ...interface{}) {
fmt.Fprintf(os.Stderr, format, a...)
fmt.Fprint(os.Stderr, "\n")
os.Exit(1)
}

56
cmd/tunneld/options.go Normal file
View file

@ -0,0 +1,56 @@
package main
import (
"flag"
"fmt"
"os"
)
const usage1 string = `Usage: tunneld [OPTIONS]
options:
`
func init() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, usage1)
flag.PrintDefaults()
}
}
// options specify arguments read command line arguments.
type options struct {
debug bool
httpAddr string
httpsAddr string
tunnelAddr string
tlsCrt string
tlsKey string
clients string
logTo string
logLevel int
}
func parseArgs() *options {
debug := flag.Bool("debug", false, "Starts gops agent")
httpAddr := flag.String("httpAddr", ":80", "Public address for HTTP connections, empty string to disable")
httpsAddr := flag.String("httpsAddr", ":443", "Public address listening for HTTPS connections, emptry string to disable")
tunnelAddr := flag.String("tunnelAddr", ":4443", "Public address listening for tunnel client")
tlsCrt := flag.String("tlsCrt", "", "Path to a TLS certificate file")
tlsKey := flag.String("tlsKey", "", "Path to a TLS key file")
clients := flag.String("clients", "", "Comma-separated list of tunnel client ids")
logTo := flag.String("log", "stdout", "Write log messages to this file, file name or 'stdout', 'stderr', 'none'")
logLevel := flag.Int("log-level", 1, "Level of messages to log, 0-3")
flag.Parse()
return &options{
debug: *debug,
httpAddr: *httpAddr,
httpsAddr: *httpsAddr,
tunnelAddr: *tunnelAddr,
tlsCrt: *tlsCrt,
tlsKey: *tlsKey,
clients: *clients,
logTo: *logTo,
logLevel: *logLevel,
}
}

114
cmd/tunneld/tunneld.go Normal file
View file

@ -0,0 +1,114 @@
package main
import (
"crypto/tls"
"fmt"
"net/http"
"os"
"strings"
"github.com/google/gops/agent"
"github.com/mmatczuk/tunnel"
"github.com/mmatczuk/tunnel/cmd/cmd"
"github.com/mmatczuk/tunnel/id"
)
func main() {
opts := parseArgs()
if opts.debug {
if err := agent.Listen(nil); err != nil {
fatal("gops agent failed to start: %s", err)
}
}
logger, err := cmd.NewLogger(opts.logTo, opts.logLevel)
if err != nil {
fatal("failed to init logger: %s", err)
}
// load certs
cert, err := tls.LoadX509KeyPair(opts.tlsCrt, opts.tlsKey)
if err != nil {
fatal("failed to load certificate: %s", err)
}
// setup server
server, err := tunnel.NewServer(&tunnel.ServerConfig{
Addr: opts.tunnelAddr,
TLSConfig: tlsConfig(cert),
Logger: logger,
})
if err != nil {
fatal("failed to create server: %s", err)
}
if opts.clients == "" {
logger.Log(
"level", 0,
"msg", "No clients",
)
} else {
for _, c := range strings.Split(opts.clients, ",") {
if c == "" {
fatal("empty client id")
}
identifier := id.ID{}
err := identifier.UnmarshalText([]byte(c))
if err != nil {
fatal("invalid identifier %q: %s", c, err)
}
server.Subscribe(identifier)
}
}
// start HTTP
if opts.httpAddr != "" {
go func() {
logger.Log(
"level", 1,
"action", "start http",
"addr", opts.httpAddr,
)
err := http.ListenAndServe(opts.httpAddr, server)
if err != nil {
fatal("failed to start HTTP: %s", err)
}
}()
}
// start HTTPS
if opts.httpsAddr != "" {
go func() {
logger.Log(
"level", 1,
"action", "start https",
"addr", opts.httpsAddr,
)
err := http.ListenAndServeTLS(opts.httpsAddr, opts.tlsCrt, opts.tlsKey, server)
if err != nil {
fatal("failed to start HTTPS: %s", err)
}
}()
}
server.Start()
}
func tlsConfig(cert tls.Certificate) *tls.Config {
return &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequestClientCert,
SessionTicketsDisabled: true,
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
PreferServerCipherSuites: true,
NextProtos: []string{"h2"},
}
}
func fatal(format string, a ...interface{}) {
fmt.Fprintf(os.Stderr, format, a...)
fmt.Fprint(os.Stderr, "\n")
os.Exit(1)
}

View file

@ -9,7 +9,7 @@ import (
"net/http"
"net/http/httputil"
"net/url"
"strings"
"path"
"github.com/mmatczuk/tunnel/log"
"github.com/mmatczuk/tunnel/proto"
@ -101,6 +101,8 @@ func (p *HTTPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessag
// is correctly routed based on localURL and localURLMap. If no URL can be found
// the request is canceled.
func (p *HTTPProxy) Director(req *http.Request) {
orig := *req.URL
target := p.localURLFor(req.URL)
if target == nil {
p.logger.Log(
@ -129,18 +131,26 @@ func (p *HTTPProxy) Director(req *http.Request) {
// explicitly disable User-Agent so it's not set to default value
req.Header.Set("User-Agent", "")
}
req.Host = req.URL.Host
p.logger.Log(
"level", 2,
"action", "url rewrite",
"from", &orig,
"to", req.URL,
)
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
if a == "" || a == "/" {
return b
}
return a + b
if b == "" || b == "/" {
return a
}
return path.Join(a, b)
}
func (p *HTTPProxy) localURLFor(u *url.URL) *url.URL {

View file

@ -15,6 +15,7 @@ import (
"github.com/mmatczuk/tunnel"
"github.com/mmatczuk/tunnel/id"
"github.com/mmatczuk/tunnel/log"
"github.com/mmatczuk/tunnel/proto"
)
const (
@ -37,13 +38,6 @@ var ctx testContext
func TestMain(m *testing.M) {
logger := log.NewFilterLogger(log.NewStdLogger(), 1)
// prepare server TCP listener
serverTCPListener, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
defer serverTCPListener.Close()
// prepare tunnel server
cert, identifier := selfSignedCert()
s, err := tunnel.NewServer(&tunnel.ServerConfig{
@ -54,30 +48,19 @@ func TestMain(m *testing.M) {
if err != nil {
panic(err)
}
s.Subscribe(identifier)
auth := &tunnel.Auth{
User: "user",
Password: "password",
}
if err := s.AddHost("localhost", auth, identifier); err != nil {
panic(err)
}
if err := s.AddListener(serverTCPListener, identifier); err != nil {
panic(err)
}
s.Start()
go s.Start()
defer s.Stop()
// run server HTTP interface
serverHTTPListener, err := net.Listen("tcp", ":0")
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
defer serverHTTPListener.Close()
go http.Serve(serverHTTPListener, s)
defer l.Close()
go http.Serve(l, s)
httpAddr := l.Addr()
// prepare local TCP echo service
echoTCPListener, err := net.Listen("tcp", ":0")
@ -95,37 +78,50 @@ func TestMain(m *testing.M) {
defer echoHTTPListener.Close()
go EchoHTTP(echoHTTPListener)
// prepare proxy
httpproxy := tunnel.NewMultiHTTPProxy(map[string]*url.URL{
"localhost:" + Port(serverHTTPListener.Addr()): {
Scheme: "http",
Host: echoHTTPListener.Addr().String(),
},
}, log.NewNopLogger())
tcpproxy := tunnel.NewMultiTCPProxy(map[string]string{
Port(serverTCPListener.Addr()): echoTCPListener.Addr().String(),
}, log.NewNopLogger())
proxy := tunnel.Proxy(tunnel.ProxyFuncs{
HTTP: httpproxy.Proxy,
TCP: tcpproxy.Proxy,
})
// allocate free port
tcpAddr := freeAddr()
// prepare tunnel client
tunnels := map[string]*proto.Tunnel{
"http": {
Protocol: proto.HTTP,
Host: "localhost",
Auth: "user:password",
},
"tcp": {
Protocol: proto.TCP,
Addr: tcpAddr.String(),
},
}
httpProxy := tunnel.NewMultiHTTPProxy(map[string]*url.URL{
"localhost:" + port(httpAddr): {
Scheme: "http",
Host: "127.0.0.1:" + port(echoHTTPListener.Addr()),
},
}, log.NewContext(logger).WithPrefix("HTTP proxy", ":"))
tcpProxy := tunnel.NewMultiTCPProxy(map[string]string{
port(tcpAddr): echoTCPListener.Addr().String(),
}, log.NewContext(logger).WithPrefix("TCP proxy", ":"))
c := tunnel.NewClient(&tunnel.ClientConfig{
ServerAddr: s.Addr(),
TLSClientConfig: TLSConfig(cert),
Proxy: proxy,
Logger: log.NewContext(logger).WithPrefix("client", ":"),
Tunnels: tunnels,
Proxy: tunnel.Proxy(tunnel.ProxyFuncs{
HTTP: httpProxy.Proxy,
TCP: tcpProxy.Proxy,
}),
Logger: log.NewContext(logger).WithPrefix("client", ":"),
})
if err := c.Start(); err != nil {
panic(err)
}
go c.Start()
// FIXME: replace sleep with client state change watch when ready
time.Sleep(500 * time.Millisecond)
defer c.Stop()
ctx.httpAddr = serverHTTPListener.Addr()
ctx.tcpAddr = serverTCPListener.Addr()
ctx.httpAddr = httpAddr
ctx.tcpAddr = tcpAddr
ctx.payload = randPayload(payloadInitialSize, payloadLen)
m.Run()
@ -166,7 +162,7 @@ func TestProxying(t *testing.T) {
func testHTTP(t *testing.T, payload []byte, repeat uint) {
for repeat > 0 {
url := fmt.Sprintf("http://localhost:%s/some/path", Port(ctx.httpAddr))
url := fmt.Sprintf("http://localhost:%s/some/path", port(ctx.httpAddr))
r, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(payload))
if err != nil {
panic("Failed to create request")
@ -253,6 +249,19 @@ func randPayload(initialSize, n int) [][]byte {
return payload
}
func freeAddr() net.Addr {
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
defer l.Close()
return l.Addr()
}
func port(addr net.Addr) string {
return fmt.Sprint(addr.(*net.TCPAddr).Port)
}
func selfSignedCert() (tls.Certificate, id.ID) {
cert, err := tls.LoadX509KeyPair("./test-fixtures/selfsigned.crt", "./test-fixtures/selfsigned.key")
if err != nil {

View file

@ -2,7 +2,6 @@ package integrationtest
import (
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"math/rand"
@ -10,11 +9,6 @@ import (
"net/http"
)
// Port returns port form TCP address.
func Port(addr net.Addr) string {
return fmt.Sprint(addr.(*net.TCPAddr).Port)
}
// EchoHTTP starts serving HTTP requests on listener l, it accepts connections,
// reads request body and writes is back in response.
func EchoHTTP(l net.Listener) {

30
pool.go
View file

@ -11,21 +11,25 @@ import (
"github.com/mmatczuk/tunnel/id"
)
type onDisconnectListener func(identifier id.ID)
type connPair struct {
conn net.Conn
clientConn *http2.ClientConn
}
type connPool struct {
t *http2.Transport
conns map[string]connPair // key is host:port
mu sync.RWMutex
t *http2.Transport
conns map[string]connPair // key is host:port
listener onDisconnectListener
mu sync.RWMutex
}
func newConnPool(t *http2.Transport) *connPool {
func newConnPool(t *http2.Transport, l onDisconnectListener) *connPool {
return &connPool{
t: t,
conns: make(map[string]connPair),
t: t,
listener: l,
conns: make(map[string]connPair),
}
}
@ -44,10 +48,13 @@ func (p *connPool) MarkDead(c *http2.ClientConn) {
p.mu.Lock()
defer p.mu.Unlock()
for identifier, cp := range p.conns {
for addr, cp := range p.conns {
if cp.clientConn == c {
cp.conn.Close()
delete(p.conns, identifier)
delete(p.conns, addr)
if p.listener != nil {
p.listener(p.addrToIdentifier(addr))
}
return
}
}
@ -84,6 +91,9 @@ func (p *connPool) DeleteConn(identifier id.ID) {
if cp, ok := p.conns[addr]; ok {
cp.conn.Close()
delete(p.conns, addr)
if p.listener != nil {
p.listener(identifier)
}
}
}
@ -94,3 +104,7 @@ func (p *connPool) URL(identifier id.ID) string {
func (p *connPool) addr(identifier id.ID) string {
return fmt.Sprint(identifier.String(), ":443")
}
func (p *connPool) addrToIdentifier(addr string) id.ID {
return id.NewFromString(addr[:len(addr)-4])
}

View file

@ -17,6 +17,7 @@ const (
// ControlMessage headers
const (
ErrorHeader = "Error"
ForwardedHeader = "Forwarded"
)

View file

@ -6,37 +6,56 @@ import (
"sync"
"github.com/mmatczuk/tunnel/id"
"github.com/mmatczuk/tunnel/log"
)
// RegistryItem holds information about hosts and listeners associated with a
// client.
type RegistryItem struct {
Hosts []string
Hosts []*HostAuth
Listeners []net.Listener
}
// HostAuth holds host and authentication info.
type HostAuth struct {
Host string
Auth *Auth
}
type hostInfo struct {
identifier id.ID
auth *Auth
}
// registry manages client tunnels information.
type registry struct {
items map[id.ID]*RegistryItem
hosts map[string]*hostInfo
mu sync.RWMutex
items map[id.ID]*RegistryItem
hosts map[string]*hostInfo
mu sync.RWMutex
logger log.Logger
}
// newRegistry creates new registry.
func newRegistry() *registry {
func newRegistry(logger log.Logger) *registry {
if logger == nil {
logger = log.NewNopLogger()
}
return &registry{
items: make(map[id.ID]*RegistryItem),
hosts: make(map[string]*hostInfo, 0),
items: make(map[id.ID]*RegistryItem),
hosts: make(map[string]*hostInfo, 0),
logger: logger,
}
}
// Subscribe adds new client to registry, this method is idempotent.
var voidRegistryItem = &RegistryItem{}
// Subscribe allows to connect client with a given identifier.
func (r *registry) Subscribe(identifier id.ID) {
r.logger.Log(
"level", 1,
"action", "subscribe",
"identifier", identifier,
)
r.mu.Lock()
defer r.mu.Unlock()
@ -44,13 +63,10 @@ func (r *registry) Subscribe(identifier id.ID) {
return
}
r.items[identifier] = &RegistryItem{
Hosts: make([]string, 0),
Listeners: make([]net.Listener, 0),
}
r.items[identifier] = voidRegistryItem
}
// IsSubscribed returns true if client is subscribed to registry.
// IsSubscribed returns true if client is subscribed.
func (r *registry) IsSubscribed(identifier id.ID) bool {
r.mu.RLock()
defer r.mu.RUnlock()
@ -60,12 +76,10 @@ func (r *registry) IsSubscribed(identifier id.ID) bool {
// Subscriber returns client identifier assigned to given host.
func (r *registry) Subscriber(hostPort string) (id.ID, *Auth, bool) {
host := trimPort(hostPort)
r.mu.RLock()
defer r.mu.RUnlock()
h, ok := r.hosts[host]
h, ok := r.hosts[trimPort(hostPort)]
if !ok {
return id.ID{}, nil, false
}
@ -73,8 +87,14 @@ func (r *registry) Subscriber(hostPort string) (id.ID, *Auth, bool) {
return h.identifier, h.auth, ok
}
// Unsubscribe removes client from registy and returns it's RegistryItem.
// Unsubscribe removes client from registry and returns it's RegistryItem.
func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem {
r.logger.Log(
"level", 1,
"action", "unsubscribe",
"identifier", identifier,
)
r.mu.Lock()
defer r.mu.Unlock()
@ -82,9 +102,14 @@ func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem {
if !ok {
return nil
}
if i == voidRegistryItem {
return nil
}
for _, h := range i.Hosts {
delete(r.hosts, h)
if i.Hosts != nil {
for _, h := range i.Hosts {
delete(r.hosts, h.Host)
}
}
delete(r.items, identifier)
@ -92,103 +117,71 @@ func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem {
return i
}
// AddHost assigns host to client unless the host is not already taken.
func (r *registry) AddHost(hostPort string, auth *Auth, identifier id.ID) error {
host := trimPort(hostPort)
func (r *registry) set(i *RegistryItem, identifier id.ID) error {
r.logger.Log(
"level", 2,
"action", "set registry item",
"identifier", identifier,
)
r.mu.Lock()
defer r.mu.Unlock()
if auth != nil && auth.User == "" {
return fmt.Errorf("Missing auth user")
}
i, ok := r.items[identifier]
j, ok := r.items[identifier]
if !ok {
return errClientNotSubscribed
}
if _, ok := r.hosts[host]; ok {
return fmt.Errorf("host %q is occupied", host)
}
r.hosts[host] = &hostInfo{
identifier: identifier,
auth: auth,
if j != voidRegistryItem {
return fmt.Errorf("attempt to overwrite registry item")
}
i.Hosts = append(i.Hosts, host)
if i.Hosts != nil {
for _, h := range i.Hosts {
if h.Auth != nil && h.Auth.User == "" {
return fmt.Errorf("missing auth user")
}
if _, ok := r.hosts[trimPort(h.Host)]; ok {
return fmt.Errorf("host %q is occupied", h.Host)
}
}
for _, h := range i.Hosts {
r.hosts[trimPort(h.Host)] = &hostInfo{
identifier: identifier,
auth: h.Auth,
}
}
}
r.items[identifier] = i
return nil
}
// DeleteHost unassigns host from client.
func (r *registry) DeleteHost(hostPort string, identifier id.ID) {
host := trimPort(hostPort)
r.mu.Lock()
defer r.mu.Unlock()
if h, ok := r.hosts[host]; !ok || h.identifier != identifier {
return
}
delete(r.hosts, host)
i := r.items[identifier]
for k, v := range i.Hosts {
if v == host {
i.Hosts = append(i.Hosts[:k], i.Hosts[k+1:]...)
return
}
}
}
// AddListener adds client listener.
func (r *registry) AddListener(l net.Listener, identifier id.ID) error {
if l == nil {
panic("Missing listener")
}
func (r *registry) clear(identifier id.ID) *RegistryItem {
r.logger.Log(
"level", 2,
"action", "clear registry item",
"identifier", identifier,
)
r.mu.Lock()
defer r.mu.Unlock()
i, ok := r.items[identifier]
if !ok {
return errClientNotSubscribed
if !ok || i == voidRegistryItem {
return nil
}
for k, v := range i.Listeners {
if v == l {
return fmt.Errorf("listener already added at %d", k)
if i.Hosts != nil {
for _, h := range i.Hosts {
delete(r.hosts, trimPort(h.Host))
}
}
i.Listeners = append(i.Listeners, l)
r.items[identifier] = voidRegistryItem
return nil
}
// DeleteListener removes listener from client. Listener must be closed to stop
// accepting go routine.
func (r *registry) DeleteListener(l net.Listener, identifier id.ID) {
if l == nil {
panic("Missing listener")
}
r.mu.Lock()
defer r.mu.Unlock()
i, ok := r.items[identifier]
if !ok {
return
}
for k, v := range i.Listeners {
if v == l {
i.Listeners = append(i.Listeners[:k], i.Listeners[k+1:]...)
return
}
}
return i
}
func trimPort(hostPort string) (host string) {

View file

@ -1,230 +0,0 @@
package tunnel
import (
"net"
"reflect"
"strings"
"testing"
"github.com/mmatczuk/tunnel/id"
)
var (
a = id.NewFromString("A")
b = id.NewFromString("B")
)
func TestRegistry_IsSubscribed(t *testing.T) {
t.Parallel()
r := newRegistry()
r.Subscribe(a)
if !r.IsSubscribed(a) {
t.Fatal("Client should be subscribed")
}
if r.IsSubscribed(b) {
t.Fatal("Client should not be subscribed")
}
}
func TestRegistry_UnsubscribeOnce(t *testing.T) {
t.Parallel()
r := newRegistry()
r.Subscribe(a)
if r.Unsubscribe(a) == nil {
t.Fatal("Unsubscribe should return RegistryItem")
}
if r.Unsubscribe(a) != nil {
t.Fatal("Unsubscribe should return nil")
}
}
func TestRegistry_UnsubscribeReturnsHosts(t *testing.T) {
t.Parallel()
r := newRegistry()
r.Subscribe(a)
r.AddHost("host0", nil, a)
r.AddHost("host1", nil, a)
i := r.Unsubscribe(a)
if !reflect.DeepEqual(i.Hosts, []string{"host0", "host1"}) {
t.Fatal("RegistryItem should contain hosts")
}
}
func TestRegistry_UnsubscribeReturnsListeners(t *testing.T) {
t.Parallel()
l0 := &net.TCPListener{}
l1 := &net.TCPListener{}
r := newRegistry()
r.Subscribe(a)
r.AddListener(l0, a)
r.AddListener(l1, a)
i := r.Unsubscribe(a)
if !reflect.DeepEqual(i.Listeners, []net.Listener{l0, l1}) {
t.Fatal("RegistryItem should contain hosts")
}
}
func TestRegistry_AddHostOnlyToSubscribed(t *testing.T) {
t.Parallel()
r := newRegistry()
if err := r.AddHost("host0", nil, a); err != errClientNotSubscribed {
t.Fatal("Adding host should be possible to subscribned clients only")
}
}
func TestRegistry_AddHostAuth(t *testing.T) {
t.Parallel()
r := newRegistry()
r.Subscribe(a)
r.AddHost("host0", &Auth{User: "A", Password: "B"}, a)
_, auth, _ := r.Subscriber("host0")
if !reflect.DeepEqual(auth, &Auth{User: "A", Password: "B"}) {
t.Fatal("Expected auth")
}
}
func TestRegistry_AddHostTrimsPort(t *testing.T) {
t.Parallel()
r := newRegistry()
r.Subscribe(a)
r.AddHost("host0", nil, a)
r.AddHost("host1:80", nil, a)
tests := []string{
"host0",
"host0:80",
"host0:8080",
"host1",
"host1:80",
"host1:8080",
}
for _, tt := range tests {
identifier, auth, ok := r.Subscriber(tt)
if !ok {
t.Fatal("Subscriber not found")
}
if auth != nil {
t.Fatal("Unexpeted auth")
}
if identifier != a {
t.Fatal("Unexpeted identifier")
}
}
}
func TestRegistry_AddHostOnce(t *testing.T) {
t.Parallel()
r := newRegistry()
r.Subscribe(a)
r.AddHost("host0", nil, a)
tests := []string{
"host0",
"host0:80",
"host0:8080",
}
for _, tt := range tests {
if err := r.AddHost(tt, nil, a); !strings.Contains(err.Error(), "occupied") {
t.Log(tt)
t.Errorf("Adding host %q should fail", tt)
}
}
r.Subscribe(b)
for _, tt := range tests {
if err := r.AddHost(tt, nil, b); !strings.Contains(err.Error(), "occupied") {
t.Log(err)
t.Errorf("Adding host %q should fail", tt)
}
}
}
func TestRegistry_DeleteHost(t *testing.T) {
t.Parallel()
r := newRegistry()
r.Subscribe(a)
r.AddHost("host0", nil, a)
r.DeleteHost("host0", a)
if _, _, ok := r.Subscriber("host0"); ok {
t.Fatal("Should delete host for a")
}
i := r.Unsubscribe(a)
if len(i.Hosts) != 0 {
t.Fatal("Host was not deleted from item")
}
}
func TestRegistry_DeleteOnlyOwnedHost(t *testing.T) {
t.Parallel()
r := newRegistry()
r.Subscribe(a)
r.AddHost("host0", nil, a)
r.DeleteHost("host0", b)
if _, _, ok := r.Subscriber("host0"); !ok {
t.Fatal("Should not delete host for b")
}
}
func TestRegistry_AddListenerOnlyToSubscribed(t *testing.T) {
t.Parallel()
r := newRegistry()
if err := r.AddListener(&net.TCPListener{}, a); err != errClientNotSubscribed {
t.Fatal("Adding listener should be possible to subscribned clients only")
}
}
func TestRegistry_AddListenerOnce(t *testing.T) {
t.Parallel()
l := &net.TCPListener{}
r := newRegistry()
r.Subscribe(a)
r.AddListener(l, a)
if err := r.AddListener(l, a); err == nil {
t.Fatal("Adding listenr should fail")
}
}
func TestRegistry_DeleteListener(t *testing.T) {
t.Parallel()
l := &net.TCPListener{}
r := newRegistry()
r.Subscribe(a)
r.AddListener(l, a)
r.DeleteListener(l, a)
i := r.Unsubscribe(a)
if len(i.Listeners) != 0 {
t.Fatal("Host was not deleted from item")
}
}

269
server.go
View file

@ -1,6 +1,7 @@
package tunnel
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
@ -45,26 +46,28 @@ type Server struct {
func NewServer(config *ServerConfig) (*Server, error) {
listener, err := listener(config)
if err != nil {
return nil, fmt.Errorf("tls listener failed :%s", err)
return nil, fmt.Errorf("tls listener failedAuthorization: %s", err)
}
t := &http2.Transport{}
pool := newConnPool(t)
t.ConnPool = pool
logger := config.Logger
if logger == nil {
logger = log.NewNopLogger()
}
return &Server{
registry: newRegistry(),
config: config,
listener: listener,
connPool: pool,
httpClient: &http.Client{Transport: t},
logger: logger,
}, nil
s := &Server{
registry: newRegistry(logger),
config: config,
listener: listener,
logger: logger,
}
t := &http2.Transport{}
pool := newConnPool(t, s.disconnected)
t.ConnPool = pool
s.connPool = pool
s.httpClient = &http.Client{Transport: t}
return s, nil
}
func listener(config *ServerConfig) (net.Listener, error) {
@ -82,18 +85,40 @@ func listener(config *ServerConfig) (net.Listener, error) {
return tls.Listen("tcp", config.Addr, config.TLSConfig)
}
// Start starts accepting connections form clients and allowed clients listeners.
// For accepting http traffic one must run server as a handler to http server.
// disconnected clears resources used by client, it's invoked by connection pool
// when client goes away.
func (s *Server) disconnected(identifier id.ID) {
s.logger.Log(
"level", 1,
"action", "disconnected",
"identifier", identifier,
)
i := s.registry.clear(identifier)
if i == nil {
return
}
for _, l := range i.Listeners {
s.logger.Log(
"level", 2,
"action", "close listener",
"identifier", identifier,
"addr", l.Addr(),
)
l.Close()
}
}
// Start starts accepting connections form clients. For accepting http traffic
// from end users server must be run as handler on http server.
func (s *Server) Start() {
s.logger.Log(
"level", 1,
"action", "start",
"addr", s.listener.Addr(),
)
go s.listenControl()
}
func (s *Server) listenControl() {
for {
conn, err := s.listener.Accept()
if err != nil {
@ -187,6 +212,11 @@ func (s *Server) handleClient(conn net.Conn) {
goto reject
}
{
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
req.WithContext(ctx)
}
resp, err = s.httpClient.Do(req)
if err != nil {
logger.Log(
@ -196,46 +226,53 @@ func (s *Server) handleClient(conn net.Conn) {
)
goto reject
}
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("Status %s", resp.Status)
logger.Log(
"level", 2,
"msg", "handshake failed",
"err", fmt.Errorf("Status %s", resp.Status),
"err", err,
)
goto reject
}
if resp.ContentLength > 0 {
var done chan struct{}
go func() {
err = json.NewDecoder(&io.LimitedReader{
R: resp.Body,
N: 126976,
}).Decode(&tunnels)
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Minute):
err = fmt.Errorf("timeout")
}
if resp.ContentLength == 0 {
err = fmt.Errorf("Tunnels Content-Legth: 0")
logger.Log(
"level", 2,
"msg", "handshake failed",
"err", err,
)
goto reject
}
if err != nil {
logger.Log(
"level", 2,
"msg", "handshake failed",
"err", err,
)
goto reject
}
if err = s.AddTunnels(tunnels, identifier); err != nil {
logger.Log(
"level", 2,
"msg", "handshake failed",
"err", err,
)
goto reject
}
if err = json.NewDecoder(&io.LimitedReader{R: resp.Body, N: 126976}).Decode(&tunnels); err != nil {
logger.Log(
"level", 2,
"msg", "handshake failed",
"err", err,
)
goto reject
}
if len(tunnels) == 0 {
err = fmt.Errorf("No tunnels")
logger.Log(
"level", 2,
"msg", "handshake failed",
"err", err,
)
goto reject
}
if err = s.addTunnels(tunnels, identifier); err != nil {
logger.Log(
"level", 2,
"msg", "handshake failed",
"err", err,
)
goto reject
}
logger.Log(
@ -246,60 +283,89 @@ func (s *Server) handleClient(conn net.Conn) {
return
reject:
s.logger.Log(
logger.Log(
"level", 1,
"action", "rejected",
"addr", conn.RemoteAddr(),
)
s.notifyError(err, identifier)
s.connPool.DeleteConn(identifier)
}
// AddTunnels invokes AddHost or AddListener based on data from proto.Tunnel. If
// a tunnel cannot be added whole batch is reverted.
func (s *Server) AddTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) error {
var (
hosts []string
listeners []net.Listener
err error
)
// notifyError tries to send error to client.
func (s *Server) notifyError(serverError error, identifier id.ID) {
if serverError == nil {
return
}
req, err := http.NewRequest(http.MethodConnect, s.connPool.URL(identifier), nil)
if err != nil {
s.logger.Log(
"level", 2,
"action", "client error notification failed",
"identifier", identifier,
"err", err,
)
return
}
req.Header.Set(proto.ErrorHeader, serverError.Error())
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
req.WithContext(ctx)
s.httpClient.Do(req)
}
// addTunnels invokes addHost or addListener based on data from proto.Tunnel. If
// a tunnel cannot be added whole batch is reverted.
func (s *Server) addTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) error {
i := &RegistryItem{
Hosts: []*HostAuth{},
Listeners: []net.Listener{},
}
var err error
for name, t := range tunnels {
switch t.Protocol {
case proto.HTTP:
err = s.AddHost(t.Host, NewAuth(t.Auth), identifier)
if err != nil {
goto rollback
}
hosts = append(hosts, t.Host)
i.Hosts = append(i.Hosts, &HostAuth{t.Host, NewAuth(t.Auth)})
case proto.TCP, proto.TCP4, proto.TCP6, proto.UNIX:
var l net.Listener
l, err = net.Listen(t.Protocol, t.Addr)
if err != nil {
goto rollback
}
listeners = append(listeners, l)
err := s.AddListener(l, identifier)
if err != nil {
goto rollback
}
s.logger.Log(
"level", 2,
"action", "open listener",
"identifier", identifier,
"addr", l.Addr(),
)
i.Listeners = append(i.Listeners, l)
default:
err = fmt.Errorf("unsupported protocol for tunnel %s: %s", name, t.Protocol)
goto rollback
}
}
err = s.set(i, identifier)
if err != nil {
goto rollback
}
for _, l := range i.Listeners {
s.listen(l, identifier)
}
return nil
rollback:
for _, h := range hosts {
s.DeleteHost(h, identifier)
}
for _, l := range listeners {
for _, l := range i.Listeners {
l.Close()
s.DeleteListener(l, identifier)
}
return err
@ -312,17 +378,6 @@ func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem {
return s.registry.Unsubscribe(identifier)
}
// AddListener adds listener to client.
func (s *Server) AddListener(l net.Listener, identifier id.ID) error {
if err := s.registry.AddListener(l, identifier); err != nil {
return err
}
go s.listen(l, identifier)
return nil
}
func (s *Server) listen(l net.Listener, identifier id.ID) {
for {
conn, err := l.Accept()
@ -330,6 +385,7 @@ func (s *Server) listen(l net.Listener, identifier id.ID) {
s.logger.Log(
"level", 2,
"msg", "accept connection failed",
"identifier", identifier,
"err", err,
)
if strings.Contains(err.Error(), "use of closed network connection") {
@ -352,6 +408,7 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe
s.logger.Log(
"level", 2,
"action", "proxy",
"identifier", identifier,
"ctrlMsg", msg,
)
@ -364,6 +421,7 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe
s.logger.Log(
"level", 0,
"msg", "proxy error",
"identifier", identifier,
"ctrlMsg", msg,
"err", err,
)
@ -371,17 +429,22 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe
return
}
go transfer(pw, conn, log.NewContext(s.logger).With(
"dir", "user to client",
"dst", identifier,
"src", conn.RemoteAddr(),
))
done := make(chan struct{})
go func() {
transfer(pw, conn, log.NewContext(s.logger).With(
"dir", "user to client",
"dst", identifier,
"src", conn.RemoteAddr(),
))
close(done)
}()
resp, err := s.httpClient.Do(req)
if err != nil {
s.logger.Log(
"level", 0,
"msg", "proxy error",
"identifier", identifier,
"ctrlMsg", msg,
"err", err,
)
@ -394,17 +457,27 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe
"dst", conn.RemoteAddr(),
"src", identifier,
))
<-done
}
// ServeHTTP proxies http connection to the client.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
resp, err := s.RoundTrip(r)
if err == errUnauthorised {
w.Header().Set("WWW-Authenticate", "Basic realm=\"User Visible Realm\"")
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
if err != nil {
s.logger.Log(
"level", 0,
"action", "round trip failed",
"addr", r.RemoteAddr,
"url", r.URL,
"err", err,
)
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
@ -438,6 +511,7 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
if auth.User != user || auth.Password != password {
return nil, errUnauthorised
}
r.Header.Del("Authorization")
}
return s.proxyHTTP(identifier, r, msg)
@ -447,6 +521,7 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control
s.logger.Log(
"level", 2,
"action", "proxy",
"identifier", identifier,
"ctrlMsg", msg,
)
@ -466,6 +541,7 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control
s.logger.Log(
"level", 0,
"msg", "proxy error",
"identifier", identifier,
"ctrlMsg", msg,
"err", err,
)
@ -474,6 +550,7 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control
s.logger.Log(
"level", 3,
"action", "transferred",
"identifier", identifier,
"bytes", cw.count,
"dir", "user to client",
"dst", r.Host,
@ -490,6 +567,14 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control
return nil, fmt.Errorf("proxy error: %s", err)
}
s.logger.Log(
"level", 2,
"action", "proxy done",
"identifier", identifier,
"ctrlMsg", msg,
"status code", resp.StatusCode,
)
return resp, nil
}

View file

@ -77,7 +77,7 @@ func (p *TCPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage
return
}
local, err := net.DialTimeout("tcp", target, DefaultDialTimeout)
local, err := net.DialTimeout("tcp", target, DefaultTimeout)
if err != nil {
p.logger.Log(
"level", 0,
@ -100,6 +100,7 @@ func (p *TCPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage
"dst", target,
"src", msg.ForwardedBy,
))
<-done
}
func (p *TCPProxy) localAddrFor(hostPort string) string {
@ -107,7 +108,7 @@ func (p *TCPProxy) localAddrFor(hostPort string) string {
return p.localAddr
}
// try host and port
// try hostPort
if addr := p.localAddrMap[hostPort]; addr != "" {
return addr
}
@ -118,6 +119,11 @@ func (p *TCPProxy) localAddrFor(hostPort string) string {
return addr
}
// try 0.0.0.0:port
if addr := p.localAddrMap[fmt.Sprintf("0.0.0.0:%s", port)]; addr != "" {
return addr
}
// try host
if addr := p.localAddrMap[host]; addr != "" {
return addr