mirror of
https://github.com/mmatczuk/go-http-tunnel.git
synced 2026-05-15 14:16:17 -06:00
server: basic auth
This commit is contained in:
parent
96c46be0f0
commit
c9768f8c1c
8 changed files with 292 additions and 81 deletions
1
TODO.md
1
TODO.md
|
|
@ -1,7 +1,6 @@
|
|||
Release 1.0
|
||||
|
||||
1. cli: cli and file configuration based on ngrok2 https://ngrok.com/docs#config
|
||||
1. security: basic auth on server
|
||||
1. docs: README update
|
||||
|
||||
Backlog
|
||||
|
|
|
|||
26
auth.go
Normal file
26
auth.go
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
package tunnel
|
||||
|
||||
import "strings"
|
||||
|
||||
// Auth holds user and password.
|
||||
type Auth struct {
|
||||
User string
|
||||
Password string
|
||||
}
|
||||
|
||||
// NewAuth creates new auth from string representation.
|
||||
func NewAuth(auth string) *Auth {
|
||||
if auth == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
s := strings.SplitN(auth, ":", 2)
|
||||
a := &Auth{
|
||||
User: s[0],
|
||||
}
|
||||
if len(s) > 1 {
|
||||
a.Password = s[1]
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
24
auth_test.go
Normal file
24
auth_test.go
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
actual string
|
||||
expected *Auth
|
||||
}{
|
||||
{"", nil},
|
||||
{"user", &Auth{User: "user"}},
|
||||
{"user:password", &Auth{User: "user", Password: "password"}},
|
||||
{"user:pass:word", &Auth{User: "user", Password: "pass:word"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if !reflect.DeepEqual(NewAuth(tt.actual), tt.expected) {
|
||||
t.Errorf("Invalid auth for %s", tt.actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -6,4 +6,6 @@ var (
|
|||
errClientNotSubscribed = errors.New("client not subscribed")
|
||||
errClientNotConnected = errors.New("client not connected")
|
||||
errClientAlreadyConnected = errors.New("client already connected")
|
||||
|
||||
errUnauthorised = errors.New("unauthorised")
|
||||
)
|
||||
|
|
|
|||
|
|
@ -56,7 +56,13 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
|
||||
s.Subscribe(identifier)
|
||||
if err := s.AddHost("localhost", identifier); err != nil {
|
||||
|
||||
auth := &tunnel.Auth{
|
||||
User: "user",
|
||||
Password: "password",
|
||||
}
|
||||
|
||||
if err := s.AddHost("localhost", auth, identifier); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := s.AddListener(serverTCPListener, identifier); err != nil {
|
||||
|
|
@ -165,6 +171,8 @@ func testHTTP(t *testing.T, payload []byte, repeat uint) {
|
|||
if err != nil {
|
||||
panic("Failed to create request")
|
||||
}
|
||||
r.SetBasicAuth("user", "password")
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("HTTP error %s", err))
|
||||
|
|
|
|||
50
registry.go
50
registry.go
|
|
@ -15,18 +15,23 @@ type RegistryItem struct {
|
|||
Listeners []net.Listener
|
||||
}
|
||||
|
||||
type hostInfo struct {
|
||||
identifier id.ID
|
||||
auth *Auth
|
||||
}
|
||||
|
||||
// registry manages client tunnels information.
|
||||
type registry struct {
|
||||
items map[id.ID]*RegistryItem
|
||||
hostIdx map[string]id.ID
|
||||
mu sync.RWMutex
|
||||
items map[id.ID]*RegistryItem
|
||||
hosts map[string]*hostInfo
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// newRegistry creates new registry.
|
||||
func newRegistry() *registry {
|
||||
return ®istry{
|
||||
items: make(map[id.ID]*RegistryItem),
|
||||
hostIdx: make(map[string]id.ID, 0),
|
||||
items: make(map[id.ID]*RegistryItem),
|
||||
hosts: make(map[string]*hostInfo, 0),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -54,14 +59,18 @@ func (r *registry) IsSubscribed(identifier id.ID) bool {
|
|||
}
|
||||
|
||||
// Subscriber returns client identifier assigned to given host.
|
||||
func (r *registry) Subscriber(hostPort string) (id.ID, bool) {
|
||||
func (r *registry) Subscriber(hostPort string) (id.ID, *Auth, bool) {
|
||||
host := trimPort(hostPort)
|
||||
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
identifier, ok := r.hostIdx[host]
|
||||
return identifier, ok
|
||||
h, ok := r.hosts[host]
|
||||
if !ok {
|
||||
return id.ID{}, nil, false
|
||||
}
|
||||
|
||||
return h.identifier, h.auth, ok
|
||||
}
|
||||
|
||||
// Unsubscribe removes client from registy and returns it's RegistryItem.
|
||||
|
|
@ -75,7 +84,7 @@ func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem {
|
|||
}
|
||||
|
||||
for _, h := range i.Hosts {
|
||||
delete(r.hostIdx, h)
|
||||
delete(r.hosts, h)
|
||||
}
|
||||
|
||||
delete(r.items, identifier)
|
||||
|
|
@ -84,21 +93,28 @@ func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem {
|
|||
}
|
||||
|
||||
// AddHost assigns host to client unless the host is not already taken.
|
||||
func (r *registry) AddHost(hostPort string, identifier id.ID) error {
|
||||
func (r *registry) AddHost(hostPort string, auth *Auth, identifier id.ID) error {
|
||||
host := trimPort(hostPort)
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if auth != nil && auth.User == "" {
|
||||
return fmt.Errorf("Missing auth user")
|
||||
}
|
||||
|
||||
i, ok := r.items[identifier]
|
||||
if !ok {
|
||||
return errClientNotSubscribed
|
||||
}
|
||||
|
||||
if _, ok := r.hostIdx[host]; ok {
|
||||
if _, ok := r.hosts[host]; ok {
|
||||
return fmt.Errorf("host %q is occupied", host)
|
||||
}
|
||||
r.hostIdx[host] = identifier
|
||||
r.hosts[host] = &hostInfo{
|
||||
identifier: identifier,
|
||||
auth: auth,
|
||||
}
|
||||
|
||||
i.Hosts = append(i.Hosts, host)
|
||||
|
||||
|
|
@ -112,11 +128,11 @@ func (r *registry) DeleteHost(hostPort string, identifier id.ID) {
|
|||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if hostIdentifier, ok := r.hostIdx[host]; !ok || hostIdentifier != identifier {
|
||||
if h, ok := r.hosts[host]; !ok || h.identifier != identifier {
|
||||
return
|
||||
}
|
||||
|
||||
delete(r.hostIdx, host)
|
||||
delete(r.hosts, host)
|
||||
|
||||
i := r.items[identifier]
|
||||
for k, v := range i.Hosts {
|
||||
|
|
@ -141,6 +157,12 @@ func (r *registry) AddListener(l net.Listener, identifier id.ID) error {
|
|||
return errClientNotSubscribed
|
||||
}
|
||||
|
||||
for k, v := range i.Listeners {
|
||||
if v == l {
|
||||
return fmt.Errorf("listener already added at %d", k)
|
||||
}
|
||||
}
|
||||
|
||||
i.Listeners = append(i.Listeners, l)
|
||||
|
||||
return nil
|
||||
|
|
|
|||
243
registry_test.go
243
registry_test.go
|
|
@ -2,60 +2,157 @@ package tunnel
|
|||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mmatczuk/tunnel/id"
|
||||
)
|
||||
|
||||
func TestRegistry_Subscribe(t *testing.T) {
|
||||
var (
|
||||
a = id.NewFromString("A")
|
||||
b = id.NewFromString("B")
|
||||
)
|
||||
|
||||
func TestRegistry_IsSubscribed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(id.NewFromString("A"))
|
||||
r.Subscribe(a)
|
||||
|
||||
if ok := r.IsSubscribed(id.NewFromString("A")); !ok {
|
||||
if !r.IsSubscribed(a) {
|
||||
t.Fatal("Client should be subscribed")
|
||||
}
|
||||
|
||||
if i := r.Unsubscribe(id.NewFromString("A")); i == nil {
|
||||
t.Fatal("Unsubscribe should return RegistryItem")
|
||||
}
|
||||
if i := r.Unsubscribe(id.NewFromString("A")); i != nil {
|
||||
t.Fatal("Unsubscribe for not existing client should return null")
|
||||
if r.IsSubscribed(b) {
|
||||
t.Fatal("Client should not be subscribed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AddHost(t *testing.T) {
|
||||
func TestRegistry_UnsubscribeOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
if err := r.AddHost("foobar", id.NewFromString("A")); err != errClientNotSubscribed {
|
||||
t.Fatal("AddHost to not subscribed client should fail")
|
||||
r.Subscribe(a)
|
||||
|
||||
if r.Unsubscribe(a) == nil {
|
||||
t.Fatal("Unsubscribe should return RegistryItem")
|
||||
}
|
||||
if r.Unsubscribe(a) != nil {
|
||||
t.Fatal("Unsubscribe should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_UnsubscribeReturnsHosts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddHost("host0", nil, a)
|
||||
r.AddHost("host1", nil, a)
|
||||
|
||||
i := r.Unsubscribe(a)
|
||||
if !reflect.DeepEqual(i.Hosts, []string{"host0", "host1"}) {
|
||||
t.Fatal("RegistryItem should contain hosts")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_UnsubscribeReturnsListeners(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l0 := &net.TCPListener{}
|
||||
l1 := &net.TCPListener{}
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddListener(l0, a)
|
||||
r.AddListener(l1, a)
|
||||
|
||||
i := r.Unsubscribe(a)
|
||||
if !reflect.DeepEqual(i.Listeners, []net.Listener{l0, l1}) {
|
||||
t.Fatal("RegistryItem should contain hosts")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AddHostOnlyToSubscribed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
if err := r.AddHost("host0", nil, a); err != errClientNotSubscribed {
|
||||
t.Fatal("Adding host should be possible to subscribned clients only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AddHostAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddHost("host0", &Auth{User: "A", Password: "B"}, a)
|
||||
|
||||
_, auth, _ := r.Subscriber("host0")
|
||||
if !reflect.DeepEqual(auth, &Auth{User: "A", Password: "B"}) {
|
||||
t.Fatal("Expected auth")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AddHostTrimsPort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddHost("host0", nil, a)
|
||||
r.AddHost("host1:80", nil, a)
|
||||
|
||||
tests := []string{
|
||||
"host0",
|
||||
"host0:80",
|
||||
"host0:8080",
|
||||
"host1",
|
||||
"host1:80",
|
||||
"host1:8080",
|
||||
}
|
||||
|
||||
r.Subscribe(id.NewFromString("A"))
|
||||
for _, tt := range tests {
|
||||
identifier, auth, ok := r.Subscriber(tt)
|
||||
if !ok {
|
||||
t.Fatal("Subscriber not found")
|
||||
}
|
||||
if auth != nil {
|
||||
t.Fatal("Unexpeted auth")
|
||||
}
|
||||
if identifier != a {
|
||||
t.Fatal("Unexpeted identifier")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.AddHost("foobar:8080", id.NewFromString("A")); err != nil {
|
||||
t.Fatal("AddHost should succeed")
|
||||
func TestRegistry_AddHostOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddHost("host0", nil, a)
|
||||
|
||||
tests := []string{
|
||||
"host0",
|
||||
"host0:80",
|
||||
"host0:8080",
|
||||
}
|
||||
|
||||
r.Subscribe(id.NewFromString("B"))
|
||||
|
||||
if err := r.AddHost("foobar", id.NewFromString("B")); err == nil {
|
||||
t.Fatal("AddHost for duplicate host should fail")
|
||||
for _, tt := range tests {
|
||||
if err := r.AddHost(tt, nil, a); !strings.Contains(err.Error(), "occupied") {
|
||||
t.Log(tt)
|
||||
t.Errorf("Adding host %q should fail", tt)
|
||||
}
|
||||
}
|
||||
|
||||
if identifier, ok := r.Subscriber("foobar"); !ok || identifier != id.NewFromString("A") {
|
||||
t.Fatal("Wrong subscriber")
|
||||
}
|
||||
if identifier, ok := r.Subscriber("foobar:8080"); !ok || identifier != id.NewFromString("A") {
|
||||
t.Fatal("Wrong subscriber")
|
||||
}
|
||||
r.Subscribe(b)
|
||||
|
||||
r.Unsubscribe(id.NewFromString("A"))
|
||||
|
||||
if err := r.AddHost("foobar", id.NewFromString("B")); err != nil {
|
||||
t.Fatal("Unsubsribe failed to remove host")
|
||||
for _, tt := range tests {
|
||||
if err := r.AddHost(tt, nil, b); !strings.Contains(err.Error(), "occupied") {
|
||||
t.Log(err)
|
||||
t.Errorf("Adding host %q should fail", tt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -63,49 +160,71 @@ func TestRegistry_DeleteHost(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(id.NewFromString("A"))
|
||||
r.Subscribe(a)
|
||||
r.AddHost("host0", nil, a)
|
||||
|
||||
if err := r.AddHost("foobar:8080", id.NewFromString("A")); err != nil {
|
||||
t.Fatal("AddHost should succeed")
|
||||
r.DeleteHost("host0", a)
|
||||
|
||||
if _, _, ok := r.Subscriber("host0"); ok {
|
||||
t.Fatal("Should delete host for a")
|
||||
}
|
||||
|
||||
if identifier, ok := r.Subscriber("foobar"); !ok || identifier != id.NewFromString("A") {
|
||||
t.Fatal("Wrong subscriber")
|
||||
}
|
||||
|
||||
if err := r.AddHost("foobar:8080", id.NewFromString("A")); err == nil {
|
||||
t.Fatal("AddHost for duplicate host should fail")
|
||||
}
|
||||
|
||||
r.DeleteHost("foobar", id.NewFromString("A"))
|
||||
|
||||
if _, ok := r.Subscriber("foobar"); ok {
|
||||
t.Fatal("DeleteHost failed to delete host")
|
||||
}
|
||||
|
||||
if err := r.AddHost("foobar:8080", id.NewFromString("A")); err != nil {
|
||||
t.Fatal("AddHost should succeed")
|
||||
}
|
||||
|
||||
r.Subscribe(id.NewFromString("B"))
|
||||
r.DeleteHost("foobar", id.NewFromString("B"))
|
||||
|
||||
if _, ok := r.Subscriber("foobar"); !ok {
|
||||
t.Fatal("DeleteHost forgein host should have no effect")
|
||||
i := r.Unsubscribe(a)
|
||||
if len(i.Hosts) != 0 {
|
||||
t.Fatal("Host was not deleted from item")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AddListener(t *testing.T) {
|
||||
func TestRegistry_DeleteOnlyOwnedHost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
if err := r.AddListener(&net.TCPListener{}, id.NewFromString("A")); err != errClientNotSubscribed {
|
||||
t.Fatal("AddListener to not subscribed client should fail")
|
||||
}
|
||||
r.Subscribe(a)
|
||||
r.AddHost("host0", nil, a)
|
||||
|
||||
r.Subscribe(id.NewFromString("A"))
|
||||
r.DeleteHost("host0", b)
|
||||
|
||||
if err := r.AddListener(&net.TCPListener{}, id.NewFromString("A")); err != nil {
|
||||
t.Fatal("AddListener should succeed")
|
||||
if _, _, ok := r.Subscriber("host0"); !ok {
|
||||
t.Fatal("Should not delete host for b")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AddListenerOnlyToSubscribed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRegistry()
|
||||
if err := r.AddListener(&net.TCPListener{}, a); err != errClientNotSubscribed {
|
||||
t.Fatal("Adding listener should be possible to subscribned clients only")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AddListenerOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := &net.TCPListener{}
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddListener(l, a)
|
||||
|
||||
if err := r.AddListener(l, a); err == nil {
|
||||
t.Fatal("Adding listenr should fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_DeleteListener(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := &net.TCPListener{}
|
||||
|
||||
r := newRegistry()
|
||||
r.Subscribe(a)
|
||||
r.AddListener(l, a)
|
||||
|
||||
r.DeleteListener(l, a)
|
||||
|
||||
i := r.Unsubscribe(a)
|
||||
if len(i.Listeners) != 0 {
|
||||
t.Fatal("Host was not deleted from item")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
17
server.go
17
server.go
|
|
@ -267,7 +267,7 @@ func (s *Server) AddTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID)
|
|||
for name, t := range tunnels {
|
||||
switch t.Protocol {
|
||||
case proto.HTTP:
|
||||
err = s.AddHost(t.Host, identifier)
|
||||
err = s.AddHost(t.Host, NewAuth(t.Auth), identifier)
|
||||
if err != nil {
|
||||
goto rollback
|
||||
}
|
||||
|
|
@ -399,6 +399,11 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe
|
|||
// ServeHTTP proxies http connection to the client.
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
resp, err := s.RoundTrip(r)
|
||||
|
||||
if err == errUnauthorised {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
|
|
@ -424,9 +429,15 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
|
|||
ForwardedBy: r.Host,
|
||||
}
|
||||
|
||||
identifier, ok := s.Subscriber(r.Host)
|
||||
identifier, auth, ok := s.Subscriber(r.Host)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("proxy request error: %s", errClientNotSubscribed)
|
||||
return nil, errClientNotSubscribed
|
||||
}
|
||||
if auth != nil {
|
||||
user, password, _ := r.BasicAuth()
|
||||
if auth.User != user || auth.Password != password {
|
||||
return nil, errUnauthorised
|
||||
}
|
||||
}
|
||||
|
||||
return s.proxyHTTP(identifier, r, msg)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue