mirror of
https://github.com/mmatczuk/go-http-tunnel.git
synced 2026-05-15 06:06:03 -06:00
Added support for subscription listeners, dynamic subscription authorization and to upgrade a http connection for websocket support
This commit is contained in:
parent
77db4b5a50
commit
9da0263137
1 changed files with 150 additions and 5 deletions
155
server.go
155
server.go
|
|
@ -7,6 +7,7 @@ package tunnel
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -24,6 +25,20 @@ import (
|
|||
"github.com/mmatczuk/go-http-tunnel/proto"
|
||||
)
|
||||
|
||||
// A set of listeners to manage subscribers
|
||||
type SubscriptionListener interface {
|
||||
// Invoked if AutoSubscribe is false and must return true if the client is allowed to subscribe or not.
|
||||
// If the tlsConfig is configured to require client certificate validation, chain will contain the first
|
||||
// verified chain, else the presented peer certificate.
|
||||
CanSubscribe(id id.ID, chain []*x509.Certificate) bool
|
||||
// Invoked when the client has been subscribed.
|
||||
// If the tlsConfig is configured to require client certificate validation, chain will contain the first
|
||||
// verified chain, else the presented peer certificate.
|
||||
Subscribed(id id.ID, tlsConn *tls.Conn, chain []*x509.Certificate)
|
||||
// Invoked before the client is unsubscribed.
|
||||
Unsubscribed(id id.ID)
|
||||
}
|
||||
|
||||
// ServerConfig defines configuration for the Server.
|
||||
type ServerConfig struct {
|
||||
// Addr is TCP address to listen for client connections. If empty ":0"
|
||||
|
|
@ -41,6 +56,8 @@ type ServerConfig struct {
|
|||
Logger log.Logger
|
||||
// Addr is TCP address to listen for TLS SNI connections
|
||||
SNIAddr string
|
||||
// Optional listener to manage subscribers
|
||||
SubscriptionListener SubscriptionListener
|
||||
}
|
||||
|
||||
// Server is responsible for proxying public connections to the client over a
|
||||
|
|
@ -238,6 +255,7 @@ func (s *Server) handleClient(conn net.Conn) {
|
|||
ok bool
|
||||
|
||||
inConnPool bool
|
||||
certs []*x509.Certificate
|
||||
)
|
||||
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
|
|
@ -262,14 +280,26 @@ func (s *Server) handleClient(conn net.Conn) {
|
|||
|
||||
logger = logger.With("identifier", identifier)
|
||||
|
||||
certs = tlsConn.ConnectionState().PeerCertificates
|
||||
if tlsConn.ConnectionState().VerifiedChains != nil && len(tlsConn.ConnectionState().VerifiedChains) > 0 {
|
||||
certs = tlsConn.ConnectionState().VerifiedChains[0]
|
||||
}
|
||||
if s.config.AutoSubscribe {
|
||||
s.Subscribe(identifier)
|
||||
if s.config.SubscriptionListener != nil {
|
||||
s.config.SubscriptionListener.Subscribed(identifier, tlsConn, certs)
|
||||
}
|
||||
} else if !s.IsSubscribed(identifier) {
|
||||
logger.Log(
|
||||
"level", 2,
|
||||
"msg", "unknown client",
|
||||
)
|
||||
goto reject
|
||||
if s.config.SubscriptionListener != nil && s.config.SubscriptionListener.CanSubscribe(identifier, certs) {
|
||||
s.Subscribe(identifier)
|
||||
s.config.SubscriptionListener.Subscribed(identifier, tlsConn, certs)
|
||||
} else {
|
||||
logger.Log(
|
||||
"level", 2,
|
||||
"msg", "unknown client",
|
||||
)
|
||||
goto reject
|
||||
}
|
||||
}
|
||||
|
||||
if err = conn.SetDeadline(time.Time{}); err != nil {
|
||||
|
|
@ -486,6 +516,9 @@ rollback:
|
|||
// Unsubscribe removes client from registry, disconnects client if already
|
||||
// connected and returns it's RegistryItem.
|
||||
func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem {
|
||||
if s.config.SubscriptionListener != nil {
|
||||
s.config.SubscriptionListener.Unsubscribed(identifier)
|
||||
}
|
||||
s.connPool.DeleteConn(identifier)
|
||||
return s.registry.Unsubscribe(identifier)
|
||||
}
|
||||
|
|
@ -561,6 +594,50 @@ func (s *Server) listen(l net.Listener, identifier id.ID) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) Upgrade(identifier id.ID, conn net.Conn, requestBytes []byte) error {
|
||||
|
||||
var err error
|
||||
|
||||
msg := &proto.ControlMessage{
|
||||
Action: proto.ActionProxy,
|
||||
ForwardedProto: "https",
|
||||
}
|
||||
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
if ok {
|
||||
msg.ForwardedHost = tlsConn.ConnectionState().ServerName
|
||||
err = keepAlive(tlsConn.NetConn())
|
||||
|
||||
} else {
|
||||
msg.ForwardedHost = conn.RemoteAddr().String()
|
||||
err = keepAlive(conn)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
s.logger.Log(
|
||||
"level", 1,
|
||||
"msg", "TCP keepalive for tunneled connection failed",
|
||||
"identifier", identifier,
|
||||
"ctrlMsg", msg,
|
||||
"err", err,
|
||||
)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := s.proxyConnUpgraded(identifier, conn, msg, requestBytes); err != nil {
|
||||
s.logger.Log(
|
||||
"level", 0,
|
||||
"msg", "proxy error",
|
||||
"identifier", identifier,
|
||||
"ctrlMsg", msg,
|
||||
"err", err,
|
||||
)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServeHTTP proxies http connection to the client.
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
resp, err := s.RoundTrip(r)
|
||||
|
|
@ -639,6 +716,74 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
|
|||
return s.proxyHTTP(identifier, outr, msg)
|
||||
}
|
||||
|
||||
func (s *Server) proxyConnUpgraded(identifier id.ID, conn net.Conn, msg *proto.ControlMessage, requestBytes []byte) error {
|
||||
s.logger.Log(
|
||||
"level", 2,
|
||||
"action", "proxy conn",
|
||||
"identifier", identifier,
|
||||
"ctrlMsg", msg,
|
||||
)
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
defer pw.Close()
|
||||
|
||||
continueChan := make(chan int)
|
||||
|
||||
go func() {
|
||||
pw.Write(requestBytes)
|
||||
continueChan <- 1
|
||||
}()
|
||||
|
||||
req, err := s.connectRequest(identifier, msg, pr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
<-continueChan
|
||||
transfer(pw, conn, log.NewContext(s.logger).With(
|
||||
"dir", "user to client",
|
||||
"dst", identifier,
|
||||
"src", conn.RemoteAddr(),
|
||||
))
|
||||
cancel()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("io error: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
transfer(conn, resp.Body, log.NewContext(s.logger).With(
|
||||
"dir", "client to user",
|
||||
"dst", conn.RemoteAddr(),
|
||||
"src", identifier,
|
||||
))
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(DefaultTimeout):
|
||||
}
|
||||
|
||||
s.logger.Log(
|
||||
"level", 2,
|
||||
"action", "proxy conn done",
|
||||
"identifier", identifier,
|
||||
"ctrlMsg", msg,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMessage) error {
|
||||
s.logger.Log(
|
||||
"level", 2,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue