diff --git a/pool.go b/pool.go index 08ef6d8..9ede026 100644 --- a/pool.go +++ b/pool.go @@ -5,10 +5,12 @@ package tunnel import ( + "context" "fmt" "net" "net/http" "sync" + "time" "golang.org/x/net/http2" @@ -37,6 +39,10 @@ func newConnPool(t *http2.Transport, l onDisconnectListener) *connPool { } } +func (p *connPool) URL(identifier id.ID) string { + return fmt.Sprint("https://", identifier) +} + func (p *connPool) GetClientConn(req *http.Request, addr string) (*http2.ClientConn, error) { p.mu.RLock() defer p.mu.RUnlock() @@ -54,11 +60,7 @@ func (p *connPool) MarkDead(c *http2.ClientConn) { for addr, cp := range p.conns { if cp.clientConn == c { - cp.conn.Close() - delete(p.conns, addr) - if p.listener != nil { - p.listener(p.addrToIdentifier(addr)) - } + p.close(cp, addr) return } } @@ -70,8 +72,12 @@ func (p *connPool) AddConn(conn net.Conn, identifier id.ID) error { addr := p.addr(identifier) - if _, ok := p.conns[addr]; ok { - return errClientAlreadyConnected + if cp, ok := p.conns[addr]; ok { + if err := p.ping(cp); err != nil { + p.close(cp, addr) + } else { + return errClientAlreadyConnected + } } c, err := p.t.NewClientConn(conn) @@ -93,24 +99,46 @@ func (p *connPool) DeleteConn(identifier id.ID) { addr := p.addr(identifier) if cp, ok := p.conns[addr]; ok { - cp.conn.Close() - delete(p.conns, addr) - if p.listener != nil { - p.listener(identifier) - } + p.close(cp, addr) } } -func (p *connPool) URL(identifier id.ID) string { - return fmt.Sprint("https://", identifier) +func (p *connPool) Ping(identifier id.ID) (time.Duration, error) { + p.mu.Lock() + defer p.mu.Unlock() + + addr := p.addr(identifier) + + if cp, ok := p.conns[addr]; ok { + start := time.Now() + err := p.ping(cp) + return time.Since(start), err + } + + return 0, errClientNotConnected +} + +func (p *connPool) ping(cp connPair) error { + ctx, cancel := context.WithTimeout(context.Background(), DefaultPingTimeout) + defer cancel() + + return cp.clientConn.Ping(ctx) +} + +func (p *connPool) close(cp connPair, addr string) { + cp.conn.Close() + delete(p.conns, addr) + if p.listener != nil { + p.listener(p.identifier(addr)) + } } func (p *connPool) addr(identifier id.ID) string { return fmt.Sprint(identifier.String(), ":443") } -func (p *connPool) addrToIdentifier(addr string) id.ID { - identifier := id.ID{} +func (p *connPool) identifier(addr string) id.ID { + var identifier id.ID identifier.UnmarshalText([]byte(addr[:len(addr)-4])) return identifier } diff --git a/server.go b/server.go index 6854a79..e4b0c7d 100644 --- a/server.go +++ b/server.go @@ -150,7 +150,7 @@ func (s *Server) Start() { if err := keepAlive(conn); err != nil { s.logger.Log( "level", 1, - "msg", "connection keepAlive failed", + "msg", "could not enable TCP keepalive for control connection", "addr", addr, "err", err, ) @@ -408,6 +408,11 @@ func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem { return s.registry.Unsubscribe(identifier) } +// Ping measures the RTT response time. +func (s *Server) Ping(identifier id.ID) (time.Duration, error) { + return s.connPool.Ping(identifier) +} + func (s *Server) listen(l net.Listener, identifier id.ID) { addr := l.Addr().String()