refactor: clean up code (#5308)
Some checks failed
golangci-lint / lint (push) Has been cancelled

This commit is contained in:
fatedier 2026-05-12 11:13:50 +08:00 committed by GitHub
parent ad07d27914
commit a88e0e9a49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
49 changed files with 2082 additions and 931 deletions

2
go.mod
View file

@ -36,6 +36,7 @@ require (
gopkg.in/ini.v1 v1.67.0
k8s.io/apimachinery v0.28.8
k8s.io/client-go v0.28.8
k8s.io/utils v0.0.0-20230406110748-d93618cff8a2
)
require (
@ -75,7 +76,6 @@ require (
google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 // indirect
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect
sigs.k8s.io/yaml v1.3.0 // indirect
)

View file

@ -14,11 +14,7 @@
package source
import (
"fmt"
v1 "github.com/fatedier/frp/pkg/config/v1"
)
import v1 "github.com/fatedier/frp/pkg/config/v1"
// ConfigSource implements Source for in-memory configuration.
// All operations are thread-safe.
@ -39,23 +35,17 @@ func (s *ConfigSource) ReplaceAll(proxies []v1.ProxyConfigurer, visitors []v1.Vi
nextProxies := make(map[string]v1.ProxyConfigurer, len(proxies))
for _, p := range proxies {
if p == nil {
return fmt.Errorf("proxy cannot be nil")
}
name := p.GetBaseConfig().Name
if name == "" {
return fmt.Errorf("proxy name cannot be empty")
name, err := validateProxyName(p)
if err != nil {
return err
}
nextProxies[name] = p
}
nextVisitors := make(map[string]v1.VisitorConfigurer, len(visitors))
for _, v := range visitors {
if v == nil {
return fmt.Errorf("visitor cannot be nil")
}
name := v.GetBaseConfig().Name
if name == "" {
return fmt.Errorf("visitor name cannot be empty")
name, err := validateVisitorName(v)
if err != nil {
return err
}
nextVisitors[name] = v
}

View file

@ -43,6 +43,11 @@ var (
ErrNotFound = errors.New("not found")
)
const (
storeKindProxy = "proxy"
storeKindVisitor = "visitor"
)
func NewStoreSource(cfg StoreSourceConfig) (*StoreSource, error) {
if cfg.Path == "" {
return nil, fmt.Errorf("path is required")
@ -172,79 +177,111 @@ func (s *StoreSource) saveToFileUnlocked() error {
return nil
}
func (s *StoreSource) AddProxy(proxy v1.ProxyConfigurer) error {
if proxy == nil {
return fmt.Errorf("proxy cannot be nil")
func (s *StoreSource) persistOrRollbackUnlocked(rollback func()) error {
if err := s.saveToFileUnlocked(); err != nil {
rollback()
return fmt.Errorf("failed to persist: %w", err)
}
return nil
}
name := proxy.GetBaseConfig().Name
// Store map selectors return the target map for generic helpers.
func proxyStoreEntries(s *StoreSource) map[string]v1.ProxyConfigurer {
return s.proxies
}
func visitorStoreEntries(s *StoreSource) map[string]v1.VisitorConfigurer {
return s.visitors
}
// Store entry helpers share mutation, persistence, and rollback for proxy and visitor maps.
// T is intentionally limited by callers to v1.ProxyConfigurer or v1.VisitorConfigurer.
func addStoreEntry[T any](
s *StoreSource,
entriesFn func(*StoreSource) map[string]T,
kind string,
name string,
value T,
) error {
s.mu.Lock()
defer s.mu.Unlock()
entries := entriesFn(s)
if _, exists := entries[name]; exists {
return fmt.Errorf("%w: %s %q", ErrAlreadyExists, kind, name)
}
entries[name] = value
return s.persistOrRollbackUnlocked(func() {
delete(entries, name)
})
}
func updateStoreEntry[T any](
s *StoreSource,
entriesFn func(*StoreSource) map[string]T,
kind string,
name string,
value T,
) error {
s.mu.Lock()
defer s.mu.Unlock()
entries := entriesFn(s)
old, exists := entries[name]
if !exists {
return fmt.Errorf("%w: %s %q", ErrNotFound, kind, name)
}
entries[name] = value
return s.persistOrRollbackUnlocked(func() {
entries[name] = old
})
}
func removeStoreEntry[T any](
s *StoreSource,
entriesFn func(*StoreSource) map[string]T,
kind string,
name string,
) error {
if name == "" {
return fmt.Errorf("proxy name cannot be empty")
return fmt.Errorf("%s name cannot be empty", kind)
}
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.proxies[name]; exists {
return fmt.Errorf("%w: proxy %q", ErrAlreadyExists, name)
entries := entriesFn(s)
old, exists := entries[name]
if !exists {
return fmt.Errorf("%w: %s %q", ErrNotFound, kind, name)
}
s.proxies[name] = proxy
if err := s.saveToFileUnlocked(); err != nil {
delete(s.proxies, name)
return fmt.Errorf("failed to persist: %w", err)
delete(entries, name)
return s.persistOrRollbackUnlocked(func() {
entries[name] = old
})
}
return nil
func (s *StoreSource) AddProxy(proxy v1.ProxyConfigurer) error {
name, err := validateProxyName(proxy)
if err != nil {
return err
}
return addStoreEntry(s, proxyStoreEntries, storeKindProxy, name, proxy)
}
func (s *StoreSource) UpdateProxy(proxy v1.ProxyConfigurer) error {
if proxy == nil {
return fmt.Errorf("proxy cannot be nil")
name, err := validateProxyName(proxy)
if err != nil {
return err
}
name := proxy.GetBaseConfig().Name
if name == "" {
return fmt.Errorf("proxy name cannot be empty")
}
s.mu.Lock()
defer s.mu.Unlock()
oldProxy, exists := s.proxies[name]
if !exists {
return fmt.Errorf("%w: proxy %q", ErrNotFound, name)
}
s.proxies[name] = proxy
if err := s.saveToFileUnlocked(); err != nil {
s.proxies[name] = oldProxy
return fmt.Errorf("failed to persist: %w", err)
}
return nil
return updateStoreEntry(s, proxyStoreEntries, storeKindProxy, name, proxy)
}
func (s *StoreSource) RemoveProxy(name string) error {
if name == "" {
return fmt.Errorf("proxy name cannot be empty")
}
s.mu.Lock()
defer s.mu.Unlock()
oldProxy, exists := s.proxies[name]
if !exists {
return fmt.Errorf("%w: proxy %q", ErrNotFound, name)
}
delete(s.proxies, name)
if err := s.saveToFileUnlocked(); err != nil {
s.proxies[name] = oldProxy
return fmt.Errorf("failed to persist: %w", err)
}
return nil
return removeStoreEntry(s, proxyStoreEntries, storeKindProxy, name)
}
func (s *StoreSource) GetProxy(name string) v1.ProxyConfigurer {
@ -259,78 +296,23 @@ func (s *StoreSource) GetProxy(name string) v1.ProxyConfigurer {
}
func (s *StoreSource) AddVisitor(visitor v1.VisitorConfigurer) error {
if visitor == nil {
return fmt.Errorf("visitor cannot be nil")
name, err := validateVisitorName(visitor)
if err != nil {
return err
}
name := visitor.GetBaseConfig().Name
if name == "" {
return fmt.Errorf("visitor name cannot be empty")
}
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.visitors[name]; exists {
return fmt.Errorf("%w: visitor %q", ErrAlreadyExists, name)
}
s.visitors[name] = visitor
if err := s.saveToFileUnlocked(); err != nil {
delete(s.visitors, name)
return fmt.Errorf("failed to persist: %w", err)
}
return nil
return addStoreEntry(s, visitorStoreEntries, storeKindVisitor, name, visitor)
}
func (s *StoreSource) UpdateVisitor(visitor v1.VisitorConfigurer) error {
if visitor == nil {
return fmt.Errorf("visitor cannot be nil")
name, err := validateVisitorName(visitor)
if err != nil {
return err
}
name := visitor.GetBaseConfig().Name
if name == "" {
return fmt.Errorf("visitor name cannot be empty")
}
s.mu.Lock()
defer s.mu.Unlock()
oldVisitor, exists := s.visitors[name]
if !exists {
return fmt.Errorf("%w: visitor %q", ErrNotFound, name)
}
s.visitors[name] = visitor
if err := s.saveToFileUnlocked(); err != nil {
s.visitors[name] = oldVisitor
return fmt.Errorf("failed to persist: %w", err)
}
return nil
return updateStoreEntry(s, visitorStoreEntries, storeKindVisitor, name, visitor)
}
func (s *StoreSource) RemoveVisitor(name string) error {
if name == "" {
return fmt.Errorf("visitor name cannot be empty")
}
s.mu.Lock()
defer s.mu.Unlock()
oldVisitor, exists := s.visitors[name]
if !exists {
return fmt.Errorf("%w: visitor %q", ErrNotFound, name)
}
delete(s.visitors, name)
if err := s.saveToFileUnlocked(); err != nil {
s.visitors[name] = oldVisitor
return fmt.Errorf("failed to persist: %w", err)
}
return nil
return removeStoreEntry(s, visitorStoreEntries, storeKindVisitor, name)
}
func (s *StoreSource) GetVisitor(name string) v1.VisitorConfigurer {

View file

@ -17,6 +17,7 @@ package source
import (
"os"
"path/filepath"
"runtime"
"testing"
"github.com/stretchr/testify/require"
@ -59,6 +60,101 @@ func TestStoreSource_AddProxyAndVisitor_DoesNotApplyRuntimeDefaults(t *testing.T
require.Empty(gotVisitor.(*v1.XTCPVisitorConfig).Protocol)
}
func TestStoreSource_UpdateAndRemoveProxyAndVisitor(t *testing.T) {
require := require.New(t)
storeSource := newTestStoreSource(t)
proxyCfg := mockProxy("proxy1")
visitorCfg := mockVisitor("visitor1")
require.NoError(storeSource.AddProxy(proxyCfg))
require.NoError(storeSource.AddVisitor(visitorCfg))
require.ErrorIs(storeSource.AddProxy(proxyCfg), ErrAlreadyExists)
require.ErrorIs(storeSource.AddVisitor(visitorCfg), ErrAlreadyExists)
require.ErrorContains(storeSource.RemoveProxy(""), "proxy name cannot be empty")
require.ErrorContains(storeSource.RemoveVisitor(""), "visitor name cannot be empty")
updatedProxy := mockProxy("proxy1").(*v1.TCPProxyConfig)
updatedProxy.RemotePort = 19090
require.NoError(storeSource.UpdateProxy(updatedProxy))
require.Equal(19090, storeSource.GetProxy("proxy1").(*v1.TCPProxyConfig).RemotePort)
updatedVisitor := mockVisitor("visitor1").(*v1.STCPVisitorConfig)
updatedVisitor.ServerName = "updated-server"
require.NoError(storeSource.UpdateVisitor(updatedVisitor))
require.Equal("updated-server", storeSource.GetVisitor("visitor1").(*v1.STCPVisitorConfig).ServerName)
require.NoError(storeSource.RemoveProxy("proxy1"))
require.Nil(storeSource.GetProxy("proxy1"))
require.ErrorIs(storeSource.RemoveProxy("proxy1"), ErrNotFound)
require.NoError(storeSource.RemoveVisitor("visitor1"))
require.Nil(storeSource.GetVisitor("visitor1"))
require.ErrorIs(storeSource.RemoveVisitor("visitor1"), ErrNotFound)
require.ErrorIs(storeSource.UpdateProxy(updatedProxy), ErrNotFound)
require.ErrorIs(storeSource.UpdateVisitor(updatedVisitor), ErrNotFound)
}
func TestStoreSource_MutationRollsBackOnPersistFailure(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("chmod does not make directories unwritable on Windows")
}
if os.Getuid() == 0 {
t.Skip("chmod does not block writes for uid 0")
}
require := require.New(t)
dir := t.TempDir()
path := filepath.Join(dir, "store.json")
storeSource, err := NewStoreSource(StoreSourceConfig{Path: path})
require.NoError(err)
proxyCfg := mockProxy("proxy1")
visitorCfg := mockVisitor("visitor1")
originalRemotePort := proxyCfg.(*v1.TCPProxyConfig).RemotePort
originalServerName := visitorCfg.(*v1.STCPVisitorConfig).ServerName
require.NoError(storeSource.AddProxy(proxyCfg))
require.NoError(storeSource.AddVisitor(visitorCfg))
require.NoError(os.Chmod(dir, 0o500))
t.Cleanup(func() {
_ = os.Chmod(dir, 0o700)
})
requirePersistError := func(err error) {
t.Helper()
require.Error(err)
require.ErrorContains(err, "failed to persist")
require.NotErrorIs(err, ErrAlreadyExists)
require.NotErrorIs(err, ErrNotFound)
}
requirePersistError(storeSource.AddProxy(mockProxy("proxy2")))
require.Nil(storeSource.GetProxy("proxy2"))
updatedProxy := mockProxy("proxy1").(*v1.TCPProxyConfig)
updatedProxy.RemotePort = 19090
requirePersistError(storeSource.UpdateProxy(updatedProxy))
require.Equal(originalRemotePort, storeSource.GetProxy("proxy1").(*v1.TCPProxyConfig).RemotePort)
requirePersistError(storeSource.RemoveProxy("proxy1"))
require.NotNil(storeSource.GetProxy("proxy1"))
requirePersistError(storeSource.AddVisitor(mockVisitor("visitor2")))
require.Nil(storeSource.GetVisitor("visitor2"))
updatedVisitor := mockVisitor("visitor1").(*v1.STCPVisitorConfig)
updatedVisitor.ServerName = "updated-server"
requirePersistError(storeSource.UpdateVisitor(updatedVisitor))
require.Equal(originalServerName, storeSource.GetVisitor("visitor1").(*v1.STCPVisitorConfig).ServerName)
requirePersistError(storeSource.RemoveVisitor("visitor1"))
require.NotNil(storeSource.GetVisitor("visitor1"))
}
func TestStoreSource_LoadFromFile_DoesNotApplyRuntimeDefaults(t *testing.T) {
require := require.New(t)

View file

@ -0,0 +1,43 @@
// Copyright 2026 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package source
import (
"fmt"
v1 "github.com/fatedier/frp/pkg/config/v1"
)
func validateProxyName(proxy v1.ProxyConfigurer) (string, error) {
if proxy == nil {
return "", fmt.Errorf("proxy cannot be nil")
}
name := proxy.GetBaseConfig().Name
if name == "" {
return "", fmt.Errorf("proxy name cannot be empty")
}
return name, nil
}
func validateVisitorName(visitor v1.VisitorConfigurer) (string, error) {
if visitor == nil {
return "", fmt.Errorf("visitor cannot be nil")
}
name := visitor.GetBaseConfig().Name
if name == "" {
return "", fmt.Errorf("visitor name cannot be empty")
}
return name, nil
}

View file

@ -0,0 +1,43 @@
// Copyright 2026 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package validation
import (
"fmt"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/policy/security"
)
func (v *ConfigValidator) validateAuthTokenSource(token string, tokenSource *v1.ValueSource) error {
var errs error
// Preserve the previous client/server validation order for joined errors.
if token != "" && tokenSource != nil {
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
}
if tokenSource == nil {
return errs
}
if tokenSource.Type == "exec" {
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
errs = AppendError(errs, err)
}
}
if err := tokenSource.Validate(); err != nil {
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
}
return errs
}

View file

@ -0,0 +1,228 @@
// Copyright 2026 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package validation
import (
"testing"
"github.com/stretchr/testify/require"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/policy/security"
)
const (
tokenSourceConflictErr = "cannot specify both auth.token and auth.tokenSource"
tokenSourceExecErr = "unsafe feature \"TokenSourceExec\" is not enabled. To enable it, ensure it is allowed in the configuration or command line flags"
invalidFileSourceErr = "invalid auth.tokenSource: file configuration is required when type is 'file'"
unsupportedSourceErr = "invalid auth.tokenSource: unsupported value source type: env (only 'file' and 'exec' are supported)"
)
func TestValidateAuthTokenSource(t *testing.T) {
for _, tc := range authTokenSourceTestCases() {
t.Run(tc.name, func(t *testing.T) {
validator := newAuthTokenSourceValidator(tc.unsafeAllowed)
err := validator.validateAuthTokenSource(tc.token, tc.tokenSource())
requireValidationErrors(t, err, tc.wantErrs)
})
}
}
func TestValidateClientAuthTokenSource(t *testing.T) {
for _, tc := range authTokenSourceTestCases() {
t.Run(tc.name, func(t *testing.T) {
auth := v1.AuthClientConfig{
Method: v1.AuthMethodToken,
Token: tc.token,
TokenSource: tc.tokenSource(),
}
validator := newAuthTokenSourceValidator(tc.unsafeAllowed)
_, err := validator.ValidateClientCommonConfig(validClientConfigWithAuth(auth))
requireValidationErrors(t, err, tc.wantErrs)
})
}
}
func TestValidateServerAuthTokenSource(t *testing.T) {
for _, tc := range authTokenSourceTestCases() {
t.Run(tc.name, func(t *testing.T) {
auth := v1.AuthServerConfig{
Method: v1.AuthMethodToken,
Token: tc.token,
TokenSource: tc.tokenSource(),
}
validator := newAuthTokenSourceValidator(tc.unsafeAllowed)
_, err := validator.ValidateServerConfig(validServerConfigWithAuth(auth))
requireValidationErrors(t, err, tc.wantErrs)
})
}
}
type authTokenSourceTestCase struct {
name string
token string
tokenSource func() *v1.ValueSource
unsafeAllowed bool
wantErrs []string
}
func authTokenSourceTestCases() []authTokenSourceTestCase {
return []authTokenSourceTestCase{
{
name: "empty token config",
tokenSource: nilTokenSource,
},
{
name: "valid file tokenSource",
tokenSource: validFileTokenSource,
},
{
name: "literal token without tokenSource",
token: "token",
tokenSource: nilTokenSource,
},
{
name: "literal token conflicts with file tokenSource",
token: "token",
tokenSource: validFileTokenSource,
wantErrs: []string{tokenSourceConflictErr},
},
{
name: "exec tokenSource requires unsafe feature",
tokenSource: validExecTokenSource,
wantErrs: []string{tokenSourceExecErr},
},
{
name: "exec tokenSource with unsafe feature allowed",
tokenSource: validExecTokenSource,
unsafeAllowed: true,
},
{
name: "literal token conflicts with exec tokenSource and unsafe feature disabled",
token: "token",
tokenSource: validExecTokenSource,
wantErrs: []string{
tokenSourceConflictErr,
tokenSourceExecErr,
},
},
{
name: "literal token conflicts with exec tokenSource and unsafe feature allowed",
token: "token",
tokenSource: validExecTokenSource,
unsafeAllowed: true,
wantErrs: []string{tokenSourceConflictErr},
},
{
name: "invalid file tokenSource is wrapped",
tokenSource: invalidFileTokenSource,
wantErrs: []string{invalidFileSourceErr},
},
{
name: "unsupported tokenSource type is wrapped",
tokenSource: unsupportedTokenSource,
wantErrs: []string{unsupportedSourceErr},
},
}
}
func newAuthTokenSourceValidator(unsafeAllowed bool) *ConfigValidator {
if !unsafeAllowed {
return NewConfigValidator(nil)
}
return NewConfigValidator(security.NewUnsafeFeatures([]string{security.TokenSourceExec}))
}
func requireValidationErrors(t *testing.T, err error, wantErrs []string) {
t.Helper()
if len(wantErrs) == 0 {
require.NoError(t, err)
return
}
require.Error(t, err)
// Client/server validators may wrap joined errors in another join layer; compare leaf errors.
gotErrs := unwrapValidationErrors(err)
require.Len(t, gotErrs, len(wantErrs))
for i, wantErr := range wantErrs {
require.EqualError(t, gotErrs[i], wantErr)
}
}
func unwrapValidationErrors(err error) []error {
type joinedError interface {
Unwrap() []error
}
joined, ok := err.(joinedError)
if !ok {
return []error{err}
}
var errs []error
for _, err := range joined.Unwrap() {
errs = append(errs, unwrapValidationErrors(err)...)
}
return errs
}
// nilTokenSource keeps the shared table shape uniform for cases without a tokenSource.
func nilTokenSource() *v1.ValueSource {
return nil
}
func validFileTokenSource() *v1.ValueSource {
return &v1.ValueSource{
Type: "file",
File: &v1.FileSource{Path: "token.txt"},
}
}
func validExecTokenSource() *v1.ValueSource {
return &v1.ValueSource{
Type: "exec",
Exec: &v1.ExecSource{Command: "print-token"},
}
}
func invalidFileTokenSource() *v1.ValueSource {
return &v1.ValueSource{
Type: "file",
}
}
func unsupportedTokenSource() *v1.ValueSource {
return &v1.ValueSource{Type: "env"}
}
func validClientConfigWithAuth(auth v1.AuthClientConfig) *v1.ClientCommonConfig {
return &v1.ClientCommonConfig{
Auth: auth,
Log: v1.LogConfig{
Level: "info",
},
Transport: v1.ClientTransportConfig{
Protocol: "tcp",
WireProtocol: "v1",
},
}
}
func validServerConfigWithAuth(auth v1.AuthServerConfig) *v1.ServerConfig {
return &v1.ServerConfig{
Auth: auth,
Log: v1.LogConfig{
Level: "info",
},
}
}

View file

@ -68,22 +68,7 @@ func (v *ConfigValidator) validateAuthConfig(c *v1.AuthClientConfig) (Warning, e
errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes))
}
// Validate token/tokenSource mutual exclusivity
if c.Token != "" && c.TokenSource != nil {
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
}
// Validate tokenSource if specified
if c.TokenSource != nil {
if c.TokenSource.Type == "exec" {
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
errs = AppendError(errs, err)
}
}
if err := c.TokenSource.Validate(); err != nil {
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
}
}
errs = AppendError(errs, v.validateAuthTokenSource(c.Token, c.TokenSource))
if err := v.validateOIDCConfig(&c.OIDC); err != nil {
errs = AppendError(errs, err)

View file

@ -21,7 +21,6 @@ import (
"github.com/samber/lo"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/policy/security"
)
func (v *ConfigValidator) ValidateServerConfig(c *v1.ServerConfig) (Warning, error) {
@ -36,22 +35,7 @@ func (v *ConfigValidator) ValidateServerConfig(c *v1.ServerConfig) (Warning, err
errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes))
}
// Validate token/tokenSource mutual exclusivity
if c.Auth.Token != "" && c.Auth.TokenSource != nil {
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
}
// Validate tokenSource if specified
if c.Auth.TokenSource != nil {
if c.Auth.TokenSource.Type == "exec" {
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
errs = AppendError(errs, err)
}
}
if err := c.Auth.TokenSource.Validate(); err != nil {
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
}
}
errs = AppendError(errs, v.validateAuthTokenSource(c.Auth.Token, c.Auth.TokenSource))
if err := validateLogConfig(&c.Log); err != nil {
errs = AppendError(errs, err)

View file

@ -18,6 +18,8 @@ import (
"sync"
"time"
"k8s.io/utils/clock"
"github.com/fatedier/frp/pkg/util/log"
"github.com/fatedier/frp/pkg/util/metric"
server "github.com/fatedier/frp/server/metrics"
@ -38,11 +40,20 @@ func init() {
type serverMetrics struct {
info *ServerStatistics
clock clock.WithTicker
mu sync.Mutex
}
func newServerMetrics() *serverMetrics {
return newServerMetricsWithClock(clock.RealClock{})
}
func newServerMetricsWithClock(clk clock.WithTicker) *serverMetrics {
if clk == nil {
clk = clock.RealClock{}
}
return &serverMetrics{
clock: clk,
info: &ServerStatistics{
TotalTrafficIn: metric.NewDateCounter(ReserveDays),
TotalTrafficOut: metric.NewDateCounter(ReserveDays),
@ -57,14 +68,23 @@ func newServerMetrics() *serverMetrics {
}
func (m *serverMetrics) run() {
go func() {
for {
time.Sleep(12 * time.Hour)
start := time.Now()
count, total := m.clearUselessInfo(time.Duration(7*24) * time.Hour)
log.Debugf("clear useless proxy statistics data count %d/%d, cost %v", count, total, time.Since(start))
go m.runUntil(nil)
}
func (m *serverMetrics) runUntil(stopCh <-chan struct{}) {
ticker := m.clock.NewTicker(12 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C():
start := m.clock.Now()
count, total := m.clearUselessInfo(time.Duration(7*24) * time.Hour)
log.Debugf("clear useless proxy statistics data count %d/%d, cost %v", count, total, m.clock.Since(start))
case <-stopCh:
return
}
}
}()
}
func (m *serverMetrics) clearUselessInfo(continuousOfflineDuration time.Duration) (int, int) {
@ -77,7 +97,7 @@ func (m *serverMetrics) clearUselessInfo(continuousOfflineDuration time.Duration
for name, data := range m.info.ProxyStatistics {
if !data.LastCloseTime.IsZero() &&
data.LastStartTime.Before(data.LastCloseTime) &&
time.Since(data.LastCloseTime) > continuousOfflineDuration {
m.clock.Since(data.LastCloseTime) > continuousOfflineDuration {
delete(m.info.ProxyStatistics, name)
count++
log.Tracef("clear proxy [%s]'s statistics data, lastCloseTime: [%s]", name, data.LastCloseTime.String())
@ -121,7 +141,7 @@ func (m *serverMetrics) NewProxy(name string, proxyType string, user string, cli
}
proxyStats.User = user
proxyStats.ClientID = clientID
proxyStats.LastStartTime = time.Now()
proxyStats.LastStartTime = m.clock.Now()
}
func (m *serverMetrics) CloseProxy(name string, proxyType string) {
@ -131,7 +151,7 @@ func (m *serverMetrics) CloseProxy(name string, proxyType string) {
counter.Dec(1)
}
if proxyStats, ok := m.info.ProxyStatistics[name]; ok {
proxyStats.LastCloseTime = time.Now()
proxyStats.LastCloseTime = m.clock.Now()
}
}

View file

@ -0,0 +1,83 @@
package mem
import (
"testing"
"time"
"github.com/stretchr/testify/require"
clocktesting "k8s.io/utils/clock/testing"
)
func TestServerMetricsUsesClockForProxyTimestamps(t *testing.T) {
require := require.New(t)
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
clk := clocktesting.NewFakeClock(start)
metrics := newServerMetricsWithClock(clk)
metrics.NewProxy("proxy", "tcp", "user", "client-id")
require.Equal(start, metrics.info.ProxyStatistics["proxy"].LastStartTime)
closedAt := start.Add(time.Minute)
clk.SetTime(closedAt)
metrics.CloseProxy("proxy", "tcp")
require.Equal(closedAt, metrics.info.ProxyStatistics["proxy"].LastCloseTime)
}
func TestServerMetricsClearUselessInfoUsesClock(t *testing.T) {
require := require.New(t)
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
clk := clocktesting.NewFakeClock(start.Add(25 * time.Hour))
metrics := newServerMetricsWithClock(clk)
metrics.info.ProxyStatistics["proxy"] = &ProxyStatistics{
Name: "proxy",
LastStartTime: start.Add(-time.Hour),
LastCloseTime: start,
}
count, total := metrics.clearUselessInfo(24 * time.Hour)
require.Equal(1, count)
require.Equal(1, total)
require.Empty(metrics.info.ProxyStatistics)
}
func TestServerMetricsRunUsesClockTicker(t *testing.T) {
require := require.New(t)
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
clk := clocktesting.NewFakeClock(start)
metrics := newServerMetricsWithClock(clk)
metrics.info.ProxyStatistics["proxy"] = &ProxyStatistics{
Name: "proxy",
LastStartTime: start.Add(-time.Hour),
LastCloseTime: start,
}
stopCh := make(chan struct{})
done := make(chan struct{})
go func() {
defer close(done)
metrics.runUntil(stopCh)
}()
t.Cleanup(func() {
close(stopCh)
<-done
})
require.Eventually(clk.HasWaiters, time.Second, time.Millisecond)
clk.Step(8 * 24 * time.Hour)
require.Eventually(func() bool {
return !metrics.hasProxyStatistics("proxy")
}, time.Second, time.Millisecond)
}
func (m *serverMetrics) hasProxyStatistics(name string) bool {
m.mu.Lock()
defer m.mu.Unlock()
_, ok := m.info.ProxyStatistics[name]
return ok
}

View file

@ -112,7 +112,6 @@ type Dispatcher struct {
sendCh chan Message
doneCh chan struct{}
msgHandlers map[reflect.Type]func(Message)
defaultHandler func(Message)
}
func NewDispatcher(rw ReadWriter) *Dispatcher {
@ -151,8 +150,6 @@ func (d *Dispatcher) readLoop() {
if handler, ok := d.msgHandlers[reflect.TypeOf(m)]; ok {
handler(m)
} else if d.defaultHandler != nil {
d.defaultHandler(m)
}
}
}
@ -170,10 +167,6 @@ func (d *Dispatcher) RegisterHandler(msg Message, handler func(Message)) {
d.msgHandlers[reflect.TypeOf(msg)] = handler
}
func (d *Dispatcher) RegisterDefaultHandler(handler func(Message)) {
d.defaultHandler = handler
}
func (d *Dispatcher) Done() chan struct{} {
return d.doneCh
}

View file

@ -21,6 +21,7 @@ import (
"time"
"github.com/samber/lo"
"k8s.io/utils/clock"
)
var (
@ -144,19 +145,19 @@ func getBehaviorByModeAndIndex(mode int, index int) (RecommandBehavior, Recomman
return behaviors[index].A, behaviors[index].B
}
func getBehaviorScoresByMode(mode int, defaultScore int) []*BehaviorScore {
func getBehaviorScoresByMode(mode int, defaultScore int) []*behaviorScore {
return getBehaviorScoresByMode2(mode, defaultScore, defaultScore)
}
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore {
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*behaviorScore {
behaviors := getBehaviorByMode(mode)
scores := make([]*BehaviorScore, 0, len(behaviors))
scores := make([]*behaviorScore, 0, len(behaviors))
for i := range behaviors {
score := receiverScore
if behaviors[i].A.Role == DetectRoleSender {
score = senderScore
}
scores = append(scores, &BehaviorScore{Mode: mode, Index: i, Score: score})
scores = append(scores, &behaviorScore{Mode: mode, Index: i, Score: score})
}
return scores
}
@ -170,14 +171,18 @@ type RecommandBehavior struct {
ListenRandomPorts int
}
type MakeHoleRecords struct {
type makeHoleRecords struct {
mu sync.Mutex
scores []*BehaviorScore
LastUpdateTime time.Time
scores []*behaviorScore
clock clock.PassiveClock
lastUpdateTime time.Time
}
func NewMakeHoleRecords(c, v *NatFeature) *MakeHoleRecords {
scores := []*BehaviorScore{}
func newMakeHoleRecordsWithClock(c, v *NatFeature, clk clock.PassiveClock) *makeHoleRecords {
if clk == nil {
clk = clock.RealClock{}
}
scores := []*behaviorScore{}
easyCount, hardCount, portsChangedRegularCount := ClassifyFeatureCount([]*NatFeature{c, v})
appendMode0 := func() {
switch {
@ -212,13 +217,17 @@ func NewMakeHoleRecords(c, v *NatFeature) *MakeHoleRecords {
scores = append(scores, getBehaviorScoresByMode(DetectMode1, 1)...)
scores = append(scores, getBehaviorScoresByMode(DetectMode3, 1)...)
}
return &MakeHoleRecords{scores: scores, LastUpdateTime: time.Now()}
return &makeHoleRecords{
scores: scores,
clock: clk,
lastUpdateTime: clk.Now(),
}
}
func (mhr *MakeHoleRecords) ReportSuccess(mode int, index int) {
func (mhr *makeHoleRecords) reportSuccess(mode int, index int) {
mhr.mu.Lock()
defer mhr.mu.Unlock()
mhr.LastUpdateTime = time.Now()
mhr.lastUpdateTime = mhr.clock.Now()
for i := range mhr.scores {
score := mhr.scores[i]
if score.Mode != mode || score.Index != index {
@ -231,22 +240,22 @@ func (mhr *MakeHoleRecords) ReportSuccess(mode int, index int) {
}
}
func (mhr *MakeHoleRecords) Recommand() (mode, index int) {
func (mhr *makeHoleRecords) recommand() (mode, index int) {
mhr.mu.Lock()
defer mhr.mu.Unlock()
if len(mhr.scores) == 0 {
return 0, 0
}
maxScore := slices.MaxFunc(mhr.scores, func(a, b *BehaviorScore) int {
maxScore := slices.MaxFunc(mhr.scores, func(a, b *behaviorScore) int {
return cmp.Compare(a.Score, b.Score)
})
maxScore.Score--
mhr.LastUpdateTime = time.Now()
mhr.lastUpdateTime = mhr.clock.Now()
return maxScore.Mode, maxScore.Index
}
type BehaviorScore struct {
type behaviorScore struct {
Mode int
Index int
// between -10 and 10
@ -255,16 +264,25 @@ type BehaviorScore struct {
type Analyzer struct {
// key is client ip + visitor ip
records map[string]*MakeHoleRecords
records map[string]*makeHoleRecords
dataReserveDuration time.Duration
clock clock.PassiveClock
mu sync.Mutex
}
func NewAnalyzer(dataReserveDuration time.Duration) *Analyzer {
return newAnalyzerWithClock(dataReserveDuration, clock.RealClock{})
}
func newAnalyzerWithClock(dataReserveDuration time.Duration, clk clock.PassiveClock) *Analyzer {
if clk == nil {
clk = clock.RealClock{}
}
return &Analyzer{
records: make(map[string]*MakeHoleRecords),
records: make(map[string]*makeHoleRecords),
dataReserveDuration: dataReserveDuration,
clock: clk,
}
}
@ -272,12 +290,12 @@ func (a *Analyzer) GetRecommandBehaviors(key string, c, v *NatFeature) (mode, in
a.mu.Lock()
records, ok := a.records[key]
if !ok {
records = NewMakeHoleRecords(c, v)
records = newMakeHoleRecordsWithClock(c, v, a.clock)
a.records[key] = records
}
a.mu.Unlock()
mode, index = records.Recommand()
mode, index = records.recommand()
cBehavior, vBehavior := getBehaviorByModeAndIndex(mode, index)
switch mode {
@ -307,11 +325,11 @@ func (a *Analyzer) ReportSuccess(key string, mode, index int) {
if !ok {
return
}
records.ReportSuccess(mode, index)
records.reportSuccess(mode, index)
}
func (a *Analyzer) Clean() (int, int) {
now := time.Now()
now := a.clock.Now()
total := 0
count := 0
@ -321,7 +339,7 @@ func (a *Analyzer) Clean() (int, int) {
total = len(a.records)
// clean up records that have not been used for a period of time.
for key, records := range a.records {
if now.Sub(records.LastUpdateTime) > a.dataReserveDuration {
if now.Sub(records.lastUpdateTime) > a.dataReserveDuration {
delete(a.records, key)
count++
}

View file

@ -0,0 +1,33 @@
package nathole
import (
"testing"
"time"
"github.com/stretchr/testify/require"
clocktesting "k8s.io/utils/clock/testing"
)
func TestAnalyzerUsesClockForRecordTimestamps(t *testing.T) {
require := require.New(t)
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
clk := clocktesting.NewFakeClock(start)
analyzer := newAnalyzerWithClock(time.Hour, clk)
clientFeature := &NatFeature{NatType: EasyNAT, Behavior: BehaviorNoChange}
visitorFeature := &NatFeature{NatType: EasyNAT, Behavior: BehaviorNoChange}
mode, index, _, _ := analyzer.GetRecommandBehaviors("key", clientFeature, visitorFeature)
require.Equal(start, analyzer.records["key"].lastUpdateTime)
updatedAt := start.Add(time.Minute)
clk.SetTime(updatedAt)
analyzer.ReportSuccess("key", mode, index)
require.Equal(updatedAt, analyzer.records["key"].lastUpdateTime)
clk.SetTime(start.Add(2 * time.Hour))
count, total := analyzer.Clean()
require.Equal(1, count)
require.Equal(1, total)
require.Empty(analyzer.records)
}

View file

@ -326,40 +326,16 @@ func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleR
}
protocol := vm.Protocol
vResp := &msg.NatHoleResp{
TransactionID: vm.TransactionID,
Sid: session.sid,
Protocol: protocol,
CandidateAddrs: slices.Compact(cm.MappedAddrs),
AssistedAddrs: slices.Compact(cm.AssistedAddrs),
DetectBehavior: msg.NatHoleDetectBehavior{
Mode: mode,
Role: vBehavior.Role,
TTL: vBehavior.TTL,
SendDelayMs: vBehavior.SendDelayMs,
ReadTimeoutMs: timeoutMs - vBehavior.SendDelayMs,
SendRandomPorts: vBehavior.PortsRandomNumber,
ListenRandomPorts: vBehavior.ListenRandomPorts,
CandidatePorts: getRangePorts(cm.MappedAddrs, cNatFeature.PortsDifference, vBehavior.PortsRangeNumber),
},
}
cResp := &msg.NatHoleResp{
TransactionID: cm.TransactionID,
Sid: session.sid,
Protocol: protocol,
CandidateAddrs: slices.Compact(vm.MappedAddrs),
AssistedAddrs: slices.Compact(vm.AssistedAddrs),
DetectBehavior: msg.NatHoleDetectBehavior{
Mode: mode,
Role: cBehavior.Role,
TTL: cBehavior.TTL,
SendDelayMs: cBehavior.SendDelayMs,
ReadTimeoutMs: timeoutMs - cBehavior.SendDelayMs,
SendRandomPorts: cBehavior.PortsRandomNumber,
ListenRandomPorts: cBehavior.ListenRandomPorts,
CandidatePorts: getRangePorts(vm.MappedAddrs, vNatFeature.PortsDifference, cBehavior.PortsRangeNumber),
},
}
vResp := newNatHoleResponse(
vm.TransactionID, session.sid, protocol, mode,
cm.MappedAddrs, cm.AssistedAddrs, vBehavior,
timeoutMs-vBehavior.SendDelayMs, cNatFeature.PortsDifference,
)
cResp := newNatHoleResponse(
cm.TransactionID, session.sid, protocol, mode,
vm.MappedAddrs, vm.AssistedAddrs, cBehavior,
timeoutMs-cBehavior.SendDelayMs, vNatFeature.PortsDifference,
)
log.Debugf("sid [%s] visitor nat: %+v, candidateAddrs: %v; client nat: %+v, candidateAddrs: %v, protocol: %s",
session.sid, *vNatFeature, vm.MappedAddrs, *cNatFeature, cm.MappedAddrs, protocol)
@ -368,6 +344,38 @@ func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleR
return vResp, cResp, nil
}
func newNatHoleResponse(
transactionID string,
sid string,
protocol string,
mode int,
candidateAddrs []string,
assistedAddrs []string,
behavior RecommandBehavior,
readTimeoutMs int,
portsDifference int,
) *msg.NatHoleResp {
compactCandidateAddrs := slices.Compact(candidateAddrs)
compactAssistedAddrs := slices.Compact(assistedAddrs)
return &msg.NatHoleResp{
TransactionID: transactionID,
Sid: sid,
Protocol: protocol,
CandidateAddrs: compactCandidateAddrs,
AssistedAddrs: compactAssistedAddrs,
DetectBehavior: msg.NatHoleDetectBehavior{
Mode: mode,
Role: behavior.Role,
TTL: behavior.TTL,
SendDelayMs: behavior.SendDelayMs,
ReadTimeoutMs: readTimeoutMs,
SendRandomPorts: behavior.PortsRandomNumber,
ListenRandomPorts: behavior.ListenRandomPorts,
CandidatePorts: getRangePorts(candidateAddrs, portsDifference, behavior.PortsRangeNumber),
},
}
}
func getRangePorts(addrs []string, difference, maxNumber int) []msg.PortsRange {
if maxNumber <= 0 {
return nil

View file

@ -17,17 +17,9 @@
package client
import (
"context"
stdlog "log"
"net/http"
"net/http/httputil"
"time"
"github.com/fatedier/golib/pool"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/util/log"
netpkg "github.com/fatedier/frp/pkg/util/net"
)
func init() {
@ -37,57 +29,28 @@ func init() {
type HTTP2HTTPPlugin struct {
opts *v1.HTTP2HTTPPluginOptions
l *Listener
s *http.Server
*httpBridgePlugin
}
func NewHTTP2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
opts := options.(*v1.HTTP2HTTPPluginOptions)
listener := NewProxyListener()
p := &HTTP2HTTPPlugin{
opts: opts,
l: listener,
}
rp := &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
rp := newHTTPBridgeReverseProxy(
func(r *httputil.ProxyRequest) {
req := r.Out
req.URL.Scheme = "http"
req.URL.Host = p.opts.LocalAddr
if p.opts.HostHeaderRewrite != "" {
req.Host = p.opts.HostHeaderRewrite
}
for k, v := range p.opts.RequestHeaders.Set {
req.Header.Set(k, v)
}
rewriteHTTPPluginRequest(req, "http", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
},
BufferPool: pool.NewBuffer(32 * 1024),
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
}
p.s = &http.Server{
Handler: rp,
ReadHeaderTimeout: 60 * time.Second,
}
go func() {
_ = p.s.Serve(listener)
}()
nil,
)
p.httpBridgePlugin = newHTTPBridgePluginServer(rp, false)
return p, nil
}
func (p *HTTP2HTTPPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
_ = p.l.PutConn(wrapConn)
}
func (p *HTTP2HTTPPlugin) Name() string {
return v1.PluginHTTP2HTTP
}
func (p *HTTP2HTTPPlugin) Close() error {
return p.s.Close()
}

View file

@ -17,18 +17,11 @@
package client
import (
"context"
"crypto/tls"
stdlog "log"
"net/http"
"net/http/httputil"
"time"
"github.com/fatedier/golib/pool"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/util/log"
netpkg "github.com/fatedier/frp/pkg/util/net"
)
func init() {
@ -38,65 +31,35 @@ func init() {
type HTTP2HTTPSPlugin struct {
opts *v1.HTTP2HTTPSPluginOptions
l *Listener
s *http.Server
*httpBridgePlugin
}
func NewHTTP2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
opts := options.(*v1.HTTP2HTTPSPluginOptions)
listener := NewProxyListener()
p := &HTTP2HTTPSPlugin{
opts: opts,
l: listener,
}
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
rp := &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
rp := newHTTPBridgeReverseProxy(
func(r *httputil.ProxyRequest) {
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
r.Out.Header["X-Forwarded-Host"] = r.In.Header["X-Forwarded-Host"]
r.Out.Header["X-Forwarded-Proto"] = r.In.Header["X-Forwarded-Proto"]
req := r.Out
req.URL.Scheme = "https"
req.URL.Host = p.opts.LocalAddr
if p.opts.HostHeaderRewrite != "" {
req.Host = p.opts.HostHeaderRewrite
}
for k, v := range p.opts.RequestHeaders.Set {
req.Header.Set(k, v)
}
rewriteHTTPPluginRequest(req, "https", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
},
Transport: tr,
BufferPool: pool.NewBuffer(32 * 1024),
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
}
p.s = &http.Server{
Handler: rp,
ReadHeaderTimeout: 60 * time.Second,
}
go func() {
_ = p.s.Serve(listener)
}()
tr,
)
p.httpBridgePlugin = newHTTPBridgePluginServer(rp, false)
return p, nil
}
func (p *HTTP2HTTPSPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
_ = p.l.PutConn(wrapConn)
}
func (p *HTTP2HTTPSPlugin) Name() string {
return v1.PluginHTTP2HTTPS
}
func (p *HTTP2HTTPSPlugin) Close() error {
return p.s.Close()
}

View file

@ -0,0 +1,126 @@
// Copyright 2026 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !frps
package client
import (
"context"
stdlog "log"
"net/http"
"net/http/httputil"
"time"
"github.com/fatedier/golib/pool"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/plugin/client/internal/httpsserver"
"github.com/fatedier/frp/pkg/util/log"
netpkg "github.com/fatedier/frp/pkg/util/net"
)
const httpBridgeReadHeaderTimeout = 60 * time.Second
func rewriteHTTPPluginRequest(
req *http.Request,
scheme string,
localAddr string,
hostHeaderRewrite string,
requestHeaders v1.HeaderOperations,
) {
req.URL.Scheme = scheme
req.URL.Host = localAddr
if hostHeaderRewrite != "" {
req.Host = hostHeaderRewrite
}
for k, v := range requestHeaders.Set {
req.Header.Set(k, v)
}
}
type httpBridgePlugin struct {
l *Listener
s *http.Server
useSourceRemoteAddr bool
}
func newHTTPBridgePluginServer(handler http.Handler, useSourceRemoteAddr bool) *httpBridgePlugin {
listener := NewProxyListener()
p := &httpBridgePlugin{
l: listener,
useSourceRemoteAddr: useSourceRemoteAddr,
}
p.s = &http.Server{
Handler: handler,
ReadHeaderTimeout: httpBridgeReadHeaderTimeout,
}
go func() {
_ = p.s.Serve(listener)
}()
return p
}
func newHTTPSBridgePluginServer(
handler http.Handler,
crtPath string,
keyPath string,
enableHTTP2 *bool,
useSourceRemoteAddr bool,
) (*httpBridgePlugin, error) {
listener := NewProxyListener()
server, err := httpsserver.New(handler, crtPath, keyPath, enableHTTP2)
if err != nil {
return nil, err
}
p := &httpBridgePlugin{
l: listener,
s: server,
useSourceRemoteAddr: useSourceRemoteAddr,
}
go func() {
_ = p.s.ServeTLS(listener, "", "")
}()
return p, nil
}
func newHTTPBridgeReverseProxy(
rewrite func(*httputil.ProxyRequest),
transport http.RoundTripper,
) *httputil.ReverseProxy {
rp := &httputil.ReverseProxy{
Rewrite: rewrite,
BufferPool: pool.NewBuffer(32 * 1024),
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
}
if transport != nil {
rp.Transport = transport
}
return rp
}
func (p *httpBridgePlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
if p.useSourceRemoteAddr && connInfo.SrcAddr != nil {
wrapConn.SetRemoteAddr(connInfo.SrcAddr)
}
_ = p.l.PutConn(wrapConn)
}
func (p *httpBridgePlugin) Close() error {
err := p.s.Close()
_ = p.l.Close()
return err
}

View file

@ -17,22 +17,9 @@
package client
import (
"context"
"crypto/tls"
"fmt"
stdlog "log"
"net/http"
"net/http/httputil"
"time"
"github.com/fatedier/golib/pool"
"github.com/samber/lo"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/transport"
httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/pkg/util/log"
netpkg "github.com/fatedier/frp/pkg/util/net"
)
func init() {
@ -42,80 +29,35 @@ func init() {
type HTTPS2HTTPPlugin struct {
opts *v1.HTTPS2HTTPPluginOptions
l *Listener
s *http.Server
*httpBridgePlugin
}
func NewHTTPS2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
opts := options.(*v1.HTTPS2HTTPPluginOptions)
listener := NewProxyListener()
p := &HTTPS2HTTPPlugin{
opts: opts,
l: listener,
}
rp := &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
rp := newHTTPBridgeReverseProxy(
func(r *httputil.ProxyRequest) {
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
r.SetXForwarded()
req := r.Out
req.URL.Scheme = "http"
req.URL.Host = p.opts.LocalAddr
if p.opts.HostHeaderRewrite != "" {
req.Host = p.opts.HostHeaderRewrite
}
for k, v := range p.opts.RequestHeaders.Set {
req.Header.Set(k, v)
}
rewriteHTTPPluginRequest(req, "http", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
},
BufferPool: pool.NewBuffer(32 * 1024),
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS != nil {
tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName)
host, _ := httppkg.CanonicalHost(r.Host)
if tlsServerName != "" && tlsServerName != host {
w.WriteHeader(http.StatusMisdirectedRequest)
return
}
}
rp.ServeHTTP(w, r)
})
nil,
)
tlsConfig, err := transport.NewServerTLSConfig(p.opts.CrtPath, p.opts.KeyPath, "")
server, err := newHTTPSBridgePluginServer(rp, p.opts.CrtPath, p.opts.KeyPath, opts.EnableHTTP2, true)
if err != nil {
return nil, fmt.Errorf("gen TLS config error: %v", err)
return nil, err
}
p.httpBridgePlugin = server
p.s = &http.Server{
Handler: handler,
ReadHeaderTimeout: 60 * time.Second,
TLSConfig: tlsConfig,
}
if !lo.FromPtr(opts.EnableHTTP2) {
p.s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
}
go func() {
_ = p.s.ServeTLS(listener, "", "")
}()
return p, nil
}
func (p *HTTPS2HTTPPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
if connInfo.SrcAddr != nil {
wrapConn.SetRemoteAddr(connInfo.SrcAddr)
}
_ = p.l.PutConn(wrapConn)
}
func (p *HTTPS2HTTPPlugin) Name() string {
return v1.PluginHTTPS2HTTP
}
func (p *HTTPS2HTTPPlugin) Close() error {
return p.s.Close()
}

View file

@ -17,22 +17,11 @@
package client
import (
"context"
"crypto/tls"
"fmt"
stdlog "log"
"net/http"
"net/http/httputil"
"time"
"github.com/fatedier/golib/pool"
"github.com/samber/lo"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/transport"
httppkg "github.com/fatedier/frp/pkg/util/http"
"github.com/fatedier/frp/pkg/util/log"
netpkg "github.com/fatedier/frp/pkg/util/net"
)
func init() {
@ -42,86 +31,39 @@ func init() {
type HTTPS2HTTPSPlugin struct {
opts *v1.HTTPS2HTTPSPluginOptions
l *Listener
s *http.Server
*httpBridgePlugin
}
func NewHTTPS2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
opts := options.(*v1.HTTPS2HTTPSPluginOptions)
listener := NewProxyListener()
p := &HTTPS2HTTPSPlugin{
opts: opts,
l: listener,
}
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
rp := &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
rp := newHTTPBridgeReverseProxy(
func(r *httputil.ProxyRequest) {
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
r.SetXForwarded()
req := r.Out
req.URL.Scheme = "https"
req.URL.Host = p.opts.LocalAddr
if p.opts.HostHeaderRewrite != "" {
req.Host = p.opts.HostHeaderRewrite
}
for k, v := range p.opts.RequestHeaders.Set {
req.Header.Set(k, v)
}
rewriteHTTPPluginRequest(req, "https", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
},
Transport: tr,
BufferPool: pool.NewBuffer(32 * 1024),
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS != nil {
tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName)
host, _ := httppkg.CanonicalHost(r.Host)
if tlsServerName != "" && tlsServerName != host {
w.WriteHeader(http.StatusMisdirectedRequest)
return
}
}
rp.ServeHTTP(w, r)
})
tr,
)
tlsConfig, err := transport.NewServerTLSConfig(p.opts.CrtPath, p.opts.KeyPath, "")
server, err := newHTTPSBridgePluginServer(rp, p.opts.CrtPath, p.opts.KeyPath, opts.EnableHTTP2, true)
if err != nil {
return nil, fmt.Errorf("gen TLS config error: %v", err)
return nil, err
}
p.httpBridgePlugin = server
p.s = &http.Server{
Handler: handler,
ReadHeaderTimeout: 60 * time.Second,
TLSConfig: tlsConfig,
}
if !lo.FromPtr(opts.EnableHTTP2) {
p.s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
}
go func() {
_ = p.s.ServeTLS(listener, "", "")
}()
return p, nil
}
func (p *HTTPS2HTTPSPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
if connInfo.SrcAddr != nil {
wrapConn.SetRemoteAddr(connInfo.SrcAddr)
}
_ = p.l.PutConn(wrapConn)
}
func (p *HTTPS2HTTPSPlugin) Name() string {
return v1.PluginHTTPS2HTTPS
}
func (p *HTTPS2HTTPSPlugin) Close() error {
return p.s.Close()
}

View file

@ -0,0 +1,60 @@
// Copyright 2026 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !frps
package httpsserver
import (
"crypto/tls"
"fmt"
"net/http"
"time"
"github.com/samber/lo"
"github.com/fatedier/frp/pkg/transport"
httppkg "github.com/fatedier/frp/pkg/util/http"
)
func New(handler http.Handler, crtPath, keyPath string, enableHTTP2 *bool) (*http.Server, error) {
tlsConfig, err := transport.NewServerTLSConfig(crtPath, keyPath, "")
if err != nil {
return nil, fmt.Errorf("gen TLS config error: %v", err)
}
server := &http.Server{
Handler: withMisdirectedRequestCheck(handler),
ReadHeaderTimeout: 60 * time.Second,
TLSConfig: tlsConfig,
}
if !lo.FromPtr(enableHTTP2) {
server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
}
return server, nil
}
func withMisdirectedRequestCheck(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS != nil {
tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName)
host, _ := httppkg.CanonicalHost(r.Host)
if tlsServerName != "" && tlsServerName != host {
w.WriteHeader(http.StatusMisdirectedRequest)
return
}
}
handler.ServeHTTP(w, r)
})
}

View file

@ -44,6 +44,67 @@ func NewManager() *Manager {
}
}
func newPluginRequestContext() (context.Context, *xlog.Logger) {
reqid, _ := util.RandID()
xl := xlog.New().AppendPrefix("reqid: " + reqid)
ctx := xlog.NewContext(context.Background(), xl)
return NewReqidContext(ctx, reqid), xl
}
type pluginErrorLogMode bool
const (
// Warn is the zero value because it is the default for mutable plugin operations.
pluginErrorLogWarn pluginErrorLogMode = false
pluginErrorLogInfo pluginErrorLogMode = true
)
func logPluginError(xl *xlog.Logger, p Plugin, op string, err error, mode pluginErrorLogMode) {
if mode == pluginErrorLogInfo {
xl.Infof("send %s request to plugin [%s] error: %v", op, p.Name(), err)
return
}
xl.Warnf("send %s request to plugin [%s] error: %v", op, p.Name(), err)
}
func handleMutableContent[T any](
plugins []Plugin,
op string,
content *T,
logMode pluginErrorLogMode,
) (*T, error) {
if len(plugins) == 0 {
return content, nil
}
var (
res = &Response{
Reject: false,
Unchange: true,
}
retContent any
err error
)
ctx, xl := newPluginRequestContext()
for _, p := range plugins {
res, retContent, err = p.Handle(ctx, op, *content)
if err != nil {
logPluginError(xl, p, op, err, logMode)
return nil, errors.New("send " + op + " request to plugin error")
}
if res.Reject {
return nil, fmt.Errorf("%s", res.RejectReason)
}
if !res.Unchange {
// Preserve the existing Plugin contract: changed content must be *T.
// Buggy Plugin implementations still panic here, by design.
content = retContent.(*T)
}
}
return content, nil
}
func (m *Manager) Register(p Plugin) {
if p.IsSupport(OpLogin) {
m.loginPlugins = append(m.loginPlugins, p)
@ -66,71 +127,11 @@ func (m *Manager) Register(p Plugin) {
}
func (m *Manager) Login(content *LoginContent) (*LoginContent, error) {
if len(m.loginPlugins) == 0 {
return content, nil
}
var (
res = &Response{
Reject: false,
Unchange: true,
}
retContent any
err error
)
reqid, _ := util.RandID()
xl := xlog.New().AppendPrefix("reqid: " + reqid)
ctx := xlog.NewContext(context.Background(), xl)
ctx = NewReqidContext(ctx, reqid)
for _, p := range m.loginPlugins {
res, retContent, err = p.Handle(ctx, OpLogin, *content)
if err != nil {
xl.Warnf("send Login request to plugin [%s] error: %v", p.Name(), err)
return nil, errors.New("send Login request to plugin error")
}
if res.Reject {
return nil, fmt.Errorf("%s", res.RejectReason)
}
if !res.Unchange {
content = retContent.(*LoginContent)
}
}
return content, nil
return handleMutableContent(m.loginPlugins, OpLogin, content, pluginErrorLogWarn)
}
func (m *Manager) NewProxy(content *NewProxyContent) (*NewProxyContent, error) {
if len(m.newProxyPlugins) == 0 {
return content, nil
}
var (
res = &Response{
Reject: false,
Unchange: true,
}
retContent any
err error
)
reqid, _ := util.RandID()
xl := xlog.New().AppendPrefix("reqid: " + reqid)
ctx := xlog.NewContext(context.Background(), xl)
ctx = NewReqidContext(ctx, reqid)
for _, p := range m.newProxyPlugins {
res, retContent, err = p.Handle(ctx, OpNewProxy, *content)
if err != nil {
xl.Warnf("send NewProxy request to plugin [%s] error: %v", p.Name(), err)
return nil, errors.New("send NewProxy request to plugin error")
}
if res.Reject {
return nil, fmt.Errorf("%s", res.RejectReason)
}
if !res.Unchange {
content = retContent.(*NewProxyContent)
}
}
return content, nil
return handleMutableContent(m.newProxyPlugins, OpNewProxy, content, pluginErrorLogWarn)
}
func (m *Manager) CloseProxy(content *CloseProxyContent) error {
@ -139,10 +140,7 @@ func (m *Manager) CloseProxy(content *CloseProxyContent) error {
}
errs := make([]string, 0)
reqid, _ := util.RandID()
xl := xlog.New().AppendPrefix("reqid: " + reqid)
ctx := xlog.NewContext(context.Background(), xl)
ctx = NewReqidContext(ctx, reqid)
ctx, xl := newPluginRequestContext()
for _, p := range m.closeProxyPlugins {
_, _, err := p.Handle(ctx, OpCloseProxy, *content)
@ -159,103 +157,14 @@ func (m *Manager) CloseProxy(content *CloseProxyContent) error {
}
func (m *Manager) Ping(content *PingContent) (*PingContent, error) {
if len(m.pingPlugins) == 0 {
return content, nil
}
var (
res = &Response{
Reject: false,
Unchange: true,
}
retContent any
err error
)
reqid, _ := util.RandID()
xl := xlog.New().AppendPrefix("reqid: " + reqid)
ctx := xlog.NewContext(context.Background(), xl)
ctx = NewReqidContext(ctx, reqid)
for _, p := range m.pingPlugins {
res, retContent, err = p.Handle(ctx, OpPing, *content)
if err != nil {
xl.Warnf("send Ping request to plugin [%s] error: %v", p.Name(), err)
return nil, errors.New("send Ping request to plugin error")
}
if res.Reject {
return nil, fmt.Errorf("%s", res.RejectReason)
}
if !res.Unchange {
content = retContent.(*PingContent)
}
}
return content, nil
return handleMutableContent(m.pingPlugins, OpPing, content, pluginErrorLogWarn)
}
func (m *Manager) NewWorkConn(content *NewWorkConnContent) (*NewWorkConnContent, error) {
if len(m.newWorkConnPlugins) == 0 {
return content, nil
}
var (
res = &Response{
Reject: false,
Unchange: true,
}
retContent any
err error
)
reqid, _ := util.RandID()
xl := xlog.New().AppendPrefix("reqid: " + reqid)
ctx := xlog.NewContext(context.Background(), xl)
ctx = NewReqidContext(ctx, reqid)
for _, p := range m.newWorkConnPlugins {
res, retContent, err = p.Handle(ctx, OpNewWorkConn, *content)
if err != nil {
xl.Warnf("send NewWorkConn request to plugin [%s] error: %v", p.Name(), err)
return nil, errors.New("send NewWorkConn request to plugin error")
}
if res.Reject {
return nil, fmt.Errorf("%s", res.RejectReason)
}
if !res.Unchange {
content = retContent.(*NewWorkConnContent)
}
}
return content, nil
return handleMutableContent(m.newWorkConnPlugins, OpNewWorkConn, content, pluginErrorLogWarn)
}
func (m *Manager) NewUserConn(content *NewUserConnContent) (*NewUserConnContent, error) {
if len(m.newUserConnPlugins) == 0 {
return content, nil
}
var (
res = &Response{
Reject: false,
Unchange: true,
}
retContent any
err error
)
reqid, _ := util.RandID()
xl := xlog.New().AppendPrefix("reqid: " + reqid)
ctx := xlog.NewContext(context.Background(), xl)
ctx = NewReqidContext(ctx, reqid)
for _, p := range m.newUserConnPlugins {
res, retContent, err = p.Handle(ctx, OpNewUserConn, *content)
if err != nil {
xl.Infof("send NewUserConn request to plugin [%s] error: %v", p.Name(), err)
return nil, errors.New("send NewUserConn request to plugin error")
}
if res.Reject {
return nil, fmt.Errorf("%s", res.RejectReason)
}
if !res.Unchange {
content = retContent.(*NewUserConnContent)
}
}
return content, nil
// Preserve the pre-refactor log level for NewUserConn plugin errors.
return handleMutableContent(m.newUserConnPlugins, OpNewUserConn, content, pluginErrorLogInfo)
}

View file

@ -0,0 +1,336 @@
// Copyright 2019 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"bytes"
"context"
"errors"
"strings"
"sync"
"testing"
"time"
goliblog "github.com/fatedier/golib/log"
"github.com/fatedier/frp/pkg/msg"
frplog "github.com/fatedier/frp/pkg/util/log"
)
type testPlugin struct {
name string
ops map[string]bool
handler func(context.Context, string, any) (*Response, any, error)
}
// Log-capturing subtests serialize global logger swaps; do not use t.Parallel.
var logCaptureMu sync.Mutex
type logCapture struct {
bytes.Buffer
levels []goliblog.Level
}
func (p testPlugin) Name() string {
return p.name
}
func (p testPlugin) IsSupport(op string) bool {
return p.ops[op]
}
func (p testPlugin) Handle(ctx context.Context, op string, content any) (*Response, any, error) {
return p.handler(ctx, op, content)
}
func (w *logCapture) WriteLog(p []byte, level goliblog.Level, _ time.Time) (int, error) {
w.levels = append(w.levels, level)
return w.Write(p)
}
func captureLogOutput(t *testing.T) *logCapture {
t.Helper()
logCaptureMu.Lock()
logOutput := &logCapture{}
oldLogger := frplog.Logger
frplog.Logger = goliblog.New(
goliblog.WithOutput(logOutput),
goliblog.WithLevel(goliblog.TraceLevel),
goliblog.WithCaller(false),
)
t.Cleanup(func() {
frplog.Logger = oldLogger
logCaptureMu.Unlock()
})
return logOutput
}
var mutablePluginOps = []struct {
name string
op string
}{
{name: "login", op: OpLogin},
{name: "new proxy", op: OpNewProxy},
{name: "ping", op: OpPing},
{name: "new work conn", op: OpNewWorkConn},
{name: "new user conn", op: OpNewUserConn},
}
func callMutableWithUser(m *Manager, op string, user string) (string, error) {
switch op {
case OpLogin:
got, err := m.Login(&LoginContent{Login: msg.Login{User: user}})
if got == nil {
return "", err
}
return got.User, err
case OpNewProxy:
got, err := m.NewProxy(&NewProxyContent{User: UserInfo{User: user}})
if got == nil {
return "", err
}
return got.User.User, err
case OpPing:
got, err := m.Ping(&PingContent{User: UserInfo{User: user}})
if got == nil {
return "", err
}
return got.User.User, err
case OpNewWorkConn:
got, err := m.NewWorkConn(&NewWorkConnContent{User: UserInfo{User: user}})
if got == nil {
return "", err
}
return got.User.User, err
case OpNewUserConn:
got, err := m.NewUserConn(&NewUserConnContent{User: UserInfo{User: user}})
if got == nil {
return "", err
}
return got.User.User, err
default:
panic("unsupported mutable op: " + op)
}
}
func mutableUser(t *testing.T, op string, content any) string {
t.Helper()
switch op {
case OpLogin:
return content.(LoginContent).User
case OpNewProxy:
return content.(NewProxyContent).User.User
case OpPing:
return content.(PingContent).User.User
case OpNewWorkConn:
return content.(NewWorkConnContent).User.User
case OpNewUserConn:
return content.(NewUserConnContent).User.User
default:
t.Fatalf("unsupported mutable op: %s", op)
return ""
}
}
func mutateMutableContent(t *testing.T, op string, content any, user string) any {
t.Helper()
switch op {
case OpLogin:
got := content.(LoginContent)
got.User = user
return &got
case OpNewProxy:
got := content.(NewProxyContent)
got.User.User = user
return &got
case OpPing:
got := content.(PingContent)
got.User.User = user
return &got
case OpNewWorkConn:
got := content.(NewWorkConnContent)
got.User.User = user
return &got
case OpNewUserConn:
got := content.(NewUserConnContent)
got.User.User = user
return &got
default:
t.Fatalf("unsupported mutable op: %s", op)
return nil
}
}
func TestManagerMutableContentAcrossOps(t *testing.T) {
for _, tt := range mutablePluginOps {
t.Run(tt.name, func(t *testing.T) {
m := NewManager()
m.Register(testPlugin{
name: "mutate",
ops: map[string]bool{tt.op: true},
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
if op != tt.op {
t.Fatalf("unexpected op: %s", op)
}
if GetReqidFromContext(ctx) == "" {
t.Fatal("expected request id in context")
}
if got := mutableUser(t, tt.op, content); got != "initial" {
t.Fatalf("expected initial user, got %q", got)
}
return &Response{Unchange: false}, mutateMutableContent(t, tt.op, content, "mutated"), nil
},
})
m.Register(testPlugin{
name: "observe",
ops: map[string]bool{tt.op: true},
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
if op != tt.op {
t.Fatalf("unexpected op: %s", op)
}
if GetReqidFromContext(ctx) == "" {
t.Fatal("expected request id in context")
}
if got := mutableUser(t, tt.op, content); got != "mutated" {
t.Fatalf("expected mutated user, got %q", got)
}
return &Response{Unchange: true}, mutateMutableContent(t, tt.op, content, "ignored"), nil
},
})
got, err := callMutableWithUser(m, tt.op, "initial")
if err != nil {
t.Fatalf("mutable op failed: %v", err)
}
if got != "mutated" {
t.Fatalf("expected mutated user, got %q", got)
}
})
}
}
func TestManagerMutableContentRejectStopsChain(t *testing.T) {
m := NewManager()
var called bool
m.Register(testPlugin{
name: "reject",
ops: map[string]bool{OpPing: true},
handler: func(context.Context, string, any) (*Response, any, error) {
return &Response{Reject: true, RejectReason: "blocked"}, nil, nil
},
})
m.Register(testPlugin{
name: "unused",
ops: map[string]bool{OpPing: true},
handler: func(context.Context, string, any) (*Response, any, error) {
called = true
return &Response{Unchange: true}, nil, nil
},
})
got, err := m.Ping(&PingContent{})
if err == nil {
t.Fatal("expected reject error")
}
if got != nil {
t.Fatalf("expected no returned content, got %#v", got)
}
if err.Error() != "blocked" {
t.Fatalf("unexpected error: %v", err)
}
if called {
t.Fatal("expected plugin chain to stop after reject")
}
}
func TestManagerMutableContentPluginErrorLogLevel(t *testing.T) {
tests := []struct {
name string
op string
level goliblog.Level
}{
{name: "default warning", op: OpLogin, level: goliblog.WarnLevel},
{name: "new user conn info", op: OpNewUserConn, level: goliblog.InfoLevel},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logOutput := captureLogOutput(t)
m := NewManager()
m.Register(testPlugin{
name: "error",
ops: map[string]bool{tt.op: true},
handler: func(context.Context, string, any) (*Response, any, error) {
return nil, nil, errors.New("boom")
},
})
_, err := callMutableWithUser(m, tt.op, "initial")
if err == nil {
t.Fatal("expected plugin error")
}
if want := "send " + tt.op + " request to plugin error"; err.Error() != want {
t.Fatalf("unexpected error: %v", err)
}
if len(logOutput.levels) != 1 || logOutput.levels[0] != tt.level {
t.Fatalf("expected log level %v, got %v in %q", tt.level, logOutput.levels, logOutput.String())
}
})
}
}
func TestManagerCloseProxyAggregatesErrors(t *testing.T) {
logOutput := captureLogOutput(t)
m := NewManager()
for _, name := range []string{"first", "second"} {
m.Register(testPlugin{
name: name,
ops: map[string]bool{OpCloseProxy: true},
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
if GetReqidFromContext(ctx) == "" {
t.Fatal("expected request id in context")
}
if op != OpCloseProxy {
t.Fatalf("unexpected op: %s", op)
}
return nil, nil, errors.New(name + " error")
},
})
}
err := m.CloseProxy(&CloseProxyContent{})
if err == nil {
t.Fatal("expected close proxy error")
}
if !strings.HasPrefix(err.Error(), "send CloseProxy request to plugin errors: ") {
t.Fatalf("unexpected close proxy error prefix: %v", err)
}
if !strings.Contains(err.Error(), "[first]: first error") || !strings.Contains(err.Error(), "[second]: second error") {
t.Fatalf("missing aggregated errors: %v", err)
}
if len(logOutput.levels) != 2 {
t.Fatalf("expected two warning logs, got %v", logOutput.levels)
}
for _, level := range logOutput.levels {
if level != goliblog.WarnLevel {
t.Fatalf("expected warning log level, got %v", logOutput.levels)
}
}
}

View file

@ -17,6 +17,8 @@ package metric
import (
"sync"
"time"
"k8s.io/utils/clock"
)
type DateCounter interface {
@ -38,27 +40,33 @@ func NewDateCounter(reserveDays int64) DateCounter {
type StandardDateCounter struct {
reserveDays int64
counts []int64
clock clock.PassiveClock
lastUpdateDate time.Time
mu sync.Mutex
}
func newStandardDateCounter(reserveDays int64) *StandardDateCounter {
now := time.Now()
now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
s := &StandardDateCounter{
return newStandardDateCounterWithClock(reserveDays, clock.RealClock{})
}
func newStandardDateCounterWithClock(reserveDays int64, clk clock.PassiveClock) *StandardDateCounter {
if clk == nil {
clk = clock.RealClock{}
}
return &StandardDateCounter{
reserveDays: reserveDays,
counts: make([]int64, reserveDays),
lastUpdateDate: now,
clock: clk,
lastUpdateDate: startOfDay(clk.Now()),
}
return s
}
func (c *StandardDateCounter) TodayCount() int64 {
c.mu.Lock()
defer c.mu.Unlock()
c.rotate(time.Now())
c.rotate(c.clock.Now())
return c.counts[0]
}
@ -70,65 +78,61 @@ func (c *StandardDateCounter) GetLastDaysCount(lastdays int64) []int64 {
c.mu.Lock()
defer c.mu.Unlock()
c.rotate(time.Now())
for i := 0; i < int(lastdays); i++ {
counts[i] = c.counts[i]
}
c.rotate(c.clock.Now())
copy(counts, c.counts)
return counts
}
func (c *StandardDateCounter) Inc(count int64) {
c.mu.Lock()
defer c.mu.Unlock()
c.rotate(time.Now())
c.rotate(c.clock.Now())
c.counts[0] += count
}
func (c *StandardDateCounter) Dec(count int64) {
c.mu.Lock()
defer c.mu.Unlock()
c.rotate(time.Now())
c.rotate(c.clock.Now())
c.counts[0] -= count
}
func (c *StandardDateCounter) Snapshot() DateCounter {
c.mu.Lock()
defer c.mu.Unlock()
tmp := newStandardDateCounter(c.reserveDays)
for i := 0; i < int(c.reserveDays); i++ {
tmp.counts[i] = c.counts[i]
}
tmp := newStandardDateCounterWithClock(c.reserveDays, c.clock)
copy(tmp.counts, c.counts)
return tmp
}
func (c *StandardDateCounter) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
for i := 0; i < int(c.reserveDays); i++ {
c.counts[i] = 0
}
clear(c.counts)
}
// rotate
// Must hold the lock before calling this function.
func (c *StandardDateCounter) rotate(now time.Time) {
now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
now = startOfDay(now)
days := int(now.Sub(c.lastUpdateDate).Hours() / 24)
defer func() {
c.lastUpdateDate = now
}()
reserveDays := int(c.reserveDays)
if days <= 0 {
return
} else if days >= int(c.reserveDays) {
} else if days >= reserveDays {
c.counts = make([]int64, c.reserveDays)
c.lastUpdateDate = now
return
}
newCounts := make([]int64, c.reserveDays)
for i := days; i < int(c.reserveDays); i++ {
newCounts[i] = c.counts[i-days]
}
copy(newCounts[days:], c.counts[:reserveDays-days])
c.counts = newCounts
c.lastUpdateDate = now
}
// startOfDay returns midnight in t's location.
func startOfDay(t time.Time) time.Time {
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
}

View file

@ -1,9 +1,12 @@
package metric
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
clocktesting "k8s.io/utils/clock/testing"
)
func TestDateCounter(t *testing.T) {
@ -25,3 +28,107 @@ func TestDateCounter(t *testing.T) {
dcTmp := dc.Snapshot()
require.EqualValues(5, dcTmp.TodayCount())
}
func TestDateCounterRotate(t *testing.T) {
loc := time.FixedZone("test", 8*60*60)
lastUpdateDate := time.Date(2026, time.May, 8, 0, 0, 0, 0, loc)
tests := []struct {
name string
now time.Time
want []int64
wantLastUpdateDate time.Time
}{
{
name: "same day",
now: time.Date(2026, time.May, 8, 12, 30, 0, 0, loc),
want: []int64{10, 7, 3},
wantLastUpdateDate: lastUpdateDate,
},
{
name: "clock skew",
now: time.Date(2026, time.May, 7, 12, 30, 0, 0, loc),
want: []int64{10, 7, 3},
wantLastUpdateDate: lastUpdateDate,
},
{
name: "one day",
now: time.Date(2026, time.May, 9, 12, 30, 0, 0, loc),
want: []int64{0, 10, 7},
wantLastUpdateDate: time.Date(2026, time.May, 9, 0, 0, 0, 0, loc),
},
{
name: "two days",
now: time.Date(2026, time.May, 10, 12, 30, 0, 0, loc),
want: []int64{0, 0, 10},
wantLastUpdateDate: time.Date(2026, time.May, 10, 0, 0, 0, 0, loc),
},
{
name: "all reserved days elapsed",
now: time.Date(2026, time.May, 11, 12, 30, 0, 0, loc),
want: []int64{0, 0, 0},
wantLastUpdateDate: time.Date(2026, time.May, 11, 0, 0, 0, 0, loc),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
dc := newStandardDateCounter(3)
dc.counts = []int64{10, 7, 3}
dc.lastUpdateDate = lastUpdateDate
dc.mu.Lock()
dc.rotate(tt.now)
dc.mu.Unlock()
require.Equal(tt.want, dc.counts)
require.Equal(tt.wantLastUpdateDate, dc.lastUpdateDate)
})
}
}
func TestDateCounterGetLastDaysCountReturnsCopy(t *testing.T) {
require := require.New(t)
clk := clocktesting.NewFakeClock(time.Date(2026, time.May, 8, 12, 30, 0, 0, time.Local))
dc := newStandardDateCounterWithClock(3, clk)
dc.counts = []int64{10, 7, 3}
counts := dc.GetLastDaysCount(2)
require.Equal([]int64{10, 7}, counts)
counts[0] = 100
require.Equal([]int64{10, 7}, dc.GetLastDaysCount(2))
}
func TestDateCounterClear(t *testing.T) {
require := require.New(t)
dc := newStandardDateCounter(3)
dc.counts = []int64{10, 7, 3}
dc.Clear()
require.Equal([]int64{0, 0, 0}, dc.counts)
}
func TestDateCounterConcurrentAccess(t *testing.T) {
clk := clocktesting.NewFakeClock(time.Date(2026, time.May, 8, 12, 30, 0, 0, time.Local))
dc := newStandardDateCounterWithClock(3, clk)
var wg sync.WaitGroup
for range 8 {
wg.Go(func() {
for range 100 {
dc.Inc(1)
dc.Dec(1)
_ = dc.TodayCount()
_ = dc.GetLastDaysCount(3)
_ = dc.Snapshot()
}
})
}
wg.Wait()
}

View file

@ -39,11 +39,14 @@ type HTTPConnectTCPMuxer struct {
func NewHTTPConnectTCPMuxer(listener net.Listener, passthrough bool, timeout time.Duration) (*HTTPConnectTCPMuxer, error) {
ret := &HTTPConnectTCPMuxer{passthrough: passthrough}
mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, timeout)
if err != nil {
return nil, err
}
mux.SetCheckAuthFunc(ret.auth).
SetSuccessHookFunc(ret.sendConnectResponse).
SetFailHookFunc(vhostFailed)
ret.Muxer = mux
return ret, err
return ret, nil
}
func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host, httpUser, httpPwd string, err error) {

View file

@ -24,7 +24,6 @@ import (
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
libio "github.com/fatedier/golib/io"
@ -160,7 +159,7 @@ func (rp *HTTPReverseProxy) UnRegister(routeCfg RouteConfig) {
}
func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser string) *RouteConfig {
vr, ok := rp.getVhost(domain, location, routeByHTTPUser)
vr, ok := rp.vhostRouter.getByRoute(domain, location, routeByHTTPUser)
if ok {
log.Debugf("get new http request host [%s] path [%s] httpuser [%s]", domain, location, routeByHTTPUser)
return vr.payload.(*RouteConfig)
@ -171,7 +170,7 @@ func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser str
// CreateConnection create a new connection by route config
func (rp *HTTPReverseProxy) CreateConnection(reqRouteInfo *RequestRouteInfo, byEndpoint bool) (net.Conn, error) {
host, _ := httppkg.CanonicalHost(reqRouteInfo.Host)
vr, ok := rp.getVhost(host, reqRouteInfo.URL, reqRouteInfo.HTTPUser)
vr, ok := rp.vhostRouter.getByRoute(host, reqRouteInfo.URL, reqRouteInfo.HTTPUser)
if ok {
if byEndpoint {
fn := vr.payload.(*RouteConfig).CreateConnByEndpointFn
@ -208,50 +207,6 @@ func checkRouteAuthByRequest(req *http.Request, rc *RouteConfig) bool {
return ok && user == rc.Username && passwd == rc.Password
}
// getVhost tries to get vhost router by route policy.
func (rp *HTTPReverseProxy) getVhost(domain, location, routeByHTTPUser string) (*Router, bool) {
findRouter := func(inDomain, inLocation, inRouteByHTTPUser string) (*Router, bool) {
vr, ok := rp.vhostRouter.Get(inDomain, inLocation, inRouteByHTTPUser)
if ok {
return vr, ok
}
// Try to check if there is one proxy that doesn't specify routerByHTTPUser, it means match all.
vr, ok = rp.vhostRouter.Get(inDomain, inLocation, "")
if ok {
return vr, ok
}
return nil, false
}
// First we check the full hostname
// if not exist, then check the wildcard_domain such as *.example.com
vr, ok := findRouter(domain, location, routeByHTTPUser)
if ok {
return vr, ok
}
// e.g. domain = test.example.com, try to match wildcard domains.
// *.example.com
// *.com
domainSplit := strings.Split(domain, ".")
for len(domainSplit) >= 3 {
domainSplit[0] = "*"
domain = strings.Join(domainSplit, ".")
vr, ok = findRouter(domain, location, routeByHTTPUser)
if ok {
return vr, true
}
domainSplit = domainSplit[1:]
}
// Finally, try to check if there is one proxy that domain is "*" means match all domains.
vr, ok = findRouter("*", location, routeByHTTPUser)
if ok {
return vr, true
}
return nil, false
}
func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Request) {
hj, ok := rw.(http.Hijacker)
if !ok {

View file

@ -29,11 +29,11 @@ type HTTPSMuxer struct {
func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, error) {
mux, err := NewMuxer(listener, GetHTTPSHostname, timeout)
mux.SetFailHookFunc(vhostFailed)
if err != nil {
return nil, err
}
return &HTTPSMuxer{mux}, err
mux.SetFailHookFunc(vhostFailed)
return &HTTPSMuxer{mux}, nil
}
func GetHTTPSHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) {

View file

@ -16,23 +16,53 @@ func TestGetHTTPSHostname(t *testing.T) {
require.NoError(err)
defer l.Close()
var conn net.Conn
connCh := make(chan net.Conn, 1)
acceptErrCh := make(chan error, 1)
go func() {
conn, _ = l.Accept()
require.NotNil(conn)
conn, err := l.Accept()
if err != nil {
acceptErrCh <- err
return
}
connCh <- conn
}()
clientErrCh := make(chan error, 1)
go func() {
time.Sleep(100 * time.Millisecond)
tls.Dial("tcp", l.Addr().String(), &tls.Config{
conn, err := tls.Dial("tcp", l.Addr().String(), &tls.Config{
InsecureSkipVerify: true,
ServerName: "example.com",
})
if conn != nil {
_ = conn.Close()
}
clientErrCh <- err
}()
time.Sleep(200 * time.Millisecond)
_, infos, err := GetHTTPSHostname(conn)
var conn net.Conn
select {
case conn = <-connCh:
case err := <-acceptErrCh:
require.NoError(err)
case <-time.After(time.Second):
t.Fatal("timed out waiting for accepted connection")
}
require.NotNil(conn)
serverConn, infos, err := GetHTTPSHostname(conn)
if serverConn != nil {
_ = serverConn.Close()
} else {
_ = conn.Close()
}
require.NoError(err)
require.Equal("example.com", infos["Host"])
require.Equal("https", infos["Scheme"])
select {
case <-clientErrCh:
case <-time.After(time.Second):
t.Fatal("timed out waiting for TLS client")
}
}

View file

@ -93,12 +93,19 @@ func (r *Routers) Del(domain, location, httpUser string) {
routersByHTTPUser[httpUser] = newVrs
}
// Get returns the best location match for an exact host and exact HTTP user.
// It does not apply all-users, wildcard-domain, or catch-all-domain fallback.
func (r *Routers) Get(host, path, httpUser string) (vr *Router, exist bool) {
host = strings.ToLower(host)
r.mutex.RLock()
defer r.mutex.RUnlock()
return r.getLocked(host, path, httpUser)
}
// getLocked performs an exact-host lookup; host must already be lower-cased.
func (r *Routers) getLocked(host, path, httpUser string) (vr *Router, exist bool) {
routersByHTTPUser, found := r.indexByDomain[host]
if !found {
return
@ -117,6 +124,49 @@ func (r *Routers) Get(host, path, httpUser string) (vr *Router, exist bool) {
return
}
func (r *Routers) getByRoute(host, path, httpUser string) (*Router, bool) {
host = strings.ToLower(host)
r.mutex.RLock()
defer r.mutex.RUnlock()
// First we check the full hostname; if it doesn't exist, then check wildcard domains.
// For example, test.example.com checks *.example.com before falling back to "*".
vr, ok := r.getExactOrAllUsersLocked(host, path, httpUser)
if ok {
return vr, true
}
hostSplit := strings.Split(host, ".")
// Keep two-label hosts out of the wildcard walk, so example.com does not match *.com.
for len(hostSplit) >= 3 {
// Replace the leftmost remaining label with the wildcard marker.
hostSplit[0] = "*"
host = strings.Join(hostSplit, ".")
vr, ok = r.getExactOrAllUsersLocked(host, path, httpUser)
if ok {
return vr, true
}
hostSplit = hostSplit[1:]
}
// Finally, try to check if there is one proxy whose domain is "*", which means match all domains.
return r.getExactOrAllUsersLocked("*", path, httpUser)
}
func (r *Routers) getExactOrAllUsersLocked(host, path, httpUser string) (*Router, bool) {
vr, ok := r.getLocked(host, path, httpUser)
if ok {
return vr, true
}
// Try to check if there is one proxy that doesn't specify routeByHTTPUser, it means match all.
vr, ok = r.getLocked(host, path, "")
if ok {
return vr, true
}
return nil, false
}
func (r *Routers) exist(host, path, httpUser string) (route *Router, exist bool) {
routersByHTTPUser, found := r.indexByDomain[host]
if !found {

View file

@ -0,0 +1,257 @@
// Copyright 2017 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package vhost
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestRoutersGet(t *testing.T) {
routers := NewRouters()
require.NoError(t, routers.Add("example.com", "/api", "alice", "exact-user"))
require.NoError(t, routers.Add("example.com", "/public", "", "exact-all-users"))
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users"))
require.NoError(t, routers.Add("*", "/api", "", "all-domains"))
t.Run("exact host and user match with normalized domain", func(t *testing.T) {
router, ok := routers.Get("EXAMPLE.COM", "/api/users", "alice")
require.True(t, ok)
require.Equal(t, "exact-user", router.payload)
})
t.Run("exact host and empty user match", func(t *testing.T) {
router, ok := routers.Get("EXAMPLE.COM", "/public/docs", "")
require.True(t, ok)
require.Equal(t, "exact-all-users", router.payload)
})
t.Run("does not fall back from named user to empty user", func(t *testing.T) {
// Get intentionally requires an exact HTTP user; route-level fallbacks live in getByRoute.
_, ok := routers.Get("EXAMPLE.COM", "/public/docs", "alice")
require.False(t, ok)
})
t.Run("does not fall back to wildcard domains", func(t *testing.T) {
_, ok := routers.Get("foo.example.com", "/api/users", "")
require.False(t, ok)
})
t.Run("does not fall back to catch-all domain", func(t *testing.T) {
_, ok := routers.Get("missing.test", "/api/users", "")
require.False(t, ok)
})
}
func TestRoutersGetByRoute(t *testing.T) {
routers := NewRouters()
require.NoError(t, routers.Add("example.com", "/api", "alice", "exact-user"))
require.NoError(t, routers.Add("example.com", "/api", "", "exact-all-users"))
require.NoError(t, routers.Add("exact.example.com", "/api", "", "exact-subdomain"))
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users"))
require.NoError(t, routers.Add("*.foo.example.com", "/api", "", "specific-wildcard"))
require.NoError(t, routers.Add("*.bar.com", "/api", "", "wildcard-parent-domain"))
require.NoError(t, routers.Add("*", "/admin", "root", "all-domains-user"))
require.NoError(t, routers.Add("*", "/", "", "all-domains"))
tests := []struct {
name string
domain string
location string
httpUser string
want string
}{
{
name: "exact domain and http user",
domain: "example.com",
location: "/api/users",
httpUser: "alice",
want: "exact-user",
},
{
name: "exact domain falls back to all users",
domain: "example.com",
location: "/api/users",
httpUser: "bob",
want: "exact-all-users",
},
{
name: "wildcard domain uses all users fallback",
domain: "foo.example.com",
location: "/api/users",
httpUser: "bob",
want: "wildcard-all-users",
},
{
name: "mixed-case domain is normalized",
domain: "Foo.Example.Com",
location: "/api/users",
httpUser: "bob",
want: "wildcard-all-users",
},
{
name: "exact domain wins over wildcard domain",
domain: "exact.example.com",
location: "/api/users",
httpUser: "bob",
want: "exact-subdomain",
},
{
name: "more specific wildcard wins over broader wildcard",
domain: "bar.foo.example.com",
location: "/api/users",
httpUser: "bob",
want: "specific-wildcard",
},
{
name: "wildcard walk checks parent domains",
domain: "a.b.bar.com",
location: "/api/users",
httpUser: "bob",
want: "wildcard-parent-domain",
},
{
name: "catch-all domain fallback",
domain: "foo.test.com",
location: "/other",
httpUser: "bob",
want: "all-domains",
},
{
name: "catch-all domain honors http user",
domain: "foo.test.com",
location: "/admin/panel",
httpUser: "root",
want: "all-domains-user",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router, ok := routers.getByRoute(tt.domain, tt.location, tt.httpUser)
require.True(t, ok)
require.Equal(t, tt.want, router.payload)
})
}
}
func TestRoutersGetByRouteNoMatch(t *testing.T) {
routers := NewRouters()
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users"))
require.NoError(t, routers.Add("*.com", "/api", "", "top-level-wildcard"))
tests := []struct {
name string
domain string
location string
}{
{
name: "two-label domain does not enter wildcard walk",
domain: "example.com",
location: "/api/users",
},
{
name: "missing catch-all remains no match",
domain: "foo.test.com",
location: "/api/users",
},
{
name: "wrong path remains no match",
domain: "foo.example.com",
location: "/other",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router, ok := routers.getByRoute(tt.domain, tt.location, "bob")
require.False(t, ok)
require.Nil(t, router)
})
}
}
func TestRoutersConcurrentGetByRouteAndAdd(t *testing.T) {
routers := NewRouters()
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard"))
const readers = 8
const iterations = 200
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
start := make(chan struct{})
errCh := make(chan error, readers+1)
var wg sync.WaitGroup
for id := range readers {
wg.Go(func() {
<-start
for j := range iterations {
select {
case <-ctx.Done():
return
default:
}
router, ok := routers.getByRoute("foo.example.com", "/api/users", "")
if !ok || router == nil || router.payload != "wildcard" {
errCh <- fmt.Errorf("reader %d iteration %d got router=%v ok=%v", id, j, router, ok)
return
}
}
})
}
wg.Go(func() {
<-start
for i := range iterations {
select {
case <-ctx.Done():
return
default:
}
err := routers.Add(fmt.Sprintf("host-%d.example.com", i), "/api", "", i)
if err != nil {
errCh <- err
return
}
}
})
close(start)
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
cancel()
t.Fatal("concurrent route lookup and add timed out")
}
close(errCh)
for err := range errCh {
require.NoError(t, err)
}
}

View file

@ -148,42 +148,10 @@ func (v *Muxer) Listen(ctx context.Context, cfg *RouteConfig) (l *Listener, err
}
func (v *Muxer) getListener(name, path, httpUser string) (*Listener, bool) {
findRouter := func(inName, inPath, inHTTPUser string) (*Listener, bool) {
vr, ok := v.registryRouter.Get(inName, inPath, inHTTPUser)
vr, ok := v.registryRouter.getByRoute(name, path, httpUser)
if ok {
return vr.payload.(*Listener), true
}
// Try to check if there is one proxy that doesn't specify routerByHTTPUser, it means match all.
vr, ok = v.registryRouter.Get(inName, inPath, "")
if ok {
return vr.payload.(*Listener), true
}
return nil, false
}
// first we check the full hostname
// if not exist, then check the wildcard_domain such as *.example.com
l, ok := findRouter(name, path, httpUser)
if ok {
return l, true
}
domainSplit := strings.Split(name, ".")
for len(domainSplit) >= 3 {
domainSplit[0] = "*"
name = strings.Join(domainSplit, ".")
l, ok = findRouter(name, path, httpUser)
if ok {
return l, true
}
domainSplit = domainSplit[1:]
}
// Finally, try to check if there is one proxy that domain is "*" means match all domains.
l, ok = findRouter("*", path, httpUser)
if ok {
return l, true
}
return nil, false
}

View file

@ -18,6 +18,8 @@ import (
"math/rand/v2"
"time"
"k8s.io/utils/clock"
"github.com/fatedier/frp/pkg/util/util"
)
@ -48,6 +50,7 @@ type FastBackoffOptions struct {
type fastBackoffImpl struct {
options FastBackoffOptions
clock clock.PassiveClock
lastCalledTime time.Time
consecutiveErrCount int
@ -57,18 +60,26 @@ type fastBackoffImpl struct {
}
func NewFastBackoffManager(options FastBackoffOptions) BackoffManager {
return newFastBackoffManagerWithClock(options, clock.RealClock{})
}
func newFastBackoffManagerWithClock(options FastBackoffOptions, clk clock.PassiveClock) BackoffManager {
if clk == nil {
clk = clock.RealClock{}
}
return &fastBackoffImpl{
options: options,
clock: clk,
countsInFastRetryWindow: 1,
}
}
func (f *fastBackoffImpl) Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration {
if f.lastCalledTime.IsZero() {
f.lastCalledTime = time.Now()
f.lastCalledTime = f.clock.Now()
return f.options.Duration
}
now := time.Now()
now := f.clock.Now()
f.lastCalledTime = now
if previousConditionError {

View file

@ -0,0 +1,27 @@
package wait
import (
"testing"
"time"
"github.com/stretchr/testify/require"
clocktesting "k8s.io/utils/clock/testing"
)
func TestFastBackoffManagerUsesClock(t *testing.T) {
require := require.New(t)
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
clk := clocktesting.NewFakeClock(start)
backoff := newFastBackoffManagerWithClock(FastBackoffOptions{
Duration: time.Second,
}, clk).(*fastBackoffImpl)
require.Equal(time.Second, backoff.Backoff(0, false))
require.Equal(start, backoff.lastCalledTime)
next := start.Add(time.Minute)
clk.SetTime(next)
require.Equal(time.Second, backoff.Backoff(time.Second, false))
require.Equal(next, backoff.lastCalledTime)
}

View file

@ -19,7 +19,6 @@ import "strings"
// LogWriter forwards writes to frp's logger at configurable level.
// It is safe for concurrent use as long as the underlying Logger is thread-safe.
type LogWriter struct {
xl *Logger
logFunc func(string)
}
@ -31,35 +30,30 @@ func (w LogWriter) Write(p []byte) (n int, err error) {
func NewTraceWriter(xl *Logger) LogWriter {
return LogWriter{
xl: xl,
logFunc: func(msg string) { xl.Tracef("%s", msg) },
}
}
func NewDebugWriter(xl *Logger) LogWriter {
return LogWriter{
xl: xl,
logFunc: func(msg string) { xl.Debugf("%s", msg) },
}
}
func NewInfoWriter(xl *Logger) LogWriter {
return LogWriter{
xl: xl,
logFunc: func(msg string) { xl.Infof("%s", msg) },
}
}
func NewWarnWriter(xl *Logger) LogWriter {
return LogWriter{
xl: xl,
logFunc: func(msg string) { xl.Warnf("%s", msg) },
}
}
func NewErrorWriter(xl *Logger) LogWriter {
return LogWriter{
xl: xl,
logFunc: func(msg string) { xl.Errorf("%s", msg) },
}
}

View file

@ -286,7 +286,6 @@ func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWri
r.mu.Lock()
defer r.mu.Unlock()
r.routes[name] = &routeElement{
name: name,
routes: routes,
conn: conn,
}
@ -383,7 +382,6 @@ func (r *serverRouter) cleanupConnIPs(conn io.Writer) {
}
type routeElement struct {
name string
routes []net.IPNet
conn io.ReadWriteCloser
}

View file

@ -7,6 +7,8 @@ import (
"sync"
"time"
"k8s.io/utils/clock"
"github.com/fatedier/frp/pkg/config/types"
)
@ -38,16 +40,25 @@ type Manager struct {
bindAddr string
netType string
clock clock.WithTicker
mu sync.Mutex
}
func NewManager(netType string, bindAddr string, allowPorts []types.PortsRange) *Manager {
return newManagerWithClock(netType, bindAddr, allowPorts, clock.RealClock{})
}
func newManagerWithClock(netType string, bindAddr string, allowPorts []types.PortsRange, clk clock.WithTicker) *Manager {
if clk == nil {
clk = clock.RealClock{}
}
pm := &Manager{
reservedPorts: make(map[string]*PortCtx),
usedPorts: make(map[int]*PortCtx),
freePorts: make(map[int]struct{}),
bindAddr: bindAddr,
netType: netType,
clock: clk,
}
if len(allowPorts) > 0 {
for _, pair := range allowPorts {
@ -72,7 +83,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
portCtx := &PortCtx{
ProxyName: name,
Closed: false,
UpdateTime: time.Now(),
UpdateTime: pm.clock.Now(),
}
var ok bool
@ -90,9 +101,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
if ctx, ok := pm.reservedPorts[name]; ok {
if pm.isPortAvailable(ctx.Port) {
realPort = ctx.Port
pm.usedPorts[realPort] = portCtx
pm.reservedPorts[name] = portCtx
delete(pm.freePorts, realPort)
pm.markPortAcquiredLocked(name, realPort, portCtx)
return
}
}
@ -109,9 +118,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
}
if pm.isPortAvailable(k) {
realPort = k
pm.usedPorts[realPort] = portCtx
pm.reservedPorts[name] = portCtx
delete(pm.freePorts, realPort)
pm.markPortAcquiredLocked(name, realPort, portCtx)
break
}
}
@ -123,9 +130,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
if _, ok = pm.freePorts[port]; ok {
if pm.isPortAvailable(port) {
realPort = port
pm.usedPorts[realPort] = portCtx
pm.reservedPorts[name] = portCtx
delete(pm.freePorts, realPort)
pm.markPortAcquiredLocked(name, realPort, portCtx)
} else {
err = ErrPortUnAvailable
}
@ -140,6 +145,13 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
return
}
// markPortAcquiredLocked records a successful acquisition. pm.mu must be held.
func (pm *Manager) markPortAcquiredLocked(name string, port int, portCtx *PortCtx) {
pm.usedPorts[port] = portCtx
pm.reservedPorts[name] = portCtx
delete(pm.freePorts, port)
}
func (pm *Manager) isPortAvailable(port int) bool {
if pm.netType == "udp" {
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pm.bindAddr, strconv.Itoa(port)))
@ -169,20 +181,36 @@ func (pm *Manager) Release(port int) {
pm.freePorts[port] = struct{}{}
delete(pm.usedPorts, port)
ctx.Closed = true
ctx.UpdateTime = time.Now()
ctx.UpdateTime = pm.clock.Now()
}
}
// Release reserved port if it isn't used in last 24 hours.
func (pm *Manager) cleanReservedPortsWorker() {
pm.cleanReservedPortsWorkerUntil(nil)
}
func (pm *Manager) cleanReservedPortsWorkerUntil(stopCh <-chan struct{}) {
ticker := pm.clock.NewTicker(CleanReservedPortsInterval)
defer ticker.Stop()
for {
time.Sleep(CleanReservedPortsInterval)
select {
case <-ticker.C():
pm.cleanReservedPortsOnce()
case <-stopCh:
return
}
}
}
func (pm *Manager) cleanReservedPortsOnce() {
pm.mu.Lock()
defer pm.mu.Unlock()
for name, ctx := range pm.reservedPorts {
if ctx.Closed && time.Since(ctx.UpdateTime) > MaxPortReservedDuration {
if ctx.Closed && pm.clock.Since(ctx.UpdateTime) > MaxPortReservedDuration {
delete(pm.reservedPorts, name)
}
}
pm.mu.Unlock()
}
}

View file

@ -0,0 +1,72 @@
package ports
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
clocktesting "k8s.io/utils/clock/testing"
"github.com/fatedier/frp/pkg/config/types"
)
func TestManagerUsesClockForPortTimestamps(t *testing.T) {
require := require.New(t)
port := freeTCPPort(t)
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
clk := clocktesting.NewFakeClock(start)
pm := newManagerWithClock("tcp", "127.0.0.1", []types.PortsRange{{Single: port}}, clk)
realPort, err := pm.Acquire("proxy", port)
require.NoError(err)
require.Equal(port, realPort)
require.Equal(start, pm.usedPorts[port].UpdateTime)
releasedAt := start.Add(time.Minute)
clk.SetTime(releasedAt)
pm.Release(port)
require.Equal(releasedAt, pm.reservedPorts["proxy"].UpdateTime)
}
func TestManagerCleanReservedPortsWorkerUsesClockTicker(t *testing.T) {
require := require.New(t)
port := freeTCPPort(t)
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
clk := clocktesting.NewFakeClock(start)
pm := newManagerWithClock("tcp", "127.0.0.1", []types.PortsRange{{Single: port}}, clk)
realPort, err := pm.Acquire("proxy", port)
require.NoError(err)
require.Equal(port, realPort)
pm.Release(port)
require.True(pm.hasReservedPort("proxy"))
require.Eventually(clk.HasWaiters, time.Second, time.Millisecond)
clk.Step(MaxPortReservedDuration + CleanReservedPortsInterval + time.Minute)
require.Eventually(func() bool {
return !pm.hasReservedPort("proxy")
}, time.Second, time.Millisecond)
}
func (pm *Manager) hasReservedPort(name string) bool {
pm.mu.Lock()
defer pm.mu.Unlock()
_, ok := pm.reservedPorts[name]
return ok
}
func freeTCPPort(t *testing.T) int {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
return listener.Addr().(*net.TCPAddr).Port
}

View file

@ -18,6 +18,8 @@ import (
"fmt"
"sync"
"time"
"k8s.io/utils/clock"
)
// ClientInfo captures metadata about a connected frpc instance.
@ -42,12 +44,21 @@ type ClientRegistry struct {
mu sync.RWMutex
clients map[string]*ClientInfo
runIndex map[string]string
clock clock.PassiveClock
}
func NewClientRegistry() *ClientRegistry {
return newClientRegistryWithClock(clock.RealClock{})
}
func newClientRegistryWithClock(clk clock.PassiveClock) *ClientRegistry {
if clk == nil {
clk = clock.RealClock{}
}
return &ClientRegistry{
clients: make(map[string]*ClientInfo),
runIndex: make(map[string]string),
clock: clk,
}
}
@ -64,7 +75,7 @@ func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version,
key = cr.composeClientKey(user, effectiveID)
enforceUnique := rawClientID != ""
now := time.Now()
now := cr.clock.Now()
cr.mu.Lock()
defer cr.mu.Unlock()
@ -116,7 +127,7 @@ func (cr *ClientRegistry) MarkOfflineByRunID(runID string) {
} else {
info.RunID = ""
info.Online = false
now := time.Now()
now := cr.clock.Now()
info.DisconnectedAt = now
}
}

View file

@ -16,6 +16,9 @@ package registry
import (
"testing"
"time"
clocktesting "k8s.io/utils/clock/testing"
"github.com/fatedier/frp/pkg/proto/wire"
)
@ -35,3 +38,37 @@ func TestClientRegistryRegisterStoresWireProtocol(t *testing.T) {
t.Fatalf("wire protocol mismatch, want %q got %q", wire.ProtocolV2, info.WireProtocol)
}
}
func TestClientRegistryUsesClockForTimestamps(t *testing.T) {
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
clk := clocktesting.NewFakeClock(start)
registry := newClientRegistryWithClock(clk)
key, conflict := registry.Register("user", "client-id", "run-id", "host", "1.0.0", "127.0.0.1", wire.ProtocolV2)
if conflict {
t.Fatal("unexpected client conflict")
}
info, ok := registry.GetByKey(key)
if !ok {
t.Fatalf("client %q not found", key)
}
if !info.FirstConnectedAt.Equal(start) {
t.Fatalf("first connected time mismatch, want %s got %s", start, info.FirstConnectedAt)
}
if !info.LastConnectedAt.Equal(start) {
t.Fatalf("last connected time mismatch, want %s got %s", start, info.LastConnectedAt)
}
disconnectedAt := start.Add(time.Minute)
clk.SetTime(disconnectedAt)
registry.MarkOfflineByRunID("run-id")
info, ok = registry.GetByKey(key)
if !ok {
t.Fatalf("client %q not found after disconnect", key)
}
if !info.DisconnectedAt.Equal(disconnectedAt) {
t.Fatalf("disconnected time mismatch, want %s got %s", disconnectedAt, info.DisconnectedAt)
}
}

View file

@ -1,62 +0,0 @@
package framework
import (
"sync"
)
// CleanupActionHandle is an integer pointer type for handling cleanup action
type CleanupActionHandle *int
type cleanupFuncHandle struct {
actionHandle CleanupActionHandle
actionHook func()
}
var (
cleanupActionsLock sync.Mutex
cleanupHookList = []cleanupFuncHandle{}
)
// AddCleanupAction installs a function that will be called in the event of the
// whole test being terminated. This allows arbitrary pieces of the overall
// test to hook into SynchronizedAfterSuite().
// The hooks are called in last-in-first-out order.
func AddCleanupAction(fn func()) CleanupActionHandle {
p := CleanupActionHandle(new(int))
cleanupActionsLock.Lock()
defer cleanupActionsLock.Unlock()
c := cleanupFuncHandle{actionHandle: p, actionHook: fn}
cleanupHookList = append([]cleanupFuncHandle{c}, cleanupHookList...)
return p
}
// RemoveCleanupAction removes a function that was installed by
// AddCleanupAction.
func RemoveCleanupAction(p CleanupActionHandle) {
cleanupActionsLock.Lock()
defer cleanupActionsLock.Unlock()
for i, item := range cleanupHookList {
if item.actionHandle == p {
cleanupHookList = append(cleanupHookList[:i], cleanupHookList[i+1:]...)
break
}
}
}
// RunCleanupActions runs all functions installed by AddCleanupAction. It does
// not remove them (see RemoveCleanupAction) but it does run unlocked, so they
// may remove themselves.
func RunCleanupActions() {
list := []func(){}
func() {
cleanupActionsLock.Lock()
defer cleanupActionsLock.Unlock()
for _, p := range cleanupHookList {
list = append(list, p.actionHook)
}
}()
// Run unlocked.
for _, fn := range list {
fn()
}
}

View file

@ -23,11 +23,6 @@ func ExpectNotEqual(actual any, extra any, explain ...any) {
gomega.ExpectWithOffset(1, actual).NotTo(gomega.Equal(extra), explain...)
}
// ExpectError expects an error happens, otherwise an exception raises
func ExpectError(err error, explain ...any) {
gomega.ExpectWithOffset(1, err).To(gomega.HaveOccurred(), explain...)
}
func ExpectErrorWithOffset(offset int, err error, explain ...any) {
gomega.ExpectWithOffset(1+offset, err).To(gomega.HaveOccurred(), explain...)
}
@ -47,11 +42,6 @@ func ExpectContainSubstring(actual, substr string, explain ...any) {
gomega.ExpectWithOffset(1, actual).To(gomega.ContainSubstring(substr), explain...)
}
// ExpectConsistOf expects actual contains precisely the extra elements. The ordering of the elements does not matter.
func ExpectConsistOf(actual any, extra any, explain ...any) {
gomega.ExpectWithOffset(1, actual).To(gomega.ConsistOf(extra), explain...)
}
func ExpectContainElements(actual any, extra any, explain ...any) {
gomega.ExpectWithOffset(1, actual).To(gomega.ContainElements(extra), explain...)
}
@ -60,16 +50,6 @@ func ExpectNotContainElements(actual any, extra any, explain ...any) {
gomega.ExpectWithOffset(1, actual).NotTo(gomega.ContainElements(extra), explain...)
}
// ExpectHaveKey expects the actual map has the key in the keyset
func ExpectHaveKey(actual any, key any, explain ...any) {
gomega.ExpectWithOffset(1, actual).To(gomega.HaveKey(key), explain...)
}
// ExpectEmpty expects actual is empty
func ExpectEmpty(actual any, explain ...any) {
gomega.ExpectWithOffset(1, actual).To(gomega.BeEmpty(), explain...)
}
func ExpectTrue(actual any, explain ...any) {
gomega.ExpectWithOffset(1, actual).Should(gomega.BeTrue(), explain...)
}

View file

@ -38,11 +38,6 @@ type Framework struct {
// Multiple default mock servers used for e2e testing.
mockServers *MockServers
// To make sure that this framework cleans up after itself, no matter what,
// we install a Cleanup action before each test and clear it after. If we
// should abort, the AfterSuite hook should run all Cleanup actions.
cleanupHandle CleanupActionHandle
// beforeEachStarted indicates that BeforeEach has started
beforeEachStarted bool
@ -87,8 +82,6 @@ func NewFramework(opt Options) *Framework {
func (f *Framework) BeforeEach() {
f.beforeEachStarted = true
f.cleanupHandle = AddCleanupAction(f.AfterEach)
dir, err := os.MkdirTemp(os.TempDir(), "frp-e2e-test-*")
ExpectNoError(err)
f.TempDirectory = dir
@ -113,8 +106,6 @@ func (f *Framework) AfterEach() {
return
}
RemoveCleanupAction(f.cleanupHandle)
// stop processor
for _, p := range f.serverProcesses {
_ = p.Stop()
@ -266,10 +257,6 @@ func (f *Framework) AllocPortExcludingRanges(ranges ...[2]int) int {
return 0
}
func (f *Framework) ReleasePort(port int) {
f.portAllocator.Release(port)
}
func (f *Framework) RunServer(portName string, s server.Server) {
f.servers = append(f.servers, s)
if s.BindPort() > 0 && portName != "" {

View file

@ -75,11 +75,3 @@ func (m *MockServers) GetTemplateParams() map[string]any {
ret[HTTPSimpleServerPort] = m.httpSimpleServer.BindPort()
return ret
}
func (m *MockServers) GetParam(key string) any {
params := m.GetTemplateParams()
if v, ok := params[key]; ok {
return v
}
return nil
}

View file

@ -124,7 +124,3 @@ func (e *RequestExpect) Ensure(fns ...EnsureFunc) {
}
}
}
func (e *RequestExpect) Do() (*request.Response, error) {
return e.req.Do()
}

View file

@ -31,13 +31,6 @@ func New(options ...Option) *Server {
return s
}
func WithBindAddr(addr string) Option {
return func(s *Server) *Server {
s.bindAddr = addr
return s
}
}
func WithBindPort(port int) Option {
return func(s *Server) *Server {
s.bindPort = port

View file

@ -61,21 +61,6 @@ func WithBindPort(port int) Option {
return func(s *Server) { s.bindPort = port }
}
func WithClientCredentials(id, secret string) Option {
return func(s *Server) {
s.clientID = id
s.clientSecret = secret
}
}
func WithAudience(aud string) Option {
return func(s *Server) { s.audience = aud }
}
func WithSubject(sub string) Option {
return func(s *Server) { s.subject = sub }
}
func WithExpiresIn(seconds int) Option {
return func(s *Server) { s.expiresIn = seconds }
}

View file

@ -41,14 +41,9 @@ type Process struct {
waitErr error
started bool
beforeStopHandler func()
stopped bool
}
func New(path string, params []string) *Process {
return NewWithEnvs(path, params, nil)
}
func NewWithEnvs(path string, params []string, envs []string) *Process {
ctx, cancel := context.WithCancel(context.Background())
cmd := exec.CommandContext(ctx, path, params...)
@ -100,9 +95,6 @@ func (p *Process) Stop() error {
defer func() {
p.stopped = true
}()
if p.beforeStopHandler != nil {
p.beforeStopHandler()
}
p.cancel()
<-p.done
return p.waitErr
@ -125,10 +117,6 @@ func (p *Process) CountOutput(pattern string) int {
return strings.Count(p.Output(), pattern)
}
func (p *Process) SetBeforeStopHandler(fn func()) {
p.beforeStopHandler = fn
}
// WaitForOutput polls the combined process output until the pattern is found
// count time(s) or the timeout is reached. It also returns early if the process exits.
func (p *Process) WaitForOutput(pattern string, count int, timeout time.Duration) error {

View file

@ -26,7 +26,6 @@ type Request struct {
port int
body []byte
timeout time.Duration
resolver *net.Resolver
// for http or https
method string
@ -134,11 +133,6 @@ func (r *Request) Body(content []byte) *Request {
return r
}
func (r *Request) Resolver(resolver *net.Resolver) *Request {
r.resolver = resolver
return r
}
func (r *Request) Do() (*Response, error) {
var (
conn net.Conn
@ -169,7 +163,7 @@ func (r *Request) Do() (*Response, error) {
return nil, err
}
} else {
dialer := &net.Dialer{Resolver: r.resolver}
dialer := &net.Dialer{}
switch r.protocol {
case "tcp":
conn, err = dialer.Dial("tcp", addr)
@ -225,7 +219,6 @@ func (r *Request) sendHTTPRequest(method, urlstr string, host string, headers ma
Timeout: time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
Resolver: r.resolver,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,