mirror of
https://github.com/mmatczuk/go-http-tunnel.git
synced 2026-05-15 14:16:17 -06:00
commit
efe21afce5
20 changed files with 1233 additions and 521 deletions
10
TODO.md
10
TODO.md
|
|
@ -1,18 +1,18 @@
|
|||
Release 1.0
|
||||
|
||||
1. cli: cli and file configuration based on ngrok2 https://ngrok.com/docs#config
|
||||
1. docs: README update
|
||||
1. docs: new README
|
||||
|
||||
Backlog
|
||||
|
||||
1. monitoring: client connection state machine
|
||||
1. monitoring: ping https://godoc.org/github.com/hashicorp/yamux#Session.Ping
|
||||
1. monitoring: prometheus.io integration
|
||||
1. proxy: WebSockets
|
||||
1. docs: demo
|
||||
1. proxy: UDP
|
||||
1. proxy: file system
|
||||
1. proxy: host_header modifier
|
||||
1. security: certificate signature checks
|
||||
1. cli: integrated certificate generation
|
||||
1. monitoring: prometheus.io integration
|
||||
|
||||
|
||||
Notes for README
|
||||
|
||||
|
|
|
|||
96
client.go
96
client.go
|
|
@ -16,9 +16,8 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
// DefaultDialTimeout specifies how long client should wait for tunnel
|
||||
// server or local service connection.
|
||||
DefaultDialTimeout = 10 * time.Second
|
||||
// DefaultTimeout specifies general purpose timeout.
|
||||
DefaultTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// ClientConfig is configuration of the Client.
|
||||
|
|
@ -51,6 +50,7 @@ type Client struct {
|
|||
conn net.Conn
|
||||
connMu sync.Mutex
|
||||
httpServer *http2.Server
|
||||
serverErr error
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
|
|
@ -63,6 +63,9 @@ func NewClient(config *ClientConfig) *Client {
|
|||
if config.TLSClientConfig == nil {
|
||||
panic("Missing TLSClientConfig")
|
||||
}
|
||||
if config.Tunnels == nil || len(config.Tunnels) == 0 {
|
||||
panic("Missing Tunnels")
|
||||
}
|
||||
if config.Proxy == nil {
|
||||
panic("Missing Proxy")
|
||||
}
|
||||
|
|
@ -85,29 +88,53 @@ func NewClient(config *ClientConfig) *Client {
|
|||
// error, otherwise it spawns a new goroutine with http/2 server handling
|
||||
// ControlMessages.
|
||||
func (c *Client) Start() error {
|
||||
c.connMu.Lock()
|
||||
defer c.connMu.Unlock()
|
||||
|
||||
c.logger.Log(
|
||||
"level", 1,
|
||||
"action", "start",
|
||||
)
|
||||
|
||||
for {
|
||||
conn, err := c.connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.httpServer.ServeConn(conn, &http2.ServeConnOpts{
|
||||
Handler: http.HandlerFunc(c.serveHTTP),
|
||||
})
|
||||
|
||||
c.logger.Log(
|
||||
"level", 1,
|
||||
"action", "disconnected",
|
||||
)
|
||||
|
||||
c.connMu.Lock()
|
||||
err = c.serverErr
|
||||
c.conn = nil
|
||||
c.serverErr = nil
|
||||
c.connMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("server error: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) connect() (net.Conn, error) {
|
||||
c.connMu.Lock()
|
||||
defer c.connMu.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
return fmt.Errorf("already connected")
|
||||
return nil, fmt.Errorf("already connected")
|
||||
}
|
||||
|
||||
conn, err := c.dial()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to server: %s", err)
|
||||
return nil, fmt.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
c.conn = conn
|
||||
|
||||
go c.httpServer.ServeConn(conn, &http2.ServeConnOpts{
|
||||
Handler: http.HandlerFunc(c.serveHTTP),
|
||||
})
|
||||
|
||||
return nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Client) dial() (net.Conn, error) {
|
||||
|
|
@ -129,7 +156,7 @@ func (c *Client) dial() (net.Conn, error) {
|
|||
conn, err = c.config.DialTLS(network, addr, tlsConfig)
|
||||
} else {
|
||||
conn, err = tls.DialWithDialer(
|
||||
&net.Dialer{Timeout: DefaultDialTimeout},
|
||||
&net.Dialer{Timeout: DefaultTimeout},
|
||||
network, addr, tlsConfig,
|
||||
)
|
||||
}
|
||||
|
|
@ -181,7 +208,11 @@ func (c *Client) dial() (net.Conn, error) {
|
|||
|
||||
func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodConnect {
|
||||
c.handleHandshake(w, r)
|
||||
if r.Header.Get(proto.ErrorHeader) != "" {
|
||||
c.handleHandshakeError(w, r)
|
||||
} else {
|
||||
c.handleHandshake(w, r)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -218,6 +249,21 @@ func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
)
|
||||
}
|
||||
|
||||
func (c *Client) handleHandshakeError(w http.ResponseWriter, r *http.Request) {
|
||||
err := fmt.Errorf(r.Header.Get(proto.ErrorHeader))
|
||||
|
||||
c.logger.Log(
|
||||
"level", 1,
|
||||
"action", "handshake error",
|
||||
"addr", r.RemoteAddr,
|
||||
"err", err,
|
||||
)
|
||||
|
||||
c.connMu.Lock()
|
||||
c.serverErr = err
|
||||
c.connMu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) handleHandshake(w http.ResponseWriter, r *http.Request) {
|
||||
c.logger.Log(
|
||||
"level", 1,
|
||||
|
|
@ -227,18 +273,16 @@ func (c *Client) handleHandshake(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
if c.config.Tunnels != nil {
|
||||
b, err := json.Marshal(c.config.Tunnels)
|
||||
if err != nil {
|
||||
c.logger.Log(
|
||||
"level", 0,
|
||||
"msg", "handshake failed",
|
||||
"err", err,
|
||||
)
|
||||
return
|
||||
}
|
||||
w.Write(b)
|
||||
b, err := json.Marshal(c.config.Tunnels)
|
||||
if err != nil {
|
||||
c.logger.Log(
|
||||
"level", 0,
|
||||
"msg", "handshake failed",
|
||||
"err", err,
|
||||
)
|
||||
return
|
||||
}
|
||||
w.Write(b)
|
||||
}
|
||||
|
||||
// Stop closes the connection between client and server. After stopping client
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/mmatczuk/tunnel/mock"
|
||||
"github.com/mmatczuk/tunnel/proto"
|
||||
)
|
||||
|
||||
func TestClient_Dial(t *testing.T) {
|
||||
|
|
@ -23,7 +24,8 @@ func TestClient_Dial(t *testing.T) {
|
|||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
Proxy: Proxy(ProxyFuncs{}),
|
||||
Tunnels: map[string]*proto.Tunnel{"test": {}},
|
||||
Proxy: Proxy(ProxyFuncs{}),
|
||||
})
|
||||
|
||||
conn, err := c.dial()
|
||||
|
|
@ -57,6 +59,7 @@ func TestClient_DialBackoff(t *testing.T) {
|
|||
TLSClientConfig: &tls.Config{},
|
||||
DialTLS: d,
|
||||
Backoff: b,
|
||||
Tunnels: map[string]*proto.Tunnel{"test": {}},
|
||||
Proxy: Proxy(ProxyFuncs{}),
|
||||
})
|
||||
|
||||
|
|
|
|||
37
cmd/cmd/log.go
Normal file
37
cmd/cmd/log.go
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
kitlog "github.com/go-kit/kit/log"
|
||||
"github.com/mmatczuk/tunnel/log"
|
||||
)
|
||||
|
||||
// NewLogger returns logfmt based logger, printing messages up to log level
|
||||
// logLevel.
|
||||
func NewLogger(to string, level int) (log.Logger, error) {
|
||||
var w io.Writer
|
||||
|
||||
switch to {
|
||||
case "none":
|
||||
return log.NewNopLogger(), nil
|
||||
case "stdout":
|
||||
w = os.Stdout
|
||||
case "stderr":
|
||||
w = os.Stderr
|
||||
default:
|
||||
f, err := os.Create(to)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w = f
|
||||
}
|
||||
|
||||
var logger kitlog.Logger
|
||||
logger = kitlog.NewJSONLogger(kitlog.NewSyncWriter(w))
|
||||
logger = kitlog.NewContext(logger).WithPrefix("time", kitlog.Timestamp(time.Now))
|
||||
logger = log.NewFilterLogger(logger, level)
|
||||
return logger, nil
|
||||
}
|
||||
152
cmd/tunnel/config.go
Normal file
152
cmd/tunnel/config.go
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/mmatczuk/tunnel/proto"
|
||||
)
|
||||
|
||||
type BackoffConfig struct {
|
||||
InitialInterval time.Duration `yaml:"interval,omitempty"`
|
||||
Multiplier float64 `yaml:"multiplier,omitempty"`
|
||||
MaxInterval time.Duration `yaml:"max_interval,omitempty"`
|
||||
MaxElapsedTime time.Duration `yaml:"max_time,omitempty"`
|
||||
}
|
||||
|
||||
type TunnelConfig struct {
|
||||
Protocol string `yaml:"proto,omitempty"`
|
||||
Addr string `yaml:"addr,omitempty"`
|
||||
Auth string `yaml:"auth,omitempty"`
|
||||
Host string `yaml:"host,omitempty"`
|
||||
RemoteAddr string `yaml:"remote_addr,omitempty"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ServerAddr string `yaml:"server_addr,omitempty"`
|
||||
InsecureSkipVerify bool `yaml:"insecure_skip_verify,omitempty"`
|
||||
TLSCrt string `yaml:"tls_crt,omitempty"`
|
||||
TLSKey string `yaml:"tls_key,omitempty"`
|
||||
Backoff *BackoffConfig `yaml:"backoff,omitempty"`
|
||||
Tunnels map[string]*TunnelConfig `yaml:"tunnels,omitempty"`
|
||||
}
|
||||
|
||||
var defaultBackoffConfig = BackoffConfig{
|
||||
InitialInterval: 500 * time.Millisecond,
|
||||
Multiplier: 1.5,
|
||||
MaxInterval: 60 * time.Second,
|
||||
MaxElapsedTime: 15 * time.Minute,
|
||||
}
|
||||
|
||||
func loadConfiguration(path string) (*Config, error) {
|
||||
configBuf, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file %q: %s", path, err)
|
||||
}
|
||||
|
||||
// deserialize/parse the config
|
||||
var config Config
|
||||
if err = yaml.Unmarshal(configBuf, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse file %q: %s", path, err)
|
||||
}
|
||||
|
||||
// set default values
|
||||
if config.TLSCrt == "" {
|
||||
config.TLSCrt = filepath.Join(filepath.Dir(path), "client.crt")
|
||||
}
|
||||
if config.TLSKey == "" {
|
||||
config.TLSKey = filepath.Join(filepath.Dir(path), "client.key")
|
||||
}
|
||||
|
||||
if config.Backoff == nil {
|
||||
config.Backoff = &defaultBackoffConfig
|
||||
} else {
|
||||
if config.Backoff.InitialInterval == 0 {
|
||||
config.Backoff.InitialInterval = defaultBackoffConfig.InitialInterval
|
||||
}
|
||||
if config.Backoff.Multiplier == 0 {
|
||||
config.Backoff.Multiplier = defaultBackoffConfig.Multiplier
|
||||
}
|
||||
if config.Backoff.MaxInterval == 0 {
|
||||
config.Backoff.MaxInterval = defaultBackoffConfig.MaxInterval
|
||||
}
|
||||
if config.Backoff.MaxElapsedTime == 0 {
|
||||
config.Backoff.MaxElapsedTime = defaultBackoffConfig.MaxElapsedTime
|
||||
}
|
||||
}
|
||||
|
||||
// validate and normalize configuration
|
||||
if config.ServerAddr == "" {
|
||||
return nil, fmt.Errorf("server_addr: missing")
|
||||
}
|
||||
|
||||
if config.ServerAddr, err = normalizeAddress(config.ServerAddr); err != nil {
|
||||
return nil, fmt.Errorf("server_addr: %s", err)
|
||||
}
|
||||
|
||||
for name, t := range config.Tunnels {
|
||||
switch t.Protocol {
|
||||
case proto.HTTP:
|
||||
if err := validateHTTP(t); err != nil {
|
||||
return nil, fmt.Errorf("%s %s", name, err)
|
||||
}
|
||||
case proto.TCP, proto.TCP4, proto.TCP6:
|
||||
if err := validateTCP(t); err != nil {
|
||||
return nil, fmt.Errorf("%s %s", name, err)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("%s invalid protocol %q", name, t.Protocol)
|
||||
}
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func validateHTTP(t *TunnelConfig) error {
|
||||
var err error
|
||||
if t.Host == "" {
|
||||
return fmt.Errorf("host: missing")
|
||||
}
|
||||
if t.Addr == "" {
|
||||
return fmt.Errorf("addr: missing")
|
||||
}
|
||||
if t.Addr, err = normalizeURL(t.Addr); err != nil {
|
||||
return fmt.Errorf("addr: %s", err)
|
||||
}
|
||||
|
||||
// unexpected
|
||||
|
||||
if t.RemoteAddr != "" {
|
||||
return fmt.Errorf("remote_addr: unexpected")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateTCP(t *TunnelConfig) error {
|
||||
var err error
|
||||
if t.RemoteAddr, err = normalizeAddress(t.RemoteAddr); err != nil {
|
||||
return fmt.Errorf("remote_addr: %s", err)
|
||||
}
|
||||
if t.Addr == "" {
|
||||
return fmt.Errorf("addr: missing")
|
||||
}
|
||||
if t.Addr, err = normalizeAddress(t.Addr); err != nil {
|
||||
return fmt.Errorf("addr: %s", err)
|
||||
}
|
||||
|
||||
// unexpected
|
||||
|
||||
if t.Host != "" {
|
||||
return fmt.Errorf("host: unexpected")
|
||||
}
|
||||
if t.Auth != "" {
|
||||
return fmt.Errorf("auth: unexpected")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
52
cmd/tunnel/normalize.go
Normal file
52
cmd/tunnel/normalize.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func normalizeAddress(addr string) (string, error) {
|
||||
// normalize port to addr
|
||||
if _, err := strconv.Atoi(addr); err == nil {
|
||||
addr = ":" + addr
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s:%s", host, port), nil
|
||||
}
|
||||
|
||||
func normalizeURL(rawurl string) (string, error) {
|
||||
// check scheme
|
||||
s := strings.SplitN(rawurl, "://", 2)
|
||||
if len(s) > 1 {
|
||||
switch s[0] {
|
||||
case "http", "https":
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported url schema, choose 'http' or 'https'")
|
||||
}
|
||||
} else {
|
||||
rawurl = fmt.Sprint("http://", rawurl)
|
||||
}
|
||||
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if u.Path != "" && !strings.HasSuffix(u.Path, "/") {
|
||||
return "", fmt.Errorf("url must end with '/'")
|
||||
}
|
||||
|
||||
return rawurl, nil
|
||||
}
|
||||
111
cmd/tunnel/normalize_test.go
Normal file
111
cmd/tunnel/normalize_test.go
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeAddress(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
addr string
|
||||
expected string
|
||||
error string
|
||||
}{
|
||||
{
|
||||
addr: "22",
|
||||
expected: "127.0.0.1:22",
|
||||
},
|
||||
{
|
||||
addr: ":22",
|
||||
expected: "127.0.0.1:22",
|
||||
},
|
||||
{
|
||||
addr: "0.0.0.0:22",
|
||||
expected: "0.0.0.0:22",
|
||||
},
|
||||
{
|
||||
addr: "0.0.0.0",
|
||||
error: "missing port",
|
||||
},
|
||||
{
|
||||
addr: "",
|
||||
error: "missing port",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
actual, err := normalizeAddress(tt.addr)
|
||||
if actual != tt.expected {
|
||||
t.Errorf("[%d] expected %q got %q err: %s", i, tt.expected, actual, err)
|
||||
}
|
||||
if tt.error != "" && err == nil {
|
||||
t.Errorf("[%d] expected error", i)
|
||||
}
|
||||
if err != nil && (tt.error == "" || !strings.Contains(err.Error(), tt.error)) {
|
||||
t.Errorf("[%d] expected error contains %q, got %q", i, tt.error, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
rawurl string
|
||||
expected string
|
||||
error string
|
||||
}{
|
||||
{
|
||||
rawurl: "localhost",
|
||||
expected: "http://localhost",
|
||||
},
|
||||
{
|
||||
rawurl: "localhost:80",
|
||||
expected: "http://localhost:80",
|
||||
},
|
||||
{
|
||||
rawurl: "localhost:80/path/",
|
||||
expected: "http://localhost:80/path/",
|
||||
},
|
||||
{
|
||||
rawurl: "localhost:80/path",
|
||||
error: "/",
|
||||
},
|
||||
{
|
||||
rawurl: "https://localhost",
|
||||
expected: "https://localhost",
|
||||
},
|
||||
{
|
||||
rawurl: "https://localhost:443",
|
||||
expected: "https://localhost:443",
|
||||
},
|
||||
{
|
||||
rawurl: "https://localhost:443/path/",
|
||||
expected: "https://localhost:443/path/",
|
||||
},
|
||||
{
|
||||
rawurl: "https://localhost:443/path",
|
||||
error: "/",
|
||||
},
|
||||
{
|
||||
rawurl: "ftp://localhost",
|
||||
error: "unsupported url schema",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
actual, err := normalizeURL(tt.rawurl)
|
||||
if actual != tt.expected {
|
||||
t.Errorf("[%d] expected %q got %q, err: %s", i, tt.expected, actual, err)
|
||||
}
|
||||
if tt.error != "" && err == nil {
|
||||
t.Errorf("[%d] expected error", i)
|
||||
}
|
||||
if err != nil && (tt.error == "" || !strings.Contains(err.Error(), tt.error)) {
|
||||
t.Errorf("[%d] expected error contains %q, got %q", i, tt.error, err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
108
cmd/tunnel/options.go
Normal file
108
cmd/tunnel/options.go
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
const usage1 string = `Usage: tunnel [OPTIONS] <command> [command args] [...]
|
||||
options:
|
||||
`
|
||||
|
||||
const usage2 string = `
|
||||
Commands:
|
||||
tunnel start [tunnel] [...] Start tunnels by name from config file
|
||||
tunnel start-all Start all tunnels defined in config file
|
||||
tunnel list List tunnel names from config file
|
||||
|
||||
Examples:
|
||||
tunnel start www ssh
|
||||
tunnel -config=config.yaml -log=stdout -log-level 2 start ssh
|
||||
tunnel start-all
|
||||
|
||||
config.yaml:
|
||||
server_addr: SERVER_IP:4443
|
||||
insecure_skip_verify: true
|
||||
tunnels:
|
||||
www:
|
||||
proto: http
|
||||
addr: http://IP:8080/ui/
|
||||
auth: user:password
|
||||
host: ui.mytunnel.com
|
||||
ssh:
|
||||
proto: tcp
|
||||
addr: IP:22
|
||||
remote_addr: 0.0.0.0:2222
|
||||
`
|
||||
|
||||
func init() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, usage1)
|
||||
flag.PrintDefaults()
|
||||
fmt.Fprintf(os.Stderr, usage2)
|
||||
}
|
||||
}
|
||||
|
||||
type options struct {
|
||||
debug bool
|
||||
config string
|
||||
logTo string
|
||||
logLevel int
|
||||
command string
|
||||
args []string
|
||||
}
|
||||
|
||||
func parseArgs() (*options, error) {
|
||||
debug := flag.Bool("debug", false, "Starts gops agent")
|
||||
config := flag.String("config", filepath.Join(defaultPath(), "config.yaml"), "Path to tunnel configuration file")
|
||||
logTo := flag.String("log", "stdout", "Write log messages to this file, file name or 'stdout', 'stderr', 'none'")
|
||||
logLevel := flag.Int("log-level", 1, "Level of messages to log, 0-3")
|
||||
flag.Parse()
|
||||
|
||||
opts := &options{
|
||||
debug: *debug,
|
||||
config: *config,
|
||||
logTo: *logTo,
|
||||
logLevel: *logLevel,
|
||||
command: flag.Arg(0),
|
||||
}
|
||||
|
||||
switch opts.command {
|
||||
case "list":
|
||||
opts.args = flag.Args()[1:]
|
||||
if len(opts.args) > 0 {
|
||||
return nil, fmt.Errorf("list takes no arguments")
|
||||
}
|
||||
case "start":
|
||||
opts.args = flag.Args()[1:]
|
||||
if len(opts.args) == 0 {
|
||||
return nil, fmt.Errorf("you must specify at least one tunnel to start")
|
||||
}
|
||||
case "start-all":
|
||||
opts.args = flag.Args()[1:]
|
||||
if len(opts.args) > 0 {
|
||||
return nil, fmt.Errorf("start-all takes no arguments")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("expected command")
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func defaultPath() string {
|
||||
// user.Current() does not work on linux when cross compiling because
|
||||
// it requires CGO; use os.Getenv("HOME") hack until we compile natively
|
||||
var dir string
|
||||
|
||||
if user, err := user.Current(); err == nil {
|
||||
dir = user.HomeDir
|
||||
} else {
|
||||
dir = os.Getenv("HOME")
|
||||
}
|
||||
|
||||
return filepath.Join(dir, ".tunnel")
|
||||
}
|
||||
153
cmd/tunnel/tunnel.go
Normal file
153
cmd/tunnel/tunnel.go
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/cenkalti/backoff"
|
||||
"github.com/google/gops/agent"
|
||||
"github.com/mmatczuk/tunnel"
|
||||
"github.com/mmatczuk/tunnel/cmd/cmd"
|
||||
"github.com/mmatczuk/tunnel/log"
|
||||
"github.com/mmatczuk/tunnel/proto"
|
||||
)
|
||||
|
||||
func main() {
|
||||
opts, err := parseArgs()
|
||||
if err != nil {
|
||||
fatal(err.Error())
|
||||
}
|
||||
|
||||
if opts.debug {
|
||||
if err := agent.Listen(nil); err != nil {
|
||||
fatal("gops agent failed to start: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger, err := cmd.NewLogger(opts.logTo, opts.logLevel)
|
||||
if err != nil {
|
||||
fatal("failed to init logger: %s", err)
|
||||
}
|
||||
|
||||
// read configuration file
|
||||
config, err := loadConfiguration(opts.config)
|
||||
if err != nil {
|
||||
fatal("configuration error: %s", err)
|
||||
}
|
||||
|
||||
switch opts.command {
|
||||
case "list":
|
||||
var names []string
|
||||
for n := range config.Tunnels {
|
||||
names = append(names, n)
|
||||
}
|
||||
|
||||
sort.Strings(names)
|
||||
|
||||
for _, n := range names {
|
||||
fmt.Println(n)
|
||||
}
|
||||
|
||||
return
|
||||
case "start":
|
||||
tunnels := make(map[string]*TunnelConfig)
|
||||
for _, arg := range opts.args {
|
||||
t, ok := config.Tunnels[arg]
|
||||
if !ok {
|
||||
fatal("no such tunnel %q", arg)
|
||||
}
|
||||
tunnels[arg] = t
|
||||
}
|
||||
config.Tunnels = tunnels
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(config.TLSCrt, config.TLSKey)
|
||||
if err != nil {
|
||||
fatal("failed to load certificate: %s", err)
|
||||
}
|
||||
|
||||
b, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
logger.Log("config", string(b))
|
||||
|
||||
client := tunnel.NewClient(&tunnel.ClientConfig{
|
||||
ServerAddr: config.ServerAddr,
|
||||
TLSClientConfig: tlsConfig(cert, config),
|
||||
Backoff: expBackoff(config.Backoff),
|
||||
Tunnels: tunnels(config.Tunnels),
|
||||
Proxy: proxy(config.Tunnels, logger),
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
if err := client.Start(); err != nil {
|
||||
fatal("%s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func tlsConfig(cert tls.Certificate, config *Config) *tls.Config {
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
InsecureSkipVerify: config.InsecureSkipVerify,
|
||||
}
|
||||
}
|
||||
|
||||
func expBackoff(config *BackoffConfig) *backoff.ExponentialBackOff {
|
||||
b := backoff.NewExponentialBackOff()
|
||||
b.InitialInterval = config.InitialInterval
|
||||
b.Multiplier = config.Multiplier
|
||||
b.MaxInterval = config.MaxInterval
|
||||
b.MaxElapsedTime = config.MaxElapsedTime
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func tunnels(m map[string]*TunnelConfig) map[string]*proto.Tunnel {
|
||||
p := make(map[string]*proto.Tunnel)
|
||||
|
||||
for name, t := range m {
|
||||
p[name] = &proto.Tunnel{
|
||||
Protocol: t.Protocol,
|
||||
Host: t.Host,
|
||||
Auth: t.Auth,
|
||||
Addr: t.RemoteAddr,
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func proxy(m map[string]*TunnelConfig, logger log.Logger) tunnel.ProxyFunc {
|
||||
httpURL := make(map[string]*url.URL)
|
||||
tcpAddr := make(map[string]string)
|
||||
|
||||
for _, t := range m {
|
||||
switch t.Protocol {
|
||||
case proto.HTTP:
|
||||
u, err := url.Parse(t.Addr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
httpURL[t.Host] = u
|
||||
case proto.TCP, proto.TCP4, proto.TCP6:
|
||||
tcpAddr[t.RemoteAddr] = t.Addr
|
||||
}
|
||||
}
|
||||
|
||||
return tunnel.Proxy(tunnel.ProxyFuncs{
|
||||
HTTP: tunnel.NewMultiHTTPProxy(httpURL, logger).Proxy,
|
||||
TCP: tunnel.NewMultiTCPProxy(tcpAddr, logger).Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
func fatal(format string, a ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, format, a...)
|
||||
fmt.Fprint(os.Stderr, "\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
56
cmd/tunneld/options.go
Normal file
56
cmd/tunneld/options.go
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
const usage1 string = `Usage: tunneld [OPTIONS]
|
||||
options:
|
||||
`
|
||||
|
||||
func init() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, usage1)
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
}
|
||||
|
||||
// options specify arguments read command line arguments.
|
||||
type options struct {
|
||||
debug bool
|
||||
httpAddr string
|
||||
httpsAddr string
|
||||
tunnelAddr string
|
||||
tlsCrt string
|
||||
tlsKey string
|
||||
clients string
|
||||
logTo string
|
||||
logLevel int
|
||||
}
|
||||
|
||||
func parseArgs() *options {
|
||||
debug := flag.Bool("debug", false, "Starts gops agent")
|
||||
httpAddr := flag.String("httpAddr", ":80", "Public address for HTTP connections, empty string to disable")
|
||||
httpsAddr := flag.String("httpsAddr", ":443", "Public address listening for HTTPS connections, emptry string to disable")
|
||||
tunnelAddr := flag.String("tunnelAddr", ":4443", "Public address listening for tunnel client")
|
||||
tlsCrt := flag.String("tlsCrt", "", "Path to a TLS certificate file")
|
||||
tlsKey := flag.String("tlsKey", "", "Path to a TLS key file")
|
||||
clients := flag.String("clients", "", "Comma-separated list of tunnel client ids")
|
||||
logTo := flag.String("log", "stdout", "Write log messages to this file, file name or 'stdout', 'stderr', 'none'")
|
||||
logLevel := flag.Int("log-level", 1, "Level of messages to log, 0-3")
|
||||
flag.Parse()
|
||||
|
||||
return &options{
|
||||
debug: *debug,
|
||||
httpAddr: *httpAddr,
|
||||
httpsAddr: *httpsAddr,
|
||||
tunnelAddr: *tunnelAddr,
|
||||
tlsCrt: *tlsCrt,
|
||||
tlsKey: *tlsKey,
|
||||
clients: *clients,
|
||||
logTo: *logTo,
|
||||
logLevel: *logLevel,
|
||||
}
|
||||
}
|
||||
114
cmd/tunneld/tunneld.go
Normal file
114
cmd/tunneld/tunneld.go
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/google/gops/agent"
|
||||
"github.com/mmatczuk/tunnel"
|
||||
"github.com/mmatczuk/tunnel/cmd/cmd"
|
||||
"github.com/mmatczuk/tunnel/id"
|
||||
)
|
||||
|
||||
func main() {
|
||||
opts := parseArgs()
|
||||
|
||||
if opts.debug {
|
||||
if err := agent.Listen(nil); err != nil {
|
||||
fatal("gops agent failed to start: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger, err := cmd.NewLogger(opts.logTo, opts.logLevel)
|
||||
if err != nil {
|
||||
fatal("failed to init logger: %s", err)
|
||||
}
|
||||
|
||||
// load certs
|
||||
cert, err := tls.LoadX509KeyPair(opts.tlsCrt, opts.tlsKey)
|
||||
if err != nil {
|
||||
fatal("failed to load certificate: %s", err)
|
||||
}
|
||||
|
||||
// setup server
|
||||
server, err := tunnel.NewServer(&tunnel.ServerConfig{
|
||||
Addr: opts.tunnelAddr,
|
||||
TLSConfig: tlsConfig(cert),
|
||||
Logger: logger,
|
||||
})
|
||||
if err != nil {
|
||||
fatal("failed to create server: %s", err)
|
||||
}
|
||||
|
||||
if opts.clients == "" {
|
||||
logger.Log(
|
||||
"level", 0,
|
||||
"msg", "No clients",
|
||||
)
|
||||
} else {
|
||||
for _, c := range strings.Split(opts.clients, ",") {
|
||||
if c == "" {
|
||||
fatal("empty client id")
|
||||
}
|
||||
identifier := id.ID{}
|
||||
err := identifier.UnmarshalText([]byte(c))
|
||||
if err != nil {
|
||||
fatal("invalid identifier %q: %s", c, err)
|
||||
}
|
||||
server.Subscribe(identifier)
|
||||
}
|
||||
}
|
||||
|
||||
// start HTTP
|
||||
if opts.httpAddr != "" {
|
||||
go func() {
|
||||
logger.Log(
|
||||
"level", 1,
|
||||
"action", "start http",
|
||||
"addr", opts.httpAddr,
|
||||
)
|
||||
err := http.ListenAndServe(opts.httpAddr, server)
|
||||
if err != nil {
|
||||
fatal("failed to start HTTP: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// start HTTPS
|
||||
if opts.httpsAddr != "" {
|
||||
go func() {
|
||||
logger.Log(
|
||||
"level", 1,
|
||||
"action", "start https",
|
||||
"addr", opts.httpsAddr,
|
||||
)
|
||||
err := http.ListenAndServeTLS(opts.httpsAddr, opts.tlsCrt, opts.tlsKey, server)
|
||||
if err != nil {
|
||||
fatal("failed to start HTTPS: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
server.Start()
|
||||
}
|
||||
|
||||
func tlsConfig(cert tls.Certificate) *tls.Config {
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
ClientAuth: tls.RequestClientCert,
|
||||
SessionTicketsDisabled: true,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
|
||||
PreferServerCipherSuites: true,
|
||||
NextProtos: []string{"h2"},
|
||||
}
|
||||
}
|
||||
|
||||
func fatal(format string, a ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, format, a...)
|
||||
fmt.Fprint(os.Stderr, "\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
28
httpproxy.go
28
httpproxy.go
|
|
@ -9,7 +9,7 @@ import (
|
|||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"path"
|
||||
|
||||
"github.com/mmatczuk/tunnel/log"
|
||||
"github.com/mmatczuk/tunnel/proto"
|
||||
|
|
@ -101,6 +101,8 @@ func (p *HTTPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessag
|
|||
// is correctly routed based on localURL and localURLMap. If no URL can be found
|
||||
// the request is canceled.
|
||||
func (p *HTTPProxy) Director(req *http.Request) {
|
||||
orig := *req.URL
|
||||
|
||||
target := p.localURLFor(req.URL)
|
||||
if target == nil {
|
||||
p.logger.Log(
|
||||
|
|
@ -129,18 +131,26 @@ func (p *HTTPProxy) Director(req *http.Request) {
|
|||
// explicitly disable User-Agent so it's not set to default value
|
||||
req.Header.Set("User-Agent", "")
|
||||
}
|
||||
|
||||
req.Host = req.URL.Host
|
||||
|
||||
p.logger.Log(
|
||||
"level", 2,
|
||||
"action", "url rewrite",
|
||||
"from", &orig,
|
||||
"to", req.URL,
|
||||
)
|
||||
}
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
bslash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a + b[1:]
|
||||
case !aslash && !bslash:
|
||||
return a + "/" + b
|
||||
if a == "" || a == "/" {
|
||||
return b
|
||||
}
|
||||
return a + b
|
||||
if b == "" || b == "/" {
|
||||
return a
|
||||
}
|
||||
|
||||
return path.Join(a, b)
|
||||
}
|
||||
|
||||
func (p *HTTPProxy) localURLFor(u *url.URL) *url.URL {
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/mmatczuk/tunnel"
|
||||
"github.com/mmatczuk/tunnel/id"
|
||||
"github.com/mmatczuk/tunnel/log"
|
||||
"github.com/mmatczuk/tunnel/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -37,13 +38,6 @@ var ctx testContext
|
|||
func TestMain(m *testing.M) {
|
||||
logger := log.NewFilterLogger(log.NewStdLogger(), 1)
|
||||
|
||||
// prepare server TCP listener
|
||||
serverTCPListener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer serverTCPListener.Close()
|
||||
|
||||
// prepare tunnel server
|
||||
cert, identifier := selfSignedCert()
|
||||
s, err := tunnel.NewServer(&tunnel.ServerConfig{
|
||||
|
|
@ -54,30 +48,19 @@ func TestMain(m *testing.M) {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
s.Subscribe(identifier)
|
||||
|
||||
auth := &tunnel.Auth{
|
||||
User: "user",
|
||||
Password: "password",
|
||||
}
|
||||
|
||||
if err := s.AddHost("localhost", auth, identifier); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := s.AddListener(serverTCPListener, identifier); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
s.Start()
|
||||
go s.Start()
|
||||
defer s.Stop()
|
||||
|
||||
// run server HTTP interface
|
||||
serverHTTPListener, err := net.Listen("tcp", ":0")
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer serverHTTPListener.Close()
|
||||
go http.Serve(serverHTTPListener, s)
|
||||
defer l.Close()
|
||||
go http.Serve(l, s)
|
||||
|
||||
httpAddr := l.Addr()
|
||||
|
||||
// prepare local TCP echo service
|
||||
echoTCPListener, err := net.Listen("tcp", ":0")
|
||||
|
|
@ -95,37 +78,50 @@ func TestMain(m *testing.M) {
|
|||
defer echoHTTPListener.Close()
|
||||
go EchoHTTP(echoHTTPListener)
|
||||
|
||||
// prepare proxy
|
||||
httpproxy := tunnel.NewMultiHTTPProxy(map[string]*url.URL{
|
||||
"localhost:" + Port(serverHTTPListener.Addr()): {
|
||||
Scheme: "http",
|
||||
Host: echoHTTPListener.Addr().String(),
|
||||
},
|
||||
}, log.NewNopLogger())
|
||||
|
||||
tcpproxy := tunnel.NewMultiTCPProxy(map[string]string{
|
||||
Port(serverTCPListener.Addr()): echoTCPListener.Addr().String(),
|
||||
}, log.NewNopLogger())
|
||||
|
||||
proxy := tunnel.Proxy(tunnel.ProxyFuncs{
|
||||
HTTP: httpproxy.Proxy,
|
||||
TCP: tcpproxy.Proxy,
|
||||
})
|
||||
// allocate free port
|
||||
tcpAddr := freeAddr()
|
||||
|
||||
// prepare tunnel client
|
||||
tunnels := map[string]*proto.Tunnel{
|
||||
"http": {
|
||||
Protocol: proto.HTTP,
|
||||
Host: "localhost",
|
||||
Auth: "user:password",
|
||||
},
|
||||
"tcp": {
|
||||
Protocol: proto.TCP,
|
||||
Addr: tcpAddr.String(),
|
||||
},
|
||||
}
|
||||
|
||||
httpProxy := tunnel.NewMultiHTTPProxy(map[string]*url.URL{
|
||||
"localhost:" + port(httpAddr): {
|
||||
Scheme: "http",
|
||||
Host: "127.0.0.1:" + port(echoHTTPListener.Addr()),
|
||||
},
|
||||
}, log.NewContext(logger).WithPrefix("HTTP proxy", ":"))
|
||||
|
||||
tcpProxy := tunnel.NewMultiTCPProxy(map[string]string{
|
||||
port(tcpAddr): echoTCPListener.Addr().String(),
|
||||
}, log.NewContext(logger).WithPrefix("TCP proxy", ":"))
|
||||
|
||||
c := tunnel.NewClient(&tunnel.ClientConfig{
|
||||
ServerAddr: s.Addr(),
|
||||
TLSClientConfig: TLSConfig(cert),
|
||||
Proxy: proxy,
|
||||
Logger: log.NewContext(logger).WithPrefix("client", ":"),
|
||||
Tunnels: tunnels,
|
||||
Proxy: tunnel.Proxy(tunnel.ProxyFuncs{
|
||||
HTTP: httpProxy.Proxy,
|
||||
TCP: tcpProxy.Proxy,
|
||||
}),
|
||||
Logger: log.NewContext(logger).WithPrefix("client", ":"),
|
||||
})
|
||||
if err := c.Start(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go c.Start()
|
||||
// FIXME: replace sleep with client state change watch when ready
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
defer c.Stop()
|
||||
|
||||
ctx.httpAddr = serverHTTPListener.Addr()
|
||||
ctx.tcpAddr = serverTCPListener.Addr()
|
||||
ctx.httpAddr = httpAddr
|
||||
ctx.tcpAddr = tcpAddr
|
||||
ctx.payload = randPayload(payloadInitialSize, payloadLen)
|
||||
|
||||
m.Run()
|
||||
|
|
@ -166,7 +162,7 @@ func TestProxying(t *testing.T) {
|
|||
|
||||
func testHTTP(t *testing.T, payload []byte, repeat uint) {
|
||||
for repeat > 0 {
|
||||
url := fmt.Sprintf("http://localhost:%s/some/path", Port(ctx.httpAddr))
|
||||
url := fmt.Sprintf("http://localhost:%s/some/path", port(ctx.httpAddr))
|
||||
r, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
panic("Failed to create request")
|
||||
|
|
@ -253,6 +249,19 @@ func randPayload(initialSize, n int) [][]byte {
|
|||
return payload
|
||||
}
|
||||
|
||||
func freeAddr() net.Addr {
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer l.Close()
|
||||
return l.Addr()
|
||||
}
|
||||
|
||||
func port(addr net.Addr) string {
|
||||
return fmt.Sprint(addr.(*net.TCPAddr).Port)
|
||||
}
|
||||
|
||||
func selfSignedCert() (tls.Certificate, id.ID) {
|
||||
cert, err := tls.LoadX509KeyPair("./test-fixtures/selfsigned.crt", "./test-fixtures/selfsigned.key")
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package integrationtest
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
|
|
@ -10,11 +9,6 @@ import (
|
|||
"net/http"
|
||||
)
|
||||
|
||||
// Port returns port form TCP address.
|
||||
func Port(addr net.Addr) string {
|
||||
return fmt.Sprint(addr.(*net.TCPAddr).Port)
|
||||
}
|
||||
|
||||
// EchoHTTP starts serving HTTP requests on listener l, it accepts connections,
|
||||
// reads request body and writes is back in response.
|
||||
func EchoHTTP(l net.Listener) {
|
||||
|
|
|
|||
30
pool.go
30
pool.go
|
|
@ -11,21 +11,25 @@ import (
|
|||
"github.com/mmatczuk/tunnel/id"
|
||||
)
|
||||
|
||||
type onDisconnectListener func(identifier id.ID)
|
||||
|
||||
type connPair struct {
|
||||
conn net.Conn
|
||||
clientConn *http2.ClientConn
|
||||
}
|
||||
|
||||
type connPool struct {
|
||||
t *http2.Transport
|
||||
conns map[string]connPair // key is host:port
|
||||
mu sync.RWMutex
|
||||
t *http2.Transport
|
||||
conns map[string]connPair // key is host:port
|
||||
listener onDisconnectListener
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newConnPool(t *http2.Transport) *connPool {
|
||||
func newConnPool(t *http2.Transport, l onDisconnectListener) *connPool {
|
||||
return &connPool{
|
||||
t: t,
|
||||
conns: make(map[string]connPair),
|
||||
t: t,
|
||||
listener: l,
|
||||
conns: make(map[string]connPair),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -44,10 +48,13 @@ func (p *connPool) MarkDead(c *http2.ClientConn) {
|
|||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for identifier, cp := range p.conns {
|
||||
for addr, cp := range p.conns {
|
||||
if cp.clientConn == c {
|
||||
cp.conn.Close()
|
||||
delete(p.conns, identifier)
|
||||
delete(p.conns, addr)
|
||||
if p.listener != nil {
|
||||
p.listener(p.addrToIdentifier(addr))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
@ -84,6 +91,9 @@ func (p *connPool) DeleteConn(identifier id.ID) {
|
|||
if cp, ok := p.conns[addr]; ok {
|
||||
cp.conn.Close()
|
||||
delete(p.conns, addr)
|
||||
if p.listener != nil {
|
||||
p.listener(identifier)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -94,3 +104,7 @@ func (p *connPool) URL(identifier id.ID) string {
|
|||
func (p *connPool) addr(identifier id.ID) string {
|
||||
return fmt.Sprint(identifier.String(), ":443")
|
||||
}
|
||||
|
||||
func (p *connPool) addrToIdentifier(addr string) id.ID {
|
||||
return id.NewFromString(addr[:len(addr)-4])
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ const (
|
|||
|
||||
// ControlMessage headers
|
||||
const (
|
||||
ErrorHeader = "Error"
|
||||
ForwardedHeader = "Forwarded"
|
||||
)
|
||||
|
||||
|
|
|
|||
181
registry.go
181
registry.go
|
|
@ -6,37 +6,56 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/mmatczuk/tunnel/id"
|
||||
"github.com/mmatczuk/tunnel/log"
|
||||
)
|
||||
|
||||
// RegistryItem holds information about hosts and listeners associated with a
|
||||
// client.
|
||||
type RegistryItem struct {
|
||||
Hosts []string
|
||||
Hosts []*HostAuth
|
||||
Listeners []net.Listener
|
||||
}
|
||||
|
||||
// HostAuth holds host and authentication info.
|
||||
type HostAuth struct {
|
||||
Host string
|
||||
Auth *Auth
|
||||
}
|
||||
|
||||
type hostInfo struct {
|
||||
identifier id.ID
|
||||
auth *Auth
|
||||
}
|
||||
|
||||
// registry manages client tunnels information.
|
||||
type registry struct {
|
||||
items map[id.ID]*RegistryItem
|
||||
hosts map[string]*hostInfo
|
||||
mu sync.RWMutex
|
||||
items map[id.ID]*RegistryItem
|
||||
hosts map[string]*hostInfo
|
||||
mu sync.RWMutex
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// newRegistry creates new registry.
|
||||
func newRegistry() *registry {
|
||||
func newRegistry(logger log.Logger) *registry {
|
||||
if logger == nil {
|
||||
logger = log.NewNopLogger()
|
||||
}
|
||||
|
||||
return ®istry{
|
||||
items: make(map[id.ID]*RegistryItem),
|
||||
hosts: make(map[string]*hostInfo, 0),
|
||||
items: make(map[id.ID]*RegistryItem),
|
||||
hosts: make(map[string]*hostInfo, 0),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe adds new client to registry, this method is idempotent.
|
||||
var voidRegistryItem = &RegistryItem{}
|
||||
|
||||
// Subscribe allows to connect client with a given identifier.
|
||||
func (r *registry) Subscribe(identifier id.ID) {
|
||||
r.logger.Log(
|
||||
"level", 1,
|
||||
"action", "subscribe",
|
||||
"identifier", identifier,
|
||||
)
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
|
|
@ -44,13 +63,10 @@ func (r *registry) Subscribe(identifier id.ID) {
|
|||
return
|
||||
}
|
||||
|
||||
r.items[identifier] = &RegistryItem{
|
||||
Hosts: make([]string, 0),
|
||||
Listeners: make([]net.Listener, 0),
|
||||
}
|
||||
r.items[identifier] = voidRegistryItem
|
||||
}
|
||||
|
||||
// IsSubscribed returns true if client is subscribed to registry.
|
||||
// IsSubscribed returns true if client is subscribed.
|
||||
func (r *registry) IsSubscribed(identifier id.ID) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
|
@ -60,12 +76,10 @@ func (r *registry) IsSubscribed(identifier id.ID) bool {
|
|||
|
||||
// Subscriber returns client identifier assigned to given host.
|
||||
func (r *registry) Subscriber(hostPort string) (id.ID, *Auth, bool) {
|
||||
host := trimPort(hostPort)
|
||||
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
h, ok := r.hosts[host]
|
||||
h, ok := r.hosts[trimPort(hostPort)]
|
||||
if !ok {
|
||||
return id.ID{}, nil, false
|
||||
}
|
||||
|
|
@ -73,8 +87,14 @@ func (r *registry) Subscriber(hostPort string) (id.ID, *Auth, bool) {
|
|||
return h.identifier, h.auth, ok
|
||||
}
|
||||
|
||||
// Unsubscribe removes client from registy and returns it's RegistryItem.
|
||||
// Unsubscribe removes client from registry and returns it's RegistryItem.
|
||||
func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem {
|
||||
r.logger.Log(
|
||||
"level", 1,
|
||||
"action", "unsubscribe",
|
||||
"identifier", identifier,
|
||||
)
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
|
|
@ -82,9 +102,14 @@ func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem {
|
|||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if i == voidRegistryItem {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, h := range i.Hosts {
|
||||
delete(r.hosts, h)
|
||||
if i.Hosts != nil {
|
||||
for _, h := range i.Hosts {
|
||||
delete(r.hosts, h.Host)
|
||||
}
|
||||
}
|
||||
|
||||
delete(r.items, identifier)
|
||||
|
|
@ -92,103 +117,71 @@ func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem {
|
|||
return i
|
||||
}
|
||||
|
||||
// AddHost assigns host to client unless the host is not already taken.
|
||||
func (r *registry) AddHost(hostPort string, auth *Auth, identifier id.ID) error {
|
||||
host := trimPort(hostPort)
|
||||
func (r *registry) set(i *RegistryItem, identifier id.ID) error {
|
||||
r.logger.Log(
|
||||
"level", 2,
|
||||
"action", "set registry item",
|
||||
"identifier", identifier,
|
||||
)
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if auth != nil && auth.User == "" {
|
||||
return fmt.Errorf("Missing auth user")
|
||||
}
|
||||
|
||||
i, ok := r.items[identifier]
|
||||
j, ok := r.items[identifier]
|
||||
if !ok {
|
||||
return errClientNotSubscribed
|
||||
}
|
||||
|
||||
if _, ok := r.hosts[host]; ok {
|
||||
return fmt.Errorf("host %q is occupied", host)
|
||||
}
|
||||
r.hosts[host] = &hostInfo{
|
||||
identifier: identifier,
|
||||
auth: auth,
|
||||
if j != voidRegistryItem {
|
||||
return fmt.Errorf("attempt to overwrite registry item")
|
||||
}
|
||||
|
||||
i.Hosts = append(i.Hosts, host)
|
||||
if i.Hosts != nil {
|
||||
for _, h := range i.Hosts {
|
||||
if h.Auth != nil && h.Auth.User == "" {
|
||||
return fmt.Errorf("missing auth user")
|
||||
}
|
||||
if _, ok := r.hosts[trimPort(h.Host)]; ok {
|
||||
return fmt.Errorf("host %q is occupied", h.Host)
|
||||
}
|
||||
}
|
||||
|
||||
for _, h := range i.Hosts {
|
||||
r.hosts[trimPort(h.Host)] = &hostInfo{
|
||||
identifier: identifier,
|
||||
auth: h.Auth,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.items[identifier] = i
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteHost unassigns host from client.
|
||||
func (r *registry) DeleteHost(hostPort string, identifier id.ID) {
|
||||
host := trimPort(hostPort)
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if h, ok := r.hosts[host]; !ok || h.identifier != identifier {
|
||||
return
|
||||
}
|
||||
|
||||
delete(r.hosts, host)
|
||||
|
||||
i := r.items[identifier]
|
||||
for k, v := range i.Hosts {
|
||||
if v == host {
|
||||
i.Hosts = append(i.Hosts[:k], i.Hosts[k+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddListener adds client listener.
|
||||
func (r *registry) AddListener(l net.Listener, identifier id.ID) error {
|
||||
if l == nil {
|
||||
panic("Missing listener")
|
||||
}
|
||||
func (r *registry) clear(identifier id.ID) *RegistryItem {
|
||||
r.logger.Log(
|
||||
"level", 2,
|
||||
"action", "clear registry item",
|
||||
"identifier", identifier,
|
||||
)
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
i, ok := r.items[identifier]
|
||||
if !ok {
|
||||
return errClientNotSubscribed
|
||||
if !ok || i == voidRegistryItem {
|
||||
return nil
|
||||
}
|
||||
|
||||
for k, v := range i.Listeners {
|
||||
if v == l {
|
||||
return fmt.Errorf("listener already added at %d", k)
|
||||
if i.Hosts != nil {
|
||||
for _, h := range i.Hosts {
|
||||
delete(r.hosts, trimPort(h.Host))
|
||||
}
|
||||
}
|
||||
|
||||
i.Listeners = append(i.Listeners, l)
|
||||
r.items[identifier] = voidRegistryItem
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteListener removes listener from client. Listener must be closed to stop
|
||||
// accepting go routine.
|
||||
func (r *registry) DeleteListener(l net.Listener, identifier id.ID) {
|
||||
if l == nil {
|
||||
panic("Missing listener")
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
i, ok := r.items[identifier]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range i.Listeners {
|
||||
if v == l {
|
||||
i.Listeners = append(i.Listeners[:k], i.Listeners[k+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func trimPort(hostPort string) (host string) {
|
||||
|
|
|
|||
230
registry_test.go
230
registry_test.go
|
|
@ -1,230 +0,0 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mmatczuk/tunnel/id"
|
||||
)
|
||||
|
||||
var (
|
||||
a = id.NewFromString("A")
|
||||
b = id.NewFromString("B")
|
||||
)
|
||||
|
||||
func TestRegistry_IsSubscribed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
|
||||
if !r.IsSubscribed(a) {
|
||||
t.Fatal("Client should be subscribed")
|
||||
}
|
||||
if r.IsSubscribed(b) {
|
||||
t.Fatal("Client should not be subscribed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_UnsubscribeOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
|
||||
if r.Unsubscribe(a) == nil {
|
||||
t.Fatal("Unsubscribe should return RegistryItem")
|
||||
}
|
||||
if r.Unsubscribe(a) != nil {
|
||||
t.Fatal("Unsubscribe should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_UnsubscribeReturnsHosts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddHost("host0", nil, a)
|
||||
r.AddHost("host1", nil, a)
|
||||
|
||||
i := r.Unsubscribe(a)
|
||||
if !reflect.DeepEqual(i.Hosts, []string{"host0", "host1"}) {
|
||||
t.Fatal("RegistryItem should contain hosts")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_UnsubscribeReturnsListeners(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l0 := &net.TCPListener{}
|
||||
l1 := &net.TCPListener{}
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddListener(l0, a)
|
||||
r.AddListener(l1, a)
|
||||
|
||||
i := r.Unsubscribe(a)
|
||||
if !reflect.DeepEqual(i.Listeners, []net.Listener{l0, l1}) {
|
||||
t.Fatal("RegistryItem should contain hosts")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AddHostOnlyToSubscribed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
if err := r.AddHost("host0", nil, a); err != errClientNotSubscribed {
|
||||
t.Fatal("Adding host should be possible to subscribned clients only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AddHostAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddHost("host0", &Auth{User: "A", Password: "B"}, a)
|
||||
|
||||
_, auth, _ := r.Subscriber("host0")
|
||||
if !reflect.DeepEqual(auth, &Auth{User: "A", Password: "B"}) {
|
||||
t.Fatal("Expected auth")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AddHostTrimsPort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddHost("host0", nil, a)
|
||||
r.AddHost("host1:80", nil, a)
|
||||
|
||||
tests := []string{
|
||||
"host0",
|
||||
"host0:80",
|
||||
"host0:8080",
|
||||
"host1",
|
||||
"host1:80",
|
||||
"host1:8080",
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
identifier, auth, ok := r.Subscriber(tt)
|
||||
if !ok {
|
||||
t.Fatal("Subscriber not found")
|
||||
}
|
||||
if auth != nil {
|
||||
t.Fatal("Unexpeted auth")
|
||||
}
|
||||
if identifier != a {
|
||||
t.Fatal("Unexpeted identifier")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AddHostOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddHost("host0", nil, a)
|
||||
|
||||
tests := []string{
|
||||
"host0",
|
||||
"host0:80",
|
||||
"host0:8080",
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if err := r.AddHost(tt, nil, a); !strings.Contains(err.Error(), "occupied") {
|
||||
t.Log(tt)
|
||||
t.Errorf("Adding host %q should fail", tt)
|
||||
}
|
||||
}
|
||||
|
||||
r.Subscribe(b)
|
||||
|
||||
for _, tt := range tests {
|
||||
if err := r.AddHost(tt, nil, b); !strings.Contains(err.Error(), "occupied") {
|
||||
t.Log(err)
|
||||
t.Errorf("Adding host %q should fail", tt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_DeleteHost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddHost("host0", nil, a)
|
||||
|
||||
r.DeleteHost("host0", a)
|
||||
|
||||
if _, _, ok := r.Subscriber("host0"); ok {
|
||||
t.Fatal("Should delete host for a")
|
||||
}
|
||||
|
||||
i := r.Unsubscribe(a)
|
||||
if len(i.Hosts) != 0 {
|
||||
t.Fatal("Host was not deleted from item")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_DeleteOnlyOwnedHost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddHost("host0", nil, a)
|
||||
|
||||
r.DeleteHost("host0", b)
|
||||
|
||||
if _, _, ok := r.Subscriber("host0"); !ok {
|
||||
t.Fatal("Should not delete host for b")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AddListenerOnlyToSubscribed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
if err := r.AddListener(&net.TCPListener{}, a); err != errClientNotSubscribed {
|
||||
t.Fatal("Adding listener should be possible to subscribned clients only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AddListenerOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := &net.TCPListener{}
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddListener(l, a)
|
||||
|
||||
if err := r.AddListener(l, a); err == nil {
|
||||
t.Fatal("Adding listenr should fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_DeleteListener(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := &net.TCPListener{}
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddListener(l, a)
|
||||
|
||||
r.DeleteListener(l, a)
|
||||
|
||||
i := r.Unsubscribe(a)
|
||||
if len(i.Listeners) != 0 {
|
||||
t.Fatal("Host was not deleted from item")
|
||||
}
|
||||
}
|
||||
269
server.go
269
server.go
|
|
@ -1,6 +1,7 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
|
@ -45,26 +46,28 @@ type Server struct {
|
|||
func NewServer(config *ServerConfig) (*Server, error) {
|
||||
listener, err := listener(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tls listener failed :%s", err)
|
||||
return nil, fmt.Errorf("tls listener failedAuthorization: %s", err)
|
||||
}
|
||||
|
||||
t := &http2.Transport{}
|
||||
pool := newConnPool(t)
|
||||
t.ConnPool = pool
|
||||
|
||||
logger := config.Logger
|
||||
if logger == nil {
|
||||
logger = log.NewNopLogger()
|
||||
}
|
||||
|
||||
return &Server{
|
||||
registry: newRegistry(),
|
||||
config: config,
|
||||
listener: listener,
|
||||
connPool: pool,
|
||||
httpClient: &http.Client{Transport: t},
|
||||
logger: logger,
|
||||
}, nil
|
||||
s := &Server{
|
||||
registry: newRegistry(logger),
|
||||
config: config,
|
||||
listener: listener,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
t := &http2.Transport{}
|
||||
pool := newConnPool(t, s.disconnected)
|
||||
t.ConnPool = pool
|
||||
s.connPool = pool
|
||||
s.httpClient = &http.Client{Transport: t}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func listener(config *ServerConfig) (net.Listener, error) {
|
||||
|
|
@ -82,18 +85,40 @@ func listener(config *ServerConfig) (net.Listener, error) {
|
|||
return tls.Listen("tcp", config.Addr, config.TLSConfig)
|
||||
}
|
||||
|
||||
// Start starts accepting connections form clients and allowed clients listeners.
|
||||
// For accepting http traffic one must run server as a handler to http server.
|
||||
// disconnected clears resources used by client, it's invoked by connection pool
|
||||
// when client goes away.
|
||||
func (s *Server) disconnected(identifier id.ID) {
|
||||
s.logger.Log(
|
||||
"level", 1,
|
||||
"action", "disconnected",
|
||||
"identifier", identifier,
|
||||
)
|
||||
|
||||
i := s.registry.clear(identifier)
|
||||
if i == nil {
|
||||
return
|
||||
}
|
||||
for _, l := range i.Listeners {
|
||||
s.logger.Log(
|
||||
"level", 2,
|
||||
"action", "close listener",
|
||||
"identifier", identifier,
|
||||
"addr", l.Addr(),
|
||||
)
|
||||
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts accepting connections form clients. For accepting http traffic
|
||||
// from end users server must be run as handler on http server.
|
||||
func (s *Server) Start() {
|
||||
s.logger.Log(
|
||||
"level", 1,
|
||||
"action", "start",
|
||||
"addr", s.listener.Addr(),
|
||||
)
|
||||
go s.listenControl()
|
||||
}
|
||||
|
||||
func (s *Server) listenControl() {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
|
|
@ -187,6 +212,11 @@ func (s *Server) handleClient(conn net.Conn) {
|
|||
goto reject
|
||||
}
|
||||
|
||||
{
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
req.WithContext(ctx)
|
||||
}
|
||||
resp, err = s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
logger.Log(
|
||||
|
|
@ -196,46 +226,53 @@ func (s *Server) handleClient(conn net.Conn) {
|
|||
)
|
||||
goto reject
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
err = fmt.Errorf("Status %s", resp.Status)
|
||||
logger.Log(
|
||||
"level", 2,
|
||||
"msg", "handshake failed",
|
||||
"err", fmt.Errorf("Status %s", resp.Status),
|
||||
"err", err,
|
||||
)
|
||||
goto reject
|
||||
}
|
||||
if resp.ContentLength > 0 {
|
||||
var done chan struct{}
|
||||
go func() {
|
||||
err = json.NewDecoder(&io.LimitedReader{
|
||||
R: resp.Body,
|
||||
N: 126976,
|
||||
}).Decode(&tunnels)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Minute):
|
||||
err = fmt.Errorf("timeout")
|
||||
}
|
||||
if resp.ContentLength == 0 {
|
||||
err = fmt.Errorf("Tunnels Content-Legth: 0")
|
||||
logger.Log(
|
||||
"level", 2,
|
||||
"msg", "handshake failed",
|
||||
"err", err,
|
||||
)
|
||||
goto reject
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Log(
|
||||
"level", 2,
|
||||
"msg", "handshake failed",
|
||||
"err", err,
|
||||
)
|
||||
goto reject
|
||||
}
|
||||
if err = s.AddTunnels(tunnels, identifier); err != nil {
|
||||
logger.Log(
|
||||
"level", 2,
|
||||
"msg", "handshake failed",
|
||||
"err", err,
|
||||
)
|
||||
goto reject
|
||||
}
|
||||
if err = json.NewDecoder(&io.LimitedReader{R: resp.Body, N: 126976}).Decode(&tunnels); err != nil {
|
||||
logger.Log(
|
||||
"level", 2,
|
||||
"msg", "handshake failed",
|
||||
"err", err,
|
||||
)
|
||||
goto reject
|
||||
}
|
||||
|
||||
if len(tunnels) == 0 {
|
||||
err = fmt.Errorf("No tunnels")
|
||||
logger.Log(
|
||||
"level", 2,
|
||||
"msg", "handshake failed",
|
||||
"err", err,
|
||||
)
|
||||
goto reject
|
||||
}
|
||||
|
||||
if err = s.addTunnels(tunnels, identifier); err != nil {
|
||||
logger.Log(
|
||||
"level", 2,
|
||||
"msg", "handshake failed",
|
||||
"err", err,
|
||||
)
|
||||
goto reject
|
||||
}
|
||||
|
||||
logger.Log(
|
||||
|
|
@ -246,60 +283,89 @@ func (s *Server) handleClient(conn net.Conn) {
|
|||
return
|
||||
|
||||
reject:
|
||||
s.logger.Log(
|
||||
logger.Log(
|
||||
"level", 1,
|
||||
"action", "rejected",
|
||||
"addr", conn.RemoteAddr(),
|
||||
)
|
||||
|
||||
s.notifyError(err, identifier)
|
||||
s.connPool.DeleteConn(identifier)
|
||||
}
|
||||
|
||||
// AddTunnels invokes AddHost or AddListener based on data from proto.Tunnel. If
|
||||
// a tunnel cannot be added whole batch is reverted.
|
||||
func (s *Server) AddTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) error {
|
||||
var (
|
||||
hosts []string
|
||||
listeners []net.Listener
|
||||
err error
|
||||
)
|
||||
// notifyError tries to send error to client.
|
||||
func (s *Server) notifyError(serverError error, identifier id.ID) {
|
||||
if serverError == nil {
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodConnect, s.connPool.URL(identifier), nil)
|
||||
if err != nil {
|
||||
s.logger.Log(
|
||||
"level", 2,
|
||||
"action", "client error notification failed",
|
||||
"identifier", identifier,
|
||||
"err", err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set(proto.ErrorHeader, serverError.Error())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
req.WithContext(ctx)
|
||||
|
||||
s.httpClient.Do(req)
|
||||
}
|
||||
|
||||
// addTunnels invokes addHost or addListener based on data from proto.Tunnel. If
|
||||
// a tunnel cannot be added whole batch is reverted.
|
||||
func (s *Server) addTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) error {
|
||||
i := &RegistryItem{
|
||||
Hosts: []*HostAuth{},
|
||||
Listeners: []net.Listener{},
|
||||
}
|
||||
|
||||
var err error
|
||||
for name, t := range tunnels {
|
||||
switch t.Protocol {
|
||||
case proto.HTTP:
|
||||
err = s.AddHost(t.Host, NewAuth(t.Auth), identifier)
|
||||
if err != nil {
|
||||
goto rollback
|
||||
}
|
||||
hosts = append(hosts, t.Host)
|
||||
i.Hosts = append(i.Hosts, &HostAuth{t.Host, NewAuth(t.Auth)})
|
||||
case proto.TCP, proto.TCP4, proto.TCP6, proto.UNIX:
|
||||
var l net.Listener
|
||||
l, err = net.Listen(t.Protocol, t.Addr)
|
||||
if err != nil {
|
||||
goto rollback
|
||||
}
|
||||
listeners = append(listeners, l)
|
||||
|
||||
err := s.AddListener(l, identifier)
|
||||
if err != nil {
|
||||
goto rollback
|
||||
}
|
||||
s.logger.Log(
|
||||
"level", 2,
|
||||
"action", "open listener",
|
||||
"identifier", identifier,
|
||||
"addr", l.Addr(),
|
||||
)
|
||||
|
||||
i.Listeners = append(i.Listeners, l)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported protocol for tunnel %s: %s", name, t.Protocol)
|
||||
goto rollback
|
||||
}
|
||||
}
|
||||
|
||||
err = s.set(i, identifier)
|
||||
if err != nil {
|
||||
goto rollback
|
||||
}
|
||||
|
||||
for _, l := range i.Listeners {
|
||||
s.listen(l, identifier)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
rollback:
|
||||
for _, h := range hosts {
|
||||
s.DeleteHost(h, identifier)
|
||||
}
|
||||
|
||||
for _, l := range listeners {
|
||||
for _, l := range i.Listeners {
|
||||
l.Close()
|
||||
s.DeleteListener(l, identifier)
|
||||
}
|
||||
|
||||
return err
|
||||
|
|
@ -312,17 +378,6 @@ func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem {
|
|||
return s.registry.Unsubscribe(identifier)
|
||||
}
|
||||
|
||||
// AddListener adds listener to client.
|
||||
func (s *Server) AddListener(l net.Listener, identifier id.ID) error {
|
||||
if err := s.registry.AddListener(l, identifier); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go s.listen(l, identifier)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) listen(l net.Listener, identifier id.ID) {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
|
|
@ -330,6 +385,7 @@ func (s *Server) listen(l net.Listener, identifier id.ID) {
|
|||
s.logger.Log(
|
||||
"level", 2,
|
||||
"msg", "accept connection failed",
|
||||
"identifier", identifier,
|
||||
"err", err,
|
||||
)
|
||||
if strings.Contains(err.Error(), "use of closed network connection") {
|
||||
|
|
@ -352,6 +408,7 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe
|
|||
s.logger.Log(
|
||||
"level", 2,
|
||||
"action", "proxy",
|
||||
"identifier", identifier,
|
||||
"ctrlMsg", msg,
|
||||
)
|
||||
|
||||
|
|
@ -364,6 +421,7 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe
|
|||
s.logger.Log(
|
||||
"level", 0,
|
||||
"msg", "proxy error",
|
||||
"identifier", identifier,
|
||||
"ctrlMsg", msg,
|
||||
"err", err,
|
||||
)
|
||||
|
|
@ -371,17 +429,22 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe
|
|||
return
|
||||
}
|
||||
|
||||
go transfer(pw, conn, log.NewContext(s.logger).With(
|
||||
"dir", "user to client",
|
||||
"dst", identifier,
|
||||
"src", conn.RemoteAddr(),
|
||||
))
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
transfer(pw, conn, log.NewContext(s.logger).With(
|
||||
"dir", "user to client",
|
||||
"dst", identifier,
|
||||
"src", conn.RemoteAddr(),
|
||||
))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
s.logger.Log(
|
||||
"level", 0,
|
||||
"msg", "proxy error",
|
||||
"identifier", identifier,
|
||||
"ctrlMsg", msg,
|
||||
"err", err,
|
||||
)
|
||||
|
|
@ -394,17 +457,27 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe
|
|||
"dst", conn.RemoteAddr(),
|
||||
"src", identifier,
|
||||
))
|
||||
|
||||
<-done
|
||||
}
|
||||
|
||||
// ServeHTTP proxies http connection to the client.
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
resp, err := s.RoundTrip(r)
|
||||
|
||||
if err == errUnauthorised {
|
||||
w.Header().Set("WWW-Authenticate", "Basic realm=\"User Visible Realm\"")
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.logger.Log(
|
||||
"level", 0,
|
||||
"action", "round trip failed",
|
||||
"addr", r.RemoteAddr,
|
||||
"url", r.URL,
|
||||
"err", err,
|
||||
)
|
||||
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
|
@ -438,6 +511,7 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
|
|||
if auth.User != user || auth.Password != password {
|
||||
return nil, errUnauthorised
|
||||
}
|
||||
r.Header.Del("Authorization")
|
||||
}
|
||||
|
||||
return s.proxyHTTP(identifier, r, msg)
|
||||
|
|
@ -447,6 +521,7 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control
|
|||
s.logger.Log(
|
||||
"level", 2,
|
||||
"action", "proxy",
|
||||
"identifier", identifier,
|
||||
"ctrlMsg", msg,
|
||||
)
|
||||
|
||||
|
|
@ -466,6 +541,7 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control
|
|||
s.logger.Log(
|
||||
"level", 0,
|
||||
"msg", "proxy error",
|
||||
"identifier", identifier,
|
||||
"ctrlMsg", msg,
|
||||
"err", err,
|
||||
)
|
||||
|
|
@ -474,6 +550,7 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control
|
|||
s.logger.Log(
|
||||
"level", 3,
|
||||
"action", "transferred",
|
||||
"identifier", identifier,
|
||||
"bytes", cw.count,
|
||||
"dir", "user to client",
|
||||
"dst", r.Host,
|
||||
|
|
@ -490,6 +567,14 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control
|
|||
return nil, fmt.Errorf("proxy error: %s", err)
|
||||
}
|
||||
|
||||
s.logger.Log(
|
||||
"level", 2,
|
||||
"action", "proxy done",
|
||||
"identifier", identifier,
|
||||
"ctrlMsg", msg,
|
||||
"status code", resp.StatusCode,
|
||||
)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
10
tcpproxy.go
10
tcpproxy.go
|
|
@ -77,7 +77,7 @@ func (p *TCPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage
|
|||
return
|
||||
}
|
||||
|
||||
local, err := net.DialTimeout("tcp", target, DefaultDialTimeout)
|
||||
local, err := net.DialTimeout("tcp", target, DefaultTimeout)
|
||||
if err != nil {
|
||||
p.logger.Log(
|
||||
"level", 0,
|
||||
|
|
@ -100,6 +100,7 @@ func (p *TCPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage
|
|||
"dst", target,
|
||||
"src", msg.ForwardedBy,
|
||||
))
|
||||
<-done
|
||||
}
|
||||
|
||||
func (p *TCPProxy) localAddrFor(hostPort string) string {
|
||||
|
|
@ -107,7 +108,7 @@ func (p *TCPProxy) localAddrFor(hostPort string) string {
|
|||
return p.localAddr
|
||||
}
|
||||
|
||||
// try host and port
|
||||
// try hostPort
|
||||
if addr := p.localAddrMap[hostPort]; addr != "" {
|
||||
return addr
|
||||
}
|
||||
|
|
@ -118,6 +119,11 @@ func (p *TCPProxy) localAddrFor(hostPort string) string {
|
|||
return addr
|
||||
}
|
||||
|
||||
// try 0.0.0.0:port
|
||||
if addr := p.localAddrMap[fmt.Sprintf("0.0.0.0:%s", port)]; addr != "" {
|
||||
return addr
|
||||
}
|
||||
|
||||
// try host
|
||||
if addr := p.localAddrMap[host]; addr != "" {
|
||||
return addr
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue