Added support for subscription listeners, dynamic subscription authorization and to upgrade a http connection for websocket support

This commit is contained in:
Ivano Culmine 2022-07-13 11:30:55 +02:00 committed by Michal Jan Matczuk
parent 77db4b5a50
commit 9da0263137

155
server.go
View file

@ -7,6 +7,7 @@ package tunnel
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
@ -24,6 +25,20 @@ import (
"github.com/mmatczuk/go-http-tunnel/proto"
)
// A set of listeners to manage subscribers
type SubscriptionListener interface {
// Invoked if AutoSubscribe is false and must return true if the client is allowed to subscribe or not.
// If the tlsConfig is configured to require client certificate validation, chain will contain the first
// verified chain, else the presented peer certificate.
CanSubscribe(id id.ID, chain []*x509.Certificate) bool
// Invoked when the client has been subscribed.
// If the tlsConfig is configured to require client certificate validation, chain will contain the first
// verified chain, else the presented peer certificate.
Subscribed(id id.ID, tlsConn *tls.Conn, chain []*x509.Certificate)
// Invoked before the client is unsubscribed.
Unsubscribed(id id.ID)
}
// ServerConfig defines configuration for the Server.
type ServerConfig struct {
// Addr is TCP address to listen for client connections. If empty ":0"
@ -41,6 +56,8 @@ type ServerConfig struct {
Logger log.Logger
// Addr is TCP address to listen for TLS SNI connections
SNIAddr string
// Optional listener to manage subscribers
SubscriptionListener SubscriptionListener
}
// Server is responsible for proxying public connections to the client over a
@ -238,6 +255,7 @@ func (s *Server) handleClient(conn net.Conn) {
ok bool
inConnPool bool
certs []*x509.Certificate
)
tlsConn, ok := conn.(*tls.Conn)
@ -262,14 +280,26 @@ func (s *Server) handleClient(conn net.Conn) {
logger = logger.With("identifier", identifier)
certs = tlsConn.ConnectionState().PeerCertificates
if tlsConn.ConnectionState().VerifiedChains != nil && len(tlsConn.ConnectionState().VerifiedChains) > 0 {
certs = tlsConn.ConnectionState().VerifiedChains[0]
}
if s.config.AutoSubscribe {
s.Subscribe(identifier)
if s.config.SubscriptionListener != nil {
s.config.SubscriptionListener.Subscribed(identifier, tlsConn, certs)
}
} else if !s.IsSubscribed(identifier) {
logger.Log(
"level", 2,
"msg", "unknown client",
)
goto reject
if s.config.SubscriptionListener != nil && s.config.SubscriptionListener.CanSubscribe(identifier, certs) {
s.Subscribe(identifier)
s.config.SubscriptionListener.Subscribed(identifier, tlsConn, certs)
} else {
logger.Log(
"level", 2,
"msg", "unknown client",
)
goto reject
}
}
if err = conn.SetDeadline(time.Time{}); err != nil {
@ -486,6 +516,9 @@ rollback:
// Unsubscribe removes client from registry, disconnects client if already
// connected and returns it's RegistryItem.
func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem {
if s.config.SubscriptionListener != nil {
s.config.SubscriptionListener.Unsubscribed(identifier)
}
s.connPool.DeleteConn(identifier)
return s.registry.Unsubscribe(identifier)
}
@ -561,6 +594,50 @@ func (s *Server) listen(l net.Listener, identifier id.ID) {
}
}
func (s *Server) Upgrade(identifier id.ID, conn net.Conn, requestBytes []byte) error {
var err error
msg := &proto.ControlMessage{
Action: proto.ActionProxy,
ForwardedProto: "https",
}
tlsConn, ok := conn.(*tls.Conn)
if ok {
msg.ForwardedHost = tlsConn.ConnectionState().ServerName
err = keepAlive(tlsConn.NetConn())
} else {
msg.ForwardedHost = conn.RemoteAddr().String()
err = keepAlive(conn)
}
if err != nil {
s.logger.Log(
"level", 1,
"msg", "TCP keepalive for tunneled connection failed",
"identifier", identifier,
"ctrlMsg", msg,
"err", err,
)
}
go func() {
if err := s.proxyConnUpgraded(identifier, conn, msg, requestBytes); err != nil {
s.logger.Log(
"level", 0,
"msg", "proxy error",
"identifier", identifier,
"ctrlMsg", msg,
"err", err,
)
}
}()
return nil
}
// ServeHTTP proxies http connection to the client.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
resp, err := s.RoundTrip(r)
@ -639,6 +716,74 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
return s.proxyHTTP(identifier, outr, msg)
}
func (s *Server) proxyConnUpgraded(identifier id.ID, conn net.Conn, msg *proto.ControlMessage, requestBytes []byte) error {
s.logger.Log(
"level", 2,
"action", "proxy conn",
"identifier", identifier,
"ctrlMsg", msg,
)
defer conn.Close()
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
continueChan := make(chan int)
go func() {
pw.Write(requestBytes)
continueChan <- 1
}()
req, err := s.connectRequest(identifier, msg, pr)
if err != nil {
return err
}
ctx, cancel := context.WithCancel(context.Background())
req = req.WithContext(ctx)
done := make(chan struct{})
go func() {
<-continueChan
transfer(pw, conn, log.NewContext(s.logger).With(
"dir", "user to client",
"dst", identifier,
"src", conn.RemoteAddr(),
))
cancel()
close(done)
}()
resp, err := s.httpClient.Do(req)
if err != nil {
return fmt.Errorf("io error: %s", err)
}
defer resp.Body.Close()
transfer(conn, resp.Body, log.NewContext(s.logger).With(
"dir", "client to user",
"dst", conn.RemoteAddr(),
"src", identifier,
))
select {
case <-done:
case <-time.After(DefaultTimeout):
}
s.logger.Log(
"level", 2,
"action", "proxy conn done",
"identifier", identifier,
"ctrlMsg", msg,
)
return nil
}
func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMessage) error {
s.logger.Log(
"level", 2,