diff --git a/server.go b/server.go index e999b2a..e189e15 100644 --- a/server.go +++ b/server.go @@ -7,6 +7,7 @@ package tunnel import ( "context" "crypto/tls" + "crypto/x509" "encoding/json" "errors" "fmt" @@ -24,6 +25,20 @@ import ( "github.com/mmatczuk/go-http-tunnel/proto" ) +// A set of listeners to manage subscribers +type SubscriptionListener interface { + // Invoked if AutoSubscribe is false and must return true if the client is allowed to subscribe or not. + // If the tlsConfig is configured to require client certificate validation, chain will contain the first + // verified chain, else the presented peer certificate. + CanSubscribe(id id.ID, chain []*x509.Certificate) bool + // Invoked when the client has been subscribed. + // If the tlsConfig is configured to require client certificate validation, chain will contain the first + // verified chain, else the presented peer certificate. + Subscribed(id id.ID, tlsConn *tls.Conn, chain []*x509.Certificate) + // Invoked before the client is unsubscribed. + Unsubscribed(id id.ID) +} + // ServerConfig defines configuration for the Server. type ServerConfig struct { // Addr is TCP address to listen for client connections. If empty ":0" @@ -41,6 +56,8 @@ type ServerConfig struct { Logger log.Logger // Addr is TCP address to listen for TLS SNI connections SNIAddr string + // Optional listener to manage subscribers + SubscriptionListener SubscriptionListener } // Server is responsible for proxying public connections to the client over a @@ -238,6 +255,7 @@ func (s *Server) handleClient(conn net.Conn) { ok bool inConnPool bool + certs []*x509.Certificate ) tlsConn, ok := conn.(*tls.Conn) @@ -262,14 +280,26 @@ func (s *Server) handleClient(conn net.Conn) { logger = logger.With("identifier", identifier) + certs = tlsConn.ConnectionState().PeerCertificates + if tlsConn.ConnectionState().VerifiedChains != nil && len(tlsConn.ConnectionState().VerifiedChains) > 0 { + certs = tlsConn.ConnectionState().VerifiedChains[0] + } if s.config.AutoSubscribe { s.Subscribe(identifier) + if s.config.SubscriptionListener != nil { + s.config.SubscriptionListener.Subscribed(identifier, tlsConn, certs) + } } else if !s.IsSubscribed(identifier) { - logger.Log( - "level", 2, - "msg", "unknown client", - ) - goto reject + if s.config.SubscriptionListener != nil && s.config.SubscriptionListener.CanSubscribe(identifier, certs) { + s.Subscribe(identifier) + s.config.SubscriptionListener.Subscribed(identifier, tlsConn, certs) + } else { + logger.Log( + "level", 2, + "msg", "unknown client", + ) + goto reject + } } if err = conn.SetDeadline(time.Time{}); err != nil { @@ -486,6 +516,9 @@ rollback: // Unsubscribe removes client from registry, disconnects client if already // connected and returns it's RegistryItem. func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem { + if s.config.SubscriptionListener != nil { + s.config.SubscriptionListener.Unsubscribed(identifier) + } s.connPool.DeleteConn(identifier) return s.registry.Unsubscribe(identifier) } @@ -561,6 +594,50 @@ func (s *Server) listen(l net.Listener, identifier id.ID) { } } +func (s *Server) Upgrade(identifier id.ID, conn net.Conn, requestBytes []byte) error { + + var err error + + msg := &proto.ControlMessage{ + Action: proto.ActionProxy, + ForwardedProto: "https", + } + + tlsConn, ok := conn.(*tls.Conn) + if ok { + msg.ForwardedHost = tlsConn.ConnectionState().ServerName + err = keepAlive(tlsConn.NetConn()) + + } else { + msg.ForwardedHost = conn.RemoteAddr().String() + err = keepAlive(conn) + } + + if err != nil { + s.logger.Log( + "level", 1, + "msg", "TCP keepalive for tunneled connection failed", + "identifier", identifier, + "ctrlMsg", msg, + "err", err, + ) + } + + go func() { + if err := s.proxyConnUpgraded(identifier, conn, msg, requestBytes); err != nil { + s.logger.Log( + "level", 0, + "msg", "proxy error", + "identifier", identifier, + "ctrlMsg", msg, + "err", err, + ) + } + }() + + return nil +} + // ServeHTTP proxies http connection to the client. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { resp, err := s.RoundTrip(r) @@ -639,6 +716,74 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) { return s.proxyHTTP(identifier, outr, msg) } +func (s *Server) proxyConnUpgraded(identifier id.ID, conn net.Conn, msg *proto.ControlMessage, requestBytes []byte) error { + s.logger.Log( + "level", 2, + "action", "proxy conn", + "identifier", identifier, + "ctrlMsg", msg, + ) + + defer conn.Close() + + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + + continueChan := make(chan int) + + go func() { + pw.Write(requestBytes) + continueChan <- 1 + }() + + req, err := s.connectRequest(identifier, msg, pr) + if err != nil { + return err + } + + ctx, cancel := context.WithCancel(context.Background()) + req = req.WithContext(ctx) + + done := make(chan struct{}) + go func() { + <-continueChan + transfer(pw, conn, log.NewContext(s.logger).With( + "dir", "user to client", + "dst", identifier, + "src", conn.RemoteAddr(), + )) + cancel() + close(done) + }() + + resp, err := s.httpClient.Do(req) + if err != nil { + return fmt.Errorf("io error: %s", err) + } + defer resp.Body.Close() + + transfer(conn, resp.Body, log.NewContext(s.logger).With( + "dir", "client to user", + "dst", conn.RemoteAddr(), + "src", identifier, + )) + + select { + case <-done: + case <-time.After(DefaultTimeout): + } + + s.logger.Log( + "level", 2, + "action", "proxy conn done", + "identifier", identifier, + "ctrlMsg", msg, + ) + + return nil +} + func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMessage) error { s.logger.Log( "level", 2,