Added support for subscription listeners, dynamic subscription authorization and to upgrade a http connection for websocket support

This commit is contained in:
Ivano Culmine 2022-07-13 11:30:55 +02:00 committed by Michal Jan Matczuk
parent 77db4b5a50
commit 9da0263137

155
server.go
View file

@ -7,6 +7,7 @@ package tunnel
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -24,6 +25,20 @@ import (
"github.com/mmatczuk/go-http-tunnel/proto" "github.com/mmatczuk/go-http-tunnel/proto"
) )
// A set of listeners to manage subscribers
type SubscriptionListener interface {
// Invoked if AutoSubscribe is false and must return true if the client is allowed to subscribe or not.
// If the tlsConfig is configured to require client certificate validation, chain will contain the first
// verified chain, else the presented peer certificate.
CanSubscribe(id id.ID, chain []*x509.Certificate) bool
// Invoked when the client has been subscribed.
// If the tlsConfig is configured to require client certificate validation, chain will contain the first
// verified chain, else the presented peer certificate.
Subscribed(id id.ID, tlsConn *tls.Conn, chain []*x509.Certificate)
// Invoked before the client is unsubscribed.
Unsubscribed(id id.ID)
}
// ServerConfig defines configuration for the Server. // ServerConfig defines configuration for the Server.
type ServerConfig struct { type ServerConfig struct {
// Addr is TCP address to listen for client connections. If empty ":0" // Addr is TCP address to listen for client connections. If empty ":0"
@ -41,6 +56,8 @@ type ServerConfig struct {
Logger log.Logger Logger log.Logger
// Addr is TCP address to listen for TLS SNI connections // Addr is TCP address to listen for TLS SNI connections
SNIAddr string SNIAddr string
// Optional listener to manage subscribers
SubscriptionListener SubscriptionListener
} }
// Server is responsible for proxying public connections to the client over a // Server is responsible for proxying public connections to the client over a
@ -238,6 +255,7 @@ func (s *Server) handleClient(conn net.Conn) {
ok bool ok bool
inConnPool bool inConnPool bool
certs []*x509.Certificate
) )
tlsConn, ok := conn.(*tls.Conn) tlsConn, ok := conn.(*tls.Conn)
@ -262,14 +280,26 @@ func (s *Server) handleClient(conn net.Conn) {
logger = logger.With("identifier", identifier) logger = logger.With("identifier", identifier)
certs = tlsConn.ConnectionState().PeerCertificates
if tlsConn.ConnectionState().VerifiedChains != nil && len(tlsConn.ConnectionState().VerifiedChains) > 0 {
certs = tlsConn.ConnectionState().VerifiedChains[0]
}
if s.config.AutoSubscribe { if s.config.AutoSubscribe {
s.Subscribe(identifier) s.Subscribe(identifier)
if s.config.SubscriptionListener != nil {
s.config.SubscriptionListener.Subscribed(identifier, tlsConn, certs)
}
} else if !s.IsSubscribed(identifier) { } else if !s.IsSubscribed(identifier) {
logger.Log( if s.config.SubscriptionListener != nil && s.config.SubscriptionListener.CanSubscribe(identifier, certs) {
"level", 2, s.Subscribe(identifier)
"msg", "unknown client", s.config.SubscriptionListener.Subscribed(identifier, tlsConn, certs)
) } else {
goto reject logger.Log(
"level", 2,
"msg", "unknown client",
)
goto reject
}
} }
if err = conn.SetDeadline(time.Time{}); err != nil { if err = conn.SetDeadline(time.Time{}); err != nil {
@ -486,6 +516,9 @@ rollback:
// Unsubscribe removes client from registry, disconnects client if already // Unsubscribe removes client from registry, disconnects client if already
// connected and returns it's RegistryItem. // connected and returns it's RegistryItem.
func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem { func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem {
if s.config.SubscriptionListener != nil {
s.config.SubscriptionListener.Unsubscribed(identifier)
}
s.connPool.DeleteConn(identifier) s.connPool.DeleteConn(identifier)
return s.registry.Unsubscribe(identifier) return s.registry.Unsubscribe(identifier)
} }
@ -561,6 +594,50 @@ func (s *Server) listen(l net.Listener, identifier id.ID) {
} }
} }
func (s *Server) Upgrade(identifier id.ID, conn net.Conn, requestBytes []byte) error {
var err error
msg := &proto.ControlMessage{
Action: proto.ActionProxy,
ForwardedProto: "https",
}
tlsConn, ok := conn.(*tls.Conn)
if ok {
msg.ForwardedHost = tlsConn.ConnectionState().ServerName
err = keepAlive(tlsConn.NetConn())
} else {
msg.ForwardedHost = conn.RemoteAddr().String()
err = keepAlive(conn)
}
if err != nil {
s.logger.Log(
"level", 1,
"msg", "TCP keepalive for tunneled connection failed",
"identifier", identifier,
"ctrlMsg", msg,
"err", err,
)
}
go func() {
if err := s.proxyConnUpgraded(identifier, conn, msg, requestBytes); err != nil {
s.logger.Log(
"level", 0,
"msg", "proxy error",
"identifier", identifier,
"ctrlMsg", msg,
"err", err,
)
}
}()
return nil
}
// ServeHTTP proxies http connection to the client. // ServeHTTP proxies http connection to the client.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
resp, err := s.RoundTrip(r) resp, err := s.RoundTrip(r)
@ -639,6 +716,74 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
return s.proxyHTTP(identifier, outr, msg) return s.proxyHTTP(identifier, outr, msg)
} }
func (s *Server) proxyConnUpgraded(identifier id.ID, conn net.Conn, msg *proto.ControlMessage, requestBytes []byte) error {
s.logger.Log(
"level", 2,
"action", "proxy conn",
"identifier", identifier,
"ctrlMsg", msg,
)
defer conn.Close()
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
continueChan := make(chan int)
go func() {
pw.Write(requestBytes)
continueChan <- 1
}()
req, err := s.connectRequest(identifier, msg, pr)
if err != nil {
return err
}
ctx, cancel := context.WithCancel(context.Background())
req = req.WithContext(ctx)
done := make(chan struct{})
go func() {
<-continueChan
transfer(pw, conn, log.NewContext(s.logger).With(
"dir", "user to client",
"dst", identifier,
"src", conn.RemoteAddr(),
))
cancel()
close(done)
}()
resp, err := s.httpClient.Do(req)
if err != nil {
return fmt.Errorf("io error: %s", err)
}
defer resp.Body.Close()
transfer(conn, resp.Body, log.NewContext(s.logger).With(
"dir", "client to user",
"dst", conn.RemoteAddr(),
"src", identifier,
))
select {
case <-done:
case <-time.After(DefaultTimeout):
}
s.logger.Log(
"level", 2,
"action", "proxy conn done",
"identifier", identifier,
"ctrlMsg", msg,
)
return nil
}
func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMessage) error { func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMessage) error {
s.logger.Log( s.logger.Log(
"level", 2, "level", 2,