mirror of
https://github.com/mmatczuk/go-http-tunnel.git
synced 2026-05-15 14:16:17 -06:00
Added support for subscription listeners, dynamic subscription authorization and to upgrade a http connection for websocket support
This commit is contained in:
parent
77db4b5a50
commit
9da0263137
1 changed files with 150 additions and 5 deletions
155
server.go
155
server.go
|
|
@ -7,6 +7,7 @@ package tunnel
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
@ -24,6 +25,20 @@ import (
|
||||||
"github.com/mmatczuk/go-http-tunnel/proto"
|
"github.com/mmatczuk/go-http-tunnel/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// A set of listeners to manage subscribers
|
||||||
|
type SubscriptionListener interface {
|
||||||
|
// Invoked if AutoSubscribe is false and must return true if the client is allowed to subscribe or not.
|
||||||
|
// If the tlsConfig is configured to require client certificate validation, chain will contain the first
|
||||||
|
// verified chain, else the presented peer certificate.
|
||||||
|
CanSubscribe(id id.ID, chain []*x509.Certificate) bool
|
||||||
|
// Invoked when the client has been subscribed.
|
||||||
|
// If the tlsConfig is configured to require client certificate validation, chain will contain the first
|
||||||
|
// verified chain, else the presented peer certificate.
|
||||||
|
Subscribed(id id.ID, tlsConn *tls.Conn, chain []*x509.Certificate)
|
||||||
|
// Invoked before the client is unsubscribed.
|
||||||
|
Unsubscribed(id id.ID)
|
||||||
|
}
|
||||||
|
|
||||||
// ServerConfig defines configuration for the Server.
|
// ServerConfig defines configuration for the Server.
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
// Addr is TCP address to listen for client connections. If empty ":0"
|
// Addr is TCP address to listen for client connections. If empty ":0"
|
||||||
|
|
@ -41,6 +56,8 @@ type ServerConfig struct {
|
||||||
Logger log.Logger
|
Logger log.Logger
|
||||||
// Addr is TCP address to listen for TLS SNI connections
|
// Addr is TCP address to listen for TLS SNI connections
|
||||||
SNIAddr string
|
SNIAddr string
|
||||||
|
// Optional listener to manage subscribers
|
||||||
|
SubscriptionListener SubscriptionListener
|
||||||
}
|
}
|
||||||
|
|
||||||
// Server is responsible for proxying public connections to the client over a
|
// Server is responsible for proxying public connections to the client over a
|
||||||
|
|
@ -238,6 +255,7 @@ func (s *Server) handleClient(conn net.Conn) {
|
||||||
ok bool
|
ok bool
|
||||||
|
|
||||||
inConnPool bool
|
inConnPool bool
|
||||||
|
certs []*x509.Certificate
|
||||||
)
|
)
|
||||||
|
|
||||||
tlsConn, ok := conn.(*tls.Conn)
|
tlsConn, ok := conn.(*tls.Conn)
|
||||||
|
|
@ -262,14 +280,26 @@ func (s *Server) handleClient(conn net.Conn) {
|
||||||
|
|
||||||
logger = logger.With("identifier", identifier)
|
logger = logger.With("identifier", identifier)
|
||||||
|
|
||||||
|
certs = tlsConn.ConnectionState().PeerCertificates
|
||||||
|
if tlsConn.ConnectionState().VerifiedChains != nil && len(tlsConn.ConnectionState().VerifiedChains) > 0 {
|
||||||
|
certs = tlsConn.ConnectionState().VerifiedChains[0]
|
||||||
|
}
|
||||||
if s.config.AutoSubscribe {
|
if s.config.AutoSubscribe {
|
||||||
s.Subscribe(identifier)
|
s.Subscribe(identifier)
|
||||||
|
if s.config.SubscriptionListener != nil {
|
||||||
|
s.config.SubscriptionListener.Subscribed(identifier, tlsConn, certs)
|
||||||
|
}
|
||||||
} else if !s.IsSubscribed(identifier) {
|
} else if !s.IsSubscribed(identifier) {
|
||||||
logger.Log(
|
if s.config.SubscriptionListener != nil && s.config.SubscriptionListener.CanSubscribe(identifier, certs) {
|
||||||
"level", 2,
|
s.Subscribe(identifier)
|
||||||
"msg", "unknown client",
|
s.config.SubscriptionListener.Subscribed(identifier, tlsConn, certs)
|
||||||
)
|
} else {
|
||||||
goto reject
|
logger.Log(
|
||||||
|
"level", 2,
|
||||||
|
"msg", "unknown client",
|
||||||
|
)
|
||||||
|
goto reject
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = conn.SetDeadline(time.Time{}); err != nil {
|
if err = conn.SetDeadline(time.Time{}); err != nil {
|
||||||
|
|
@ -486,6 +516,9 @@ rollback:
|
||||||
// Unsubscribe removes client from registry, disconnects client if already
|
// Unsubscribe removes client from registry, disconnects client if already
|
||||||
// connected and returns it's RegistryItem.
|
// connected and returns it's RegistryItem.
|
||||||
func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem {
|
func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem {
|
||||||
|
if s.config.SubscriptionListener != nil {
|
||||||
|
s.config.SubscriptionListener.Unsubscribed(identifier)
|
||||||
|
}
|
||||||
s.connPool.DeleteConn(identifier)
|
s.connPool.DeleteConn(identifier)
|
||||||
return s.registry.Unsubscribe(identifier)
|
return s.registry.Unsubscribe(identifier)
|
||||||
}
|
}
|
||||||
|
|
@ -561,6 +594,50 @@ func (s *Server) listen(l net.Listener, identifier id.ID) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) Upgrade(identifier id.ID, conn net.Conn, requestBytes []byte) error {
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
msg := &proto.ControlMessage{
|
||||||
|
Action: proto.ActionProxy,
|
||||||
|
ForwardedProto: "https",
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn, ok := conn.(*tls.Conn)
|
||||||
|
if ok {
|
||||||
|
msg.ForwardedHost = tlsConn.ConnectionState().ServerName
|
||||||
|
err = keepAlive(tlsConn.NetConn())
|
||||||
|
|
||||||
|
} else {
|
||||||
|
msg.ForwardedHost = conn.RemoteAddr().String()
|
||||||
|
err = keepAlive(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Log(
|
||||||
|
"level", 1,
|
||||||
|
"msg", "TCP keepalive for tunneled connection failed",
|
||||||
|
"identifier", identifier,
|
||||||
|
"ctrlMsg", msg,
|
||||||
|
"err", err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := s.proxyConnUpgraded(identifier, conn, msg, requestBytes); err != nil {
|
||||||
|
s.logger.Log(
|
||||||
|
"level", 0,
|
||||||
|
"msg", "proxy error",
|
||||||
|
"identifier", identifier,
|
||||||
|
"ctrlMsg", msg,
|
||||||
|
"err", err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// ServeHTTP proxies http connection to the client.
|
// ServeHTTP proxies http connection to the client.
|
||||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
resp, err := s.RoundTrip(r)
|
resp, err := s.RoundTrip(r)
|
||||||
|
|
@ -639,6 +716,74 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
return s.proxyHTTP(identifier, outr, msg)
|
return s.proxyHTTP(identifier, outr, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) proxyConnUpgraded(identifier id.ID, conn net.Conn, msg *proto.ControlMessage, requestBytes []byte) error {
|
||||||
|
s.logger.Log(
|
||||||
|
"level", 2,
|
||||||
|
"action", "proxy conn",
|
||||||
|
"identifier", identifier,
|
||||||
|
"ctrlMsg", msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
defer pr.Close()
|
||||||
|
defer pw.Close()
|
||||||
|
|
||||||
|
continueChan := make(chan int)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
pw.Write(requestBytes)
|
||||||
|
continueChan <- 1
|
||||||
|
}()
|
||||||
|
|
||||||
|
req, err := s.connectRequest(identifier, msg, pr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
<-continueChan
|
||||||
|
transfer(pw, conn, log.NewContext(s.logger).With(
|
||||||
|
"dir", "user to client",
|
||||||
|
"dst", identifier,
|
||||||
|
"src", conn.RemoteAddr(),
|
||||||
|
))
|
||||||
|
cancel()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
resp, err := s.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("io error: %s", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
transfer(conn, resp.Body, log.NewContext(s.logger).With(
|
||||||
|
"dir", "client to user",
|
||||||
|
"dst", conn.RemoteAddr(),
|
||||||
|
"src", identifier,
|
||||||
|
))
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(DefaultTimeout):
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Log(
|
||||||
|
"level", 2,
|
||||||
|
"action", "proxy conn done",
|
||||||
|
"identifier", identifier,
|
||||||
|
"ctrlMsg", msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMessage) error {
|
func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMessage) error {
|
||||||
s.logger.Log(
|
s.logger.Log(
|
||||||
"level", 2,
|
"level", 2,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue