server: basic auth

This commit is contained in:
mmatczuk 2017-02-08 14:15:23 +01:00
parent 96c46be0f0
commit c9768f8c1c
8 changed files with 292 additions and 81 deletions

View file

@ -1,7 +1,6 @@
Release 1.0
1. cli: cli and file configuration based on ngrok2 https://ngrok.com/docs#config
1. security: basic auth on server
1. docs: README update
Backlog

26
auth.go Normal file
View file

@ -0,0 +1,26 @@
package tunnel
import "strings"
// Auth holds user and password.
type Auth struct {
User string
Password string
}
// NewAuth creates new auth from string representation.
func NewAuth(auth string) *Auth {
if auth == "" {
return nil
}
s := strings.SplitN(auth, ":", 2)
a := &Auth{
User: s[0],
}
if len(s) > 1 {
a.Password = s[1]
}
return a
}

24
auth_test.go Normal file
View file

@ -0,0 +1,24 @@
package tunnel
import (
"reflect"
"testing"
)
func TestNewAuth(t *testing.T) {
tests := []struct {
actual string
expected *Auth
}{
{"", nil},
{"user", &Auth{User: "user"}},
{"user:password", &Auth{User: "user", Password: "password"}},
{"user:pass:word", &Auth{User: "user", Password: "pass:word"}},
}
for _, tt := range tests {
if !reflect.DeepEqual(NewAuth(tt.actual), tt.expected) {
t.Errorf("Invalid auth for %s", tt.actual)
}
}
}

View file

@ -6,4 +6,6 @@ var (
errClientNotSubscribed = errors.New("client not subscribed")
errClientNotConnected = errors.New("client not connected")
errClientAlreadyConnected = errors.New("client already connected")
errUnauthorised = errors.New("unauthorised")
)

View file

@ -56,7 +56,13 @@ func TestMain(m *testing.M) {
}
s.Subscribe(identifier)
if err := s.AddHost("localhost", identifier); err != nil {
auth := &tunnel.Auth{
User: "user",
Password: "password",
}
if err := s.AddHost("localhost", auth, identifier); err != nil {
panic(err)
}
if err := s.AddListener(serverTCPListener, identifier); err != nil {
@ -165,6 +171,8 @@ func testHTTP(t *testing.T, payload []byte, repeat uint) {
if err != nil {
panic("Failed to create request")
}
r.SetBasicAuth("user", "password")
resp, err := http.DefaultClient.Do(r)
if err != nil {
panic(fmt.Sprintf("HTTP error %s", err))

View file

@ -15,18 +15,23 @@ type RegistryItem struct {
Listeners []net.Listener
}
type hostInfo struct {
identifier id.ID
auth *Auth
}
// registry manages client tunnels information.
type registry struct {
items map[id.ID]*RegistryItem
hostIdx map[string]id.ID
mu sync.RWMutex
items map[id.ID]*RegistryItem
hosts map[string]*hostInfo
mu sync.RWMutex
}
// newRegistry creates new registry.
func newRegistry() *registry {
return &registry{
items: make(map[id.ID]*RegistryItem),
hostIdx: make(map[string]id.ID, 0),
items: make(map[id.ID]*RegistryItem),
hosts: make(map[string]*hostInfo, 0),
}
}
@ -54,14 +59,18 @@ func (r *registry) IsSubscribed(identifier id.ID) bool {
}
// Subscriber returns client identifier assigned to given host.
func (r *registry) Subscriber(hostPort string) (id.ID, bool) {
func (r *registry) Subscriber(hostPort string) (id.ID, *Auth, bool) {
host := trimPort(hostPort)
r.mu.RLock()
defer r.mu.RUnlock()
identifier, ok := r.hostIdx[host]
return identifier, ok
h, ok := r.hosts[host]
if !ok {
return id.ID{}, nil, false
}
return h.identifier, h.auth, ok
}
// Unsubscribe removes client from registy and returns it's RegistryItem.
@ -75,7 +84,7 @@ func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem {
}
for _, h := range i.Hosts {
delete(r.hostIdx, h)
delete(r.hosts, h)
}
delete(r.items, identifier)
@ -84,21 +93,28 @@ func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem {
}
// AddHost assigns host to client unless the host is not already taken.
func (r *registry) AddHost(hostPort string, identifier id.ID) error {
func (r *registry) AddHost(hostPort string, auth *Auth, identifier id.ID) error {
host := trimPort(hostPort)
r.mu.Lock()
defer r.mu.Unlock()
if auth != nil && auth.User == "" {
return fmt.Errorf("Missing auth user")
}
i, ok := r.items[identifier]
if !ok {
return errClientNotSubscribed
}
if _, ok := r.hostIdx[host]; ok {
if _, ok := r.hosts[host]; ok {
return fmt.Errorf("host %q is occupied", host)
}
r.hostIdx[host] = identifier
r.hosts[host] = &hostInfo{
identifier: identifier,
auth: auth,
}
i.Hosts = append(i.Hosts, host)
@ -112,11 +128,11 @@ func (r *registry) DeleteHost(hostPort string, identifier id.ID) {
r.mu.Lock()
defer r.mu.Unlock()
if hostIdentifier, ok := r.hostIdx[host]; !ok || hostIdentifier != identifier {
if h, ok := r.hosts[host]; !ok || h.identifier != identifier {
return
}
delete(r.hostIdx, host)
delete(r.hosts, host)
i := r.items[identifier]
for k, v := range i.Hosts {
@ -141,6 +157,12 @@ func (r *registry) AddListener(l net.Listener, identifier id.ID) error {
return errClientNotSubscribed
}
for k, v := range i.Listeners {
if v == l {
return fmt.Errorf("listener already added at %d", k)
}
}
i.Listeners = append(i.Listeners, l)
return nil

View file

@ -2,60 +2,157 @@ package tunnel
import (
"net"
"reflect"
"strings"
"testing"
"github.com/mmatczuk/tunnel/id"
)
func TestRegistry_Subscribe(t *testing.T) {
var (
a = id.NewFromString("A")
b = id.NewFromString("B")
)
func TestRegistry_IsSubscribed(t *testing.T) {
t.Parallel()
r := newRegistry()
r.Subscribe(id.NewFromString("A"))
r.Subscribe(a)
if ok := r.IsSubscribed(id.NewFromString("A")); !ok {
if !r.IsSubscribed(a) {
t.Fatal("Client should be subscribed")
}
if i := r.Unsubscribe(id.NewFromString("A")); i == nil {
t.Fatal("Unsubscribe should return RegistryItem")
}
if i := r.Unsubscribe(id.NewFromString("A")); i != nil {
t.Fatal("Unsubscribe for not existing client should return null")
if r.IsSubscribed(b) {
t.Fatal("Client should not be subscribed")
}
}
func TestRegistry_AddHost(t *testing.T) {
func TestRegistry_UnsubscribeOnce(t *testing.T) {
t.Parallel()
r := newRegistry()
if err := r.AddHost("foobar", id.NewFromString("A")); err != errClientNotSubscribed {
t.Fatal("AddHost to not subscribed client should fail")
r.Subscribe(a)
if r.Unsubscribe(a) == nil {
t.Fatal("Unsubscribe should return RegistryItem")
}
if r.Unsubscribe(a) != nil {
t.Fatal("Unsubscribe should return nil")
}
}
func TestRegistry_UnsubscribeReturnsHosts(t *testing.T) {
t.Parallel()
r := newRegistry()
r.Subscribe(a)
r.AddHost("host0", nil, a)
r.AddHost("host1", nil, a)
i := r.Unsubscribe(a)
if !reflect.DeepEqual(i.Hosts, []string{"host0", "host1"}) {
t.Fatal("RegistryItem should contain hosts")
}
}
func TestRegistry_UnsubscribeReturnsListeners(t *testing.T) {
t.Parallel()
l0 := &net.TCPListener{}
l1 := &net.TCPListener{}
r := newRegistry()
r.Subscribe(a)
r.AddListener(l0, a)
r.AddListener(l1, a)
i := r.Unsubscribe(a)
if !reflect.DeepEqual(i.Listeners, []net.Listener{l0, l1}) {
t.Fatal("RegistryItem should contain hosts")
}
}
func TestRegistry_AddHostOnlyToSubscribed(t *testing.T) {
t.Parallel()
r := newRegistry()
if err := r.AddHost("host0", nil, a); err != errClientNotSubscribed {
t.Fatal("Adding host should be possible to subscribned clients only")
}
}
func TestRegistry_AddHostAuth(t *testing.T) {
t.Parallel()
r := newRegistry()
r.Subscribe(a)
r.AddHost("host0", &Auth{User: "A", Password: "B"}, a)
_, auth, _ := r.Subscriber("host0")
if !reflect.DeepEqual(auth, &Auth{User: "A", Password: "B"}) {
t.Fatal("Expected auth")
}
}
func TestRegistry_AddHostTrimsPort(t *testing.T) {
t.Parallel()
r := newRegistry()
r.Subscribe(a)
r.AddHost("host0", nil, a)
r.AddHost("host1:80", nil, a)
tests := []string{
"host0",
"host0:80",
"host0:8080",
"host1",
"host1:80",
"host1:8080",
}
r.Subscribe(id.NewFromString("A"))
for _, tt := range tests {
identifier, auth, ok := r.Subscriber(tt)
if !ok {
t.Fatal("Subscriber not found")
}
if auth != nil {
t.Fatal("Unexpeted auth")
}
if identifier != a {
t.Fatal("Unexpeted identifier")
}
}
}
if err := r.AddHost("foobar:8080", id.NewFromString("A")); err != nil {
t.Fatal("AddHost should succeed")
func TestRegistry_AddHostOnce(t *testing.T) {
t.Parallel()
r := newRegistry()
r.Subscribe(a)
r.AddHost("host0", nil, a)
tests := []string{
"host0",
"host0:80",
"host0:8080",
}
r.Subscribe(id.NewFromString("B"))
if err := r.AddHost("foobar", id.NewFromString("B")); err == nil {
t.Fatal("AddHost for duplicate host should fail")
for _, tt := range tests {
if err := r.AddHost(tt, nil, a); !strings.Contains(err.Error(), "occupied") {
t.Log(tt)
t.Errorf("Adding host %q should fail", tt)
}
}
if identifier, ok := r.Subscriber("foobar"); !ok || identifier != id.NewFromString("A") {
t.Fatal("Wrong subscriber")
}
if identifier, ok := r.Subscriber("foobar:8080"); !ok || identifier != id.NewFromString("A") {
t.Fatal("Wrong subscriber")
}
r.Subscribe(b)
r.Unsubscribe(id.NewFromString("A"))
if err := r.AddHost("foobar", id.NewFromString("B")); err != nil {
t.Fatal("Unsubsribe failed to remove host")
for _, tt := range tests {
if err := r.AddHost(tt, nil, b); !strings.Contains(err.Error(), "occupied") {
t.Log(err)
t.Errorf("Adding host %q should fail", tt)
}
}
}
@ -63,49 +160,71 @@ func TestRegistry_DeleteHost(t *testing.T) {
t.Parallel()
r := newRegistry()
r.Subscribe(id.NewFromString("A"))
r.Subscribe(a)
r.AddHost("host0", nil, a)
if err := r.AddHost("foobar:8080", id.NewFromString("A")); err != nil {
t.Fatal("AddHost should succeed")
r.DeleteHost("host0", a)
if _, _, ok := r.Subscriber("host0"); ok {
t.Fatal("Should delete host for a")
}
if identifier, ok := r.Subscriber("foobar"); !ok || identifier != id.NewFromString("A") {
t.Fatal("Wrong subscriber")
}
if err := r.AddHost("foobar:8080", id.NewFromString("A")); err == nil {
t.Fatal("AddHost for duplicate host should fail")
}
r.DeleteHost("foobar", id.NewFromString("A"))
if _, ok := r.Subscriber("foobar"); ok {
t.Fatal("DeleteHost failed to delete host")
}
if err := r.AddHost("foobar:8080", id.NewFromString("A")); err != nil {
t.Fatal("AddHost should succeed")
}
r.Subscribe(id.NewFromString("B"))
r.DeleteHost("foobar", id.NewFromString("B"))
if _, ok := r.Subscriber("foobar"); !ok {
t.Fatal("DeleteHost forgein host should have no effect")
i := r.Unsubscribe(a)
if len(i.Hosts) != 0 {
t.Fatal("Host was not deleted from item")
}
}
func TestRegistry_AddListener(t *testing.T) {
func TestRegistry_DeleteOnlyOwnedHost(t *testing.T) {
t.Parallel()
r := newRegistry()
if err := r.AddListener(&net.TCPListener{}, id.NewFromString("A")); err != errClientNotSubscribed {
t.Fatal("AddListener to not subscribed client should fail")
}
r.Subscribe(a)
r.AddHost("host0", nil, a)
r.Subscribe(id.NewFromString("A"))
r.DeleteHost("host0", b)
if err := r.AddListener(&net.TCPListener{}, id.NewFromString("A")); err != nil {
t.Fatal("AddListener should succeed")
if _, _, ok := r.Subscriber("host0"); !ok {
t.Fatal("Should not delete host for b")
}
}
func TestRegistry_AddListenerOnlyToSubscribed(t *testing.T) {
t.Parallel()
r := newRegistry()
if err := r.AddListener(&net.TCPListener{}, a); err != errClientNotSubscribed {
t.Fatal("Adding listener should be possible to subscribned clients only")
}
}
func TestRegistry_AddListenerOnce(t *testing.T) {
t.Parallel()
l := &net.TCPListener{}
r := newRegistry()
r.Subscribe(a)
r.AddListener(l, a)
if err := r.AddListener(l, a); err == nil {
t.Fatal("Adding listenr should fail")
}
}
func TestRegistry_DeleteListener(t *testing.T) {
t.Parallel()
l := &net.TCPListener{}
r := newRegistry()
r.Subscribe(a)
r.AddListener(l, a)
r.DeleteListener(l, a)
i := r.Unsubscribe(a)
if len(i.Listeners) != 0 {
t.Fatal("Host was not deleted from item")
}
}

View file

@ -267,7 +267,7 @@ func (s *Server) AddTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID)
for name, t := range tunnels {
switch t.Protocol {
case proto.HTTP:
err = s.AddHost(t.Host, identifier)
err = s.AddHost(t.Host, NewAuth(t.Auth), identifier)
if err != nil {
goto rollback
}
@ -399,6 +399,11 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe
// ServeHTTP proxies http connection to the client.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
resp, err := s.RoundTrip(r)
if err == errUnauthorised {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
@ -424,9 +429,15 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
ForwardedBy: r.Host,
}
identifier, ok := s.Subscriber(r.Host)
identifier, auth, ok := s.Subscriber(r.Host)
if !ok {
return nil, fmt.Errorf("proxy request error: %s", errClientNotSubscribed)
return nil, errClientNotSubscribed
}
if auth != nil {
user, password, _ := r.BasicAuth()
if auth.User != user || auth.Password != password {
return nil, errUnauthorised
}
}
return s.proxyHTTP(identifier, r, msg)