diff --git a/go.mod b/go.mod index 46bbc2d4..d9d53c84 100644 --- a/go.mod +++ b/go.mod @@ -36,6 +36,7 @@ require ( gopkg.in/ini.v1 v1.67.0 k8s.io/apimachinery v0.28.8 k8s.io/client-go v0.28.8 + k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 ) require ( @@ -75,7 +76,6 @@ require ( google.golang.org/protobuf v1.36.5 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect sigs.k8s.io/yaml v1.3.0 // indirect ) diff --git a/pkg/config/source/config_source.go b/pkg/config/source/config_source.go index ea8a2af6..95c8cfa1 100644 --- a/pkg/config/source/config_source.go +++ b/pkg/config/source/config_source.go @@ -14,11 +14,7 @@ package source -import ( - "fmt" - - v1 "github.com/fatedier/frp/pkg/config/v1" -) +import v1 "github.com/fatedier/frp/pkg/config/v1" // ConfigSource implements Source for in-memory configuration. // All operations are thread-safe. @@ -39,23 +35,17 @@ func (s *ConfigSource) ReplaceAll(proxies []v1.ProxyConfigurer, visitors []v1.Vi nextProxies := make(map[string]v1.ProxyConfigurer, len(proxies)) for _, p := range proxies { - if p == nil { - return fmt.Errorf("proxy cannot be nil") - } - name := p.GetBaseConfig().Name - if name == "" { - return fmt.Errorf("proxy name cannot be empty") + name, err := validateProxyName(p) + if err != nil { + return err } nextProxies[name] = p } nextVisitors := make(map[string]v1.VisitorConfigurer, len(visitors)) for _, v := range visitors { - if v == nil { - return fmt.Errorf("visitor cannot be nil") - } - name := v.GetBaseConfig().Name - if name == "" { - return fmt.Errorf("visitor name cannot be empty") + name, err := validateVisitorName(v) + if err != nil { + return err } nextVisitors[name] = v } diff --git a/pkg/config/source/store.go b/pkg/config/source/store.go index d1bf4fb5..f568f885 100644 --- a/pkg/config/source/store.go +++ b/pkg/config/source/store.go @@ -43,6 +43,11 @@ var ( ErrNotFound = errors.New("not found") ) +const ( + storeKindProxy = "proxy" + storeKindVisitor = "visitor" +) + func NewStoreSource(cfg StoreSourceConfig) (*StoreSource, error) { if cfg.Path == "" { return nil, fmt.Errorf("path is required") @@ -172,79 +177,111 @@ func (s *StoreSource) saveToFileUnlocked() error { return nil } -func (s *StoreSource) AddProxy(proxy v1.ProxyConfigurer) error { - if proxy == nil { - return fmt.Errorf("proxy cannot be nil") +func (s *StoreSource) persistOrRollbackUnlocked(rollback func()) error { + if err := s.saveToFileUnlocked(); err != nil { + rollback() + return fmt.Errorf("failed to persist: %w", err) + } + return nil +} + +// Store map selectors return the target map for generic helpers. +func proxyStoreEntries(s *StoreSource) map[string]v1.ProxyConfigurer { + return s.proxies +} + +func visitorStoreEntries(s *StoreSource) map[string]v1.VisitorConfigurer { + return s.visitors +} + +// Store entry helpers share mutation, persistence, and rollback for proxy and visitor maps. +// T is intentionally limited by callers to v1.ProxyConfigurer or v1.VisitorConfigurer. +func addStoreEntry[T any]( + s *StoreSource, + entriesFn func(*StoreSource) map[string]T, + kind string, + name string, + value T, +) error { + s.mu.Lock() + defer s.mu.Unlock() + + entries := entriesFn(s) + if _, exists := entries[name]; exists { + return fmt.Errorf("%w: %s %q", ErrAlreadyExists, kind, name) } - name := proxy.GetBaseConfig().Name + entries[name] = value + return s.persistOrRollbackUnlocked(func() { + delete(entries, name) + }) +} + +func updateStoreEntry[T any]( + s *StoreSource, + entriesFn func(*StoreSource) map[string]T, + kind string, + name string, + value T, +) error { + s.mu.Lock() + defer s.mu.Unlock() + + entries := entriesFn(s) + old, exists := entries[name] + if !exists { + return fmt.Errorf("%w: %s %q", ErrNotFound, kind, name) + } + + entries[name] = value + return s.persistOrRollbackUnlocked(func() { + entries[name] = old + }) +} + +func removeStoreEntry[T any]( + s *StoreSource, + entriesFn func(*StoreSource) map[string]T, + kind string, + name string, +) error { if name == "" { - return fmt.Errorf("proxy name cannot be empty") + return fmt.Errorf("%s name cannot be empty", kind) } s.mu.Lock() defer s.mu.Unlock() - if _, exists := s.proxies[name]; exists { - return fmt.Errorf("%w: proxy %q", ErrAlreadyExists, name) + entries := entriesFn(s) + old, exists := entries[name] + if !exists { + return fmt.Errorf("%w: %s %q", ErrNotFound, kind, name) } - s.proxies[name] = proxy + delete(entries, name) + return s.persistOrRollbackUnlocked(func() { + entries[name] = old + }) +} - if err := s.saveToFileUnlocked(); err != nil { - delete(s.proxies, name) - return fmt.Errorf("failed to persist: %w", err) +func (s *StoreSource) AddProxy(proxy v1.ProxyConfigurer) error { + name, err := validateProxyName(proxy) + if err != nil { + return err } - return nil + return addStoreEntry(s, proxyStoreEntries, storeKindProxy, name, proxy) } func (s *StoreSource) UpdateProxy(proxy v1.ProxyConfigurer) error { - if proxy == nil { - return fmt.Errorf("proxy cannot be nil") + name, err := validateProxyName(proxy) + if err != nil { + return err } - - name := proxy.GetBaseConfig().Name - if name == "" { - return fmt.Errorf("proxy name cannot be empty") - } - - s.mu.Lock() - defer s.mu.Unlock() - - oldProxy, exists := s.proxies[name] - if !exists { - return fmt.Errorf("%w: proxy %q", ErrNotFound, name) - } - - s.proxies[name] = proxy - - if err := s.saveToFileUnlocked(); err != nil { - s.proxies[name] = oldProxy - return fmt.Errorf("failed to persist: %w", err) - } - return nil + return updateStoreEntry(s, proxyStoreEntries, storeKindProxy, name, proxy) } func (s *StoreSource) RemoveProxy(name string) error { - if name == "" { - return fmt.Errorf("proxy name cannot be empty") - } - - s.mu.Lock() - defer s.mu.Unlock() - - oldProxy, exists := s.proxies[name] - if !exists { - return fmt.Errorf("%w: proxy %q", ErrNotFound, name) - } - - delete(s.proxies, name) - - if err := s.saveToFileUnlocked(); err != nil { - s.proxies[name] = oldProxy - return fmt.Errorf("failed to persist: %w", err) - } - return nil + return removeStoreEntry(s, proxyStoreEntries, storeKindProxy, name) } func (s *StoreSource) GetProxy(name string) v1.ProxyConfigurer { @@ -259,78 +296,23 @@ func (s *StoreSource) GetProxy(name string) v1.ProxyConfigurer { } func (s *StoreSource) AddVisitor(visitor v1.VisitorConfigurer) error { - if visitor == nil { - return fmt.Errorf("visitor cannot be nil") + name, err := validateVisitorName(visitor) + if err != nil { + return err } - - name := visitor.GetBaseConfig().Name - if name == "" { - return fmt.Errorf("visitor name cannot be empty") - } - - s.mu.Lock() - defer s.mu.Unlock() - - if _, exists := s.visitors[name]; exists { - return fmt.Errorf("%w: visitor %q", ErrAlreadyExists, name) - } - - s.visitors[name] = visitor - - if err := s.saveToFileUnlocked(); err != nil { - delete(s.visitors, name) - return fmt.Errorf("failed to persist: %w", err) - } - return nil + return addStoreEntry(s, visitorStoreEntries, storeKindVisitor, name, visitor) } func (s *StoreSource) UpdateVisitor(visitor v1.VisitorConfigurer) error { - if visitor == nil { - return fmt.Errorf("visitor cannot be nil") + name, err := validateVisitorName(visitor) + if err != nil { + return err } - - name := visitor.GetBaseConfig().Name - if name == "" { - return fmt.Errorf("visitor name cannot be empty") - } - - s.mu.Lock() - defer s.mu.Unlock() - - oldVisitor, exists := s.visitors[name] - if !exists { - return fmt.Errorf("%w: visitor %q", ErrNotFound, name) - } - - s.visitors[name] = visitor - - if err := s.saveToFileUnlocked(); err != nil { - s.visitors[name] = oldVisitor - return fmt.Errorf("failed to persist: %w", err) - } - return nil + return updateStoreEntry(s, visitorStoreEntries, storeKindVisitor, name, visitor) } func (s *StoreSource) RemoveVisitor(name string) error { - if name == "" { - return fmt.Errorf("visitor name cannot be empty") - } - - s.mu.Lock() - defer s.mu.Unlock() - - oldVisitor, exists := s.visitors[name] - if !exists { - return fmt.Errorf("%w: visitor %q", ErrNotFound, name) - } - - delete(s.visitors, name) - - if err := s.saveToFileUnlocked(); err != nil { - s.visitors[name] = oldVisitor - return fmt.Errorf("failed to persist: %w", err) - } - return nil + return removeStoreEntry(s, visitorStoreEntries, storeKindVisitor, name) } func (s *StoreSource) GetVisitor(name string) v1.VisitorConfigurer { diff --git a/pkg/config/source/store_test.go b/pkg/config/source/store_test.go index bb5382b0..8bb107c5 100644 --- a/pkg/config/source/store_test.go +++ b/pkg/config/source/store_test.go @@ -17,6 +17,7 @@ package source import ( "os" "path/filepath" + "runtime" "testing" "github.com/stretchr/testify/require" @@ -59,6 +60,101 @@ func TestStoreSource_AddProxyAndVisitor_DoesNotApplyRuntimeDefaults(t *testing.T require.Empty(gotVisitor.(*v1.XTCPVisitorConfig).Protocol) } +func TestStoreSource_UpdateAndRemoveProxyAndVisitor(t *testing.T) { + require := require.New(t) + + storeSource := newTestStoreSource(t) + + proxyCfg := mockProxy("proxy1") + visitorCfg := mockVisitor("visitor1") + + require.NoError(storeSource.AddProxy(proxyCfg)) + require.NoError(storeSource.AddVisitor(visitorCfg)) + require.ErrorIs(storeSource.AddProxy(proxyCfg), ErrAlreadyExists) + require.ErrorIs(storeSource.AddVisitor(visitorCfg), ErrAlreadyExists) + require.ErrorContains(storeSource.RemoveProxy(""), "proxy name cannot be empty") + require.ErrorContains(storeSource.RemoveVisitor(""), "visitor name cannot be empty") + + updatedProxy := mockProxy("proxy1").(*v1.TCPProxyConfig) + updatedProxy.RemotePort = 19090 + require.NoError(storeSource.UpdateProxy(updatedProxy)) + require.Equal(19090, storeSource.GetProxy("proxy1").(*v1.TCPProxyConfig).RemotePort) + + updatedVisitor := mockVisitor("visitor1").(*v1.STCPVisitorConfig) + updatedVisitor.ServerName = "updated-server" + require.NoError(storeSource.UpdateVisitor(updatedVisitor)) + require.Equal("updated-server", storeSource.GetVisitor("visitor1").(*v1.STCPVisitorConfig).ServerName) + + require.NoError(storeSource.RemoveProxy("proxy1")) + require.Nil(storeSource.GetProxy("proxy1")) + require.ErrorIs(storeSource.RemoveProxy("proxy1"), ErrNotFound) + + require.NoError(storeSource.RemoveVisitor("visitor1")) + require.Nil(storeSource.GetVisitor("visitor1")) + require.ErrorIs(storeSource.RemoveVisitor("visitor1"), ErrNotFound) + + require.ErrorIs(storeSource.UpdateProxy(updatedProxy), ErrNotFound) + require.ErrorIs(storeSource.UpdateVisitor(updatedVisitor), ErrNotFound) +} + +func TestStoreSource_MutationRollsBackOnPersistFailure(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("chmod does not make directories unwritable on Windows") + } + if os.Getuid() == 0 { + t.Skip("chmod does not block writes for uid 0") + } + + require := require.New(t) + + dir := t.TempDir() + path := filepath.Join(dir, "store.json") + storeSource, err := NewStoreSource(StoreSourceConfig{Path: path}) + require.NoError(err) + + proxyCfg := mockProxy("proxy1") + visitorCfg := mockVisitor("visitor1") + originalRemotePort := proxyCfg.(*v1.TCPProxyConfig).RemotePort + originalServerName := visitorCfg.(*v1.STCPVisitorConfig).ServerName + require.NoError(storeSource.AddProxy(proxyCfg)) + require.NoError(storeSource.AddVisitor(visitorCfg)) + + require.NoError(os.Chmod(dir, 0o500)) + t.Cleanup(func() { + _ = os.Chmod(dir, 0o700) + }) + + requirePersistError := func(err error) { + t.Helper() + require.Error(err) + require.ErrorContains(err, "failed to persist") + require.NotErrorIs(err, ErrAlreadyExists) + require.NotErrorIs(err, ErrNotFound) + } + + requirePersistError(storeSource.AddProxy(mockProxy("proxy2"))) + require.Nil(storeSource.GetProxy("proxy2")) + + updatedProxy := mockProxy("proxy1").(*v1.TCPProxyConfig) + updatedProxy.RemotePort = 19090 + requirePersistError(storeSource.UpdateProxy(updatedProxy)) + require.Equal(originalRemotePort, storeSource.GetProxy("proxy1").(*v1.TCPProxyConfig).RemotePort) + + requirePersistError(storeSource.RemoveProxy("proxy1")) + require.NotNil(storeSource.GetProxy("proxy1")) + + requirePersistError(storeSource.AddVisitor(mockVisitor("visitor2"))) + require.Nil(storeSource.GetVisitor("visitor2")) + + updatedVisitor := mockVisitor("visitor1").(*v1.STCPVisitorConfig) + updatedVisitor.ServerName = "updated-server" + requirePersistError(storeSource.UpdateVisitor(updatedVisitor)) + require.Equal(originalServerName, storeSource.GetVisitor("visitor1").(*v1.STCPVisitorConfig).ServerName) + + requirePersistError(storeSource.RemoveVisitor("visitor1")) + require.NotNil(storeSource.GetVisitor("visitor1")) +} + func TestStoreSource_LoadFromFile_DoesNotApplyRuntimeDefaults(t *testing.T) { require := require.New(t) diff --git a/pkg/config/source/validation.go b/pkg/config/source/validation.go new file mode 100644 index 00000000..55bc4220 --- /dev/null +++ b/pkg/config/source/validation.go @@ -0,0 +1,43 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package source + +import ( + "fmt" + + v1 "github.com/fatedier/frp/pkg/config/v1" +) + +func validateProxyName(proxy v1.ProxyConfigurer) (string, error) { + if proxy == nil { + return "", fmt.Errorf("proxy cannot be nil") + } + name := proxy.GetBaseConfig().Name + if name == "" { + return "", fmt.Errorf("proxy name cannot be empty") + } + return name, nil +} + +func validateVisitorName(visitor v1.VisitorConfigurer) (string, error) { + if visitor == nil { + return "", fmt.Errorf("visitor cannot be nil") + } + name := visitor.GetBaseConfig().Name + if name == "" { + return "", fmt.Errorf("visitor name cannot be empty") + } + return name, nil +} diff --git a/pkg/config/v1/validation/auth.go b/pkg/config/v1/validation/auth.go new file mode 100644 index 00000000..c70235dc --- /dev/null +++ b/pkg/config/v1/validation/auth.go @@ -0,0 +1,43 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "fmt" + + v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/policy/security" +) + +func (v *ConfigValidator) validateAuthTokenSource(token string, tokenSource *v1.ValueSource) error { + var errs error + // Preserve the previous client/server validation order for joined errors. + if token != "" && tokenSource != nil { + errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource")) + } + if tokenSource == nil { + return errs + } + + if tokenSource.Type == "exec" { + if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil { + errs = AppendError(errs, err) + } + } + if err := tokenSource.Validate(); err != nil { + errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err)) + } + return errs +} diff --git a/pkg/config/v1/validation/auth_test.go b/pkg/config/v1/validation/auth_test.go new file mode 100644 index 00000000..4e90b44c --- /dev/null +++ b/pkg/config/v1/validation/auth_test.go @@ -0,0 +1,228 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validation + +import ( + "testing" + + "github.com/stretchr/testify/require" + + v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/policy/security" +) + +const ( + tokenSourceConflictErr = "cannot specify both auth.token and auth.tokenSource" + tokenSourceExecErr = "unsafe feature \"TokenSourceExec\" is not enabled. To enable it, ensure it is allowed in the configuration or command line flags" + invalidFileSourceErr = "invalid auth.tokenSource: file configuration is required when type is 'file'" + unsupportedSourceErr = "invalid auth.tokenSource: unsupported value source type: env (only 'file' and 'exec' are supported)" +) + +func TestValidateAuthTokenSource(t *testing.T) { + for _, tc := range authTokenSourceTestCases() { + t.Run(tc.name, func(t *testing.T) { + validator := newAuthTokenSourceValidator(tc.unsafeAllowed) + err := validator.validateAuthTokenSource(tc.token, tc.tokenSource()) + requireValidationErrors(t, err, tc.wantErrs) + }) + } +} + +func TestValidateClientAuthTokenSource(t *testing.T) { + for _, tc := range authTokenSourceTestCases() { + t.Run(tc.name, func(t *testing.T) { + auth := v1.AuthClientConfig{ + Method: v1.AuthMethodToken, + Token: tc.token, + TokenSource: tc.tokenSource(), + } + validator := newAuthTokenSourceValidator(tc.unsafeAllowed) + _, err := validator.ValidateClientCommonConfig(validClientConfigWithAuth(auth)) + requireValidationErrors(t, err, tc.wantErrs) + }) + } +} + +func TestValidateServerAuthTokenSource(t *testing.T) { + for _, tc := range authTokenSourceTestCases() { + t.Run(tc.name, func(t *testing.T) { + auth := v1.AuthServerConfig{ + Method: v1.AuthMethodToken, + Token: tc.token, + TokenSource: tc.tokenSource(), + } + validator := newAuthTokenSourceValidator(tc.unsafeAllowed) + _, err := validator.ValidateServerConfig(validServerConfigWithAuth(auth)) + requireValidationErrors(t, err, tc.wantErrs) + }) + } +} + +type authTokenSourceTestCase struct { + name string + token string + tokenSource func() *v1.ValueSource + unsafeAllowed bool + wantErrs []string +} + +func authTokenSourceTestCases() []authTokenSourceTestCase { + return []authTokenSourceTestCase{ + { + name: "empty token config", + tokenSource: nilTokenSource, + }, + { + name: "valid file tokenSource", + tokenSource: validFileTokenSource, + }, + { + name: "literal token without tokenSource", + token: "token", + tokenSource: nilTokenSource, + }, + { + name: "literal token conflicts with file tokenSource", + token: "token", + tokenSource: validFileTokenSource, + wantErrs: []string{tokenSourceConflictErr}, + }, + { + name: "exec tokenSource requires unsafe feature", + tokenSource: validExecTokenSource, + wantErrs: []string{tokenSourceExecErr}, + }, + { + name: "exec tokenSource with unsafe feature allowed", + tokenSource: validExecTokenSource, + unsafeAllowed: true, + }, + { + name: "literal token conflicts with exec tokenSource and unsafe feature disabled", + token: "token", + tokenSource: validExecTokenSource, + wantErrs: []string{ + tokenSourceConflictErr, + tokenSourceExecErr, + }, + }, + { + name: "literal token conflicts with exec tokenSource and unsafe feature allowed", + token: "token", + tokenSource: validExecTokenSource, + unsafeAllowed: true, + wantErrs: []string{tokenSourceConflictErr}, + }, + { + name: "invalid file tokenSource is wrapped", + tokenSource: invalidFileTokenSource, + wantErrs: []string{invalidFileSourceErr}, + }, + { + name: "unsupported tokenSource type is wrapped", + tokenSource: unsupportedTokenSource, + wantErrs: []string{unsupportedSourceErr}, + }, + } +} + +func newAuthTokenSourceValidator(unsafeAllowed bool) *ConfigValidator { + if !unsafeAllowed { + return NewConfigValidator(nil) + } + return NewConfigValidator(security.NewUnsafeFeatures([]string{security.TokenSourceExec})) +} + +func requireValidationErrors(t *testing.T, err error, wantErrs []string) { + t.Helper() + if len(wantErrs) == 0 { + require.NoError(t, err) + return + } + require.Error(t, err) + // Client/server validators may wrap joined errors in another join layer; compare leaf errors. + gotErrs := unwrapValidationErrors(err) + require.Len(t, gotErrs, len(wantErrs)) + for i, wantErr := range wantErrs { + require.EqualError(t, gotErrs[i], wantErr) + } +} + +func unwrapValidationErrors(err error) []error { + type joinedError interface { + Unwrap() []error + } + joined, ok := err.(joinedError) + if !ok { + return []error{err} + } + + var errs []error + for _, err := range joined.Unwrap() { + errs = append(errs, unwrapValidationErrors(err)...) + } + return errs +} + +// nilTokenSource keeps the shared table shape uniform for cases without a tokenSource. +func nilTokenSource() *v1.ValueSource { + return nil +} + +func validFileTokenSource() *v1.ValueSource { + return &v1.ValueSource{ + Type: "file", + File: &v1.FileSource{Path: "token.txt"}, + } +} + +func validExecTokenSource() *v1.ValueSource { + return &v1.ValueSource{ + Type: "exec", + Exec: &v1.ExecSource{Command: "print-token"}, + } +} + +func invalidFileTokenSource() *v1.ValueSource { + return &v1.ValueSource{ + Type: "file", + } +} + +func unsupportedTokenSource() *v1.ValueSource { + return &v1.ValueSource{Type: "env"} +} + +func validClientConfigWithAuth(auth v1.AuthClientConfig) *v1.ClientCommonConfig { + return &v1.ClientCommonConfig{ + Auth: auth, + Log: v1.LogConfig{ + Level: "info", + }, + Transport: v1.ClientTransportConfig{ + Protocol: "tcp", + WireProtocol: "v1", + }, + } +} + +func validServerConfigWithAuth(auth v1.AuthServerConfig) *v1.ServerConfig { + return &v1.ServerConfig{ + Auth: auth, + Log: v1.LogConfig{ + Level: "info", + }, + } +} diff --git a/pkg/config/v1/validation/client.go b/pkg/config/v1/validation/client.go index 5c8433d5..77004317 100644 --- a/pkg/config/v1/validation/client.go +++ b/pkg/config/v1/validation/client.go @@ -68,22 +68,7 @@ func (v *ConfigValidator) validateAuthConfig(c *v1.AuthClientConfig) (Warning, e errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes)) } - // Validate token/tokenSource mutual exclusivity - if c.Token != "" && c.TokenSource != nil { - errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource")) - } - - // Validate tokenSource if specified - if c.TokenSource != nil { - if c.TokenSource.Type == "exec" { - if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil { - errs = AppendError(errs, err) - } - } - if err := c.TokenSource.Validate(); err != nil { - errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err)) - } - } + errs = AppendError(errs, v.validateAuthTokenSource(c.Token, c.TokenSource)) if err := v.validateOIDCConfig(&c.OIDC); err != nil { errs = AppendError(errs, err) diff --git a/pkg/config/v1/validation/server.go b/pkg/config/v1/validation/server.go index 338ecc82..8be740aa 100644 --- a/pkg/config/v1/validation/server.go +++ b/pkg/config/v1/validation/server.go @@ -21,7 +21,6 @@ import ( "github.com/samber/lo" v1 "github.com/fatedier/frp/pkg/config/v1" - "github.com/fatedier/frp/pkg/policy/security" ) func (v *ConfigValidator) ValidateServerConfig(c *v1.ServerConfig) (Warning, error) { @@ -36,22 +35,7 @@ func (v *ConfigValidator) ValidateServerConfig(c *v1.ServerConfig) (Warning, err errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes)) } - // Validate token/tokenSource mutual exclusivity - if c.Auth.Token != "" && c.Auth.TokenSource != nil { - errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource")) - } - - // Validate tokenSource if specified - if c.Auth.TokenSource != nil { - if c.Auth.TokenSource.Type == "exec" { - if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil { - errs = AppendError(errs, err) - } - } - if err := c.Auth.TokenSource.Validate(); err != nil { - errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err)) - } - } + errs = AppendError(errs, v.validateAuthTokenSource(c.Auth.Token, c.Auth.TokenSource)) if err := validateLogConfig(&c.Log); err != nil { errs = AppendError(errs, err) diff --git a/pkg/metrics/mem/server.go b/pkg/metrics/mem/server.go index add90a5a..a3c2bb65 100644 --- a/pkg/metrics/mem/server.go +++ b/pkg/metrics/mem/server.go @@ -18,6 +18,8 @@ import ( "sync" "time" + "k8s.io/utils/clock" + "github.com/fatedier/frp/pkg/util/log" "github.com/fatedier/frp/pkg/util/metric" server "github.com/fatedier/frp/server/metrics" @@ -37,12 +39,21 @@ func init() { } type serverMetrics struct { - info *ServerStatistics - mu sync.Mutex + info *ServerStatistics + clock clock.WithTicker + mu sync.Mutex } func newServerMetrics() *serverMetrics { + return newServerMetricsWithClock(clock.RealClock{}) +} + +func newServerMetricsWithClock(clk clock.WithTicker) *serverMetrics { + if clk == nil { + clk = clock.RealClock{} + } return &serverMetrics{ + clock: clk, info: &ServerStatistics{ TotalTrafficIn: metric.NewDateCounter(ReserveDays), TotalTrafficOut: metric.NewDateCounter(ReserveDays), @@ -57,14 +68,23 @@ func newServerMetrics() *serverMetrics { } func (m *serverMetrics) run() { - go func() { - for { - time.Sleep(12 * time.Hour) - start := time.Now() + go m.runUntil(nil) +} + +func (m *serverMetrics) runUntil(stopCh <-chan struct{}) { + ticker := m.clock.NewTicker(12 * time.Hour) + defer ticker.Stop() + + for { + select { + case <-ticker.C(): + start := m.clock.Now() count, total := m.clearUselessInfo(time.Duration(7*24) * time.Hour) - log.Debugf("clear useless proxy statistics data count %d/%d, cost %v", count, total, time.Since(start)) + log.Debugf("clear useless proxy statistics data count %d/%d, cost %v", count, total, m.clock.Since(start)) + case <-stopCh: + return } - }() + } } func (m *serverMetrics) clearUselessInfo(continuousOfflineDuration time.Duration) (int, int) { @@ -77,7 +97,7 @@ func (m *serverMetrics) clearUselessInfo(continuousOfflineDuration time.Duration for name, data := range m.info.ProxyStatistics { if !data.LastCloseTime.IsZero() && data.LastStartTime.Before(data.LastCloseTime) && - time.Since(data.LastCloseTime) > continuousOfflineDuration { + m.clock.Since(data.LastCloseTime) > continuousOfflineDuration { delete(m.info.ProxyStatistics, name) count++ log.Tracef("clear proxy [%s]'s statistics data, lastCloseTime: [%s]", name, data.LastCloseTime.String()) @@ -121,7 +141,7 @@ func (m *serverMetrics) NewProxy(name string, proxyType string, user string, cli } proxyStats.User = user proxyStats.ClientID = clientID - proxyStats.LastStartTime = time.Now() + proxyStats.LastStartTime = m.clock.Now() } func (m *serverMetrics) CloseProxy(name string, proxyType string) { @@ -131,7 +151,7 @@ func (m *serverMetrics) CloseProxy(name string, proxyType string) { counter.Dec(1) } if proxyStats, ok := m.info.ProxyStatistics[name]; ok { - proxyStats.LastCloseTime = time.Now() + proxyStats.LastCloseTime = m.clock.Now() } } diff --git a/pkg/metrics/mem/server_test.go b/pkg/metrics/mem/server_test.go new file mode 100644 index 00000000..fe9f9984 --- /dev/null +++ b/pkg/metrics/mem/server_test.go @@ -0,0 +1,83 @@ +package mem + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + clocktesting "k8s.io/utils/clock/testing" +) + +func TestServerMetricsUsesClockForProxyTimestamps(t *testing.T) { + require := require.New(t) + + start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC) + clk := clocktesting.NewFakeClock(start) + metrics := newServerMetricsWithClock(clk) + + metrics.NewProxy("proxy", "tcp", "user", "client-id") + require.Equal(start, metrics.info.ProxyStatistics["proxy"].LastStartTime) + + closedAt := start.Add(time.Minute) + clk.SetTime(closedAt) + metrics.CloseProxy("proxy", "tcp") + require.Equal(closedAt, metrics.info.ProxyStatistics["proxy"].LastCloseTime) +} + +func TestServerMetricsClearUselessInfoUsesClock(t *testing.T) { + require := require.New(t) + + start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC) + clk := clocktesting.NewFakeClock(start.Add(25 * time.Hour)) + metrics := newServerMetricsWithClock(clk) + metrics.info.ProxyStatistics["proxy"] = &ProxyStatistics{ + Name: "proxy", + LastStartTime: start.Add(-time.Hour), + LastCloseTime: start, + } + + count, total := metrics.clearUselessInfo(24 * time.Hour) + + require.Equal(1, count) + require.Equal(1, total) + require.Empty(metrics.info.ProxyStatistics) +} + +func TestServerMetricsRunUsesClockTicker(t *testing.T) { + require := require.New(t) + + start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC) + clk := clocktesting.NewFakeClock(start) + metrics := newServerMetricsWithClock(clk) + metrics.info.ProxyStatistics["proxy"] = &ProxyStatistics{ + Name: "proxy", + LastStartTime: start.Add(-time.Hour), + LastCloseTime: start, + } + + stopCh := make(chan struct{}) + done := make(chan struct{}) + go func() { + defer close(done) + metrics.runUntil(stopCh) + }() + t.Cleanup(func() { + close(stopCh) + <-done + }) + + require.Eventually(clk.HasWaiters, time.Second, time.Millisecond) + clk.Step(8 * 24 * time.Hour) + + require.Eventually(func() bool { + return !metrics.hasProxyStatistics("proxy") + }, time.Second, time.Millisecond) +} + +func (m *serverMetrics) hasProxyStatistics(name string) bool { + m.mu.Lock() + defer m.mu.Unlock() + + _, ok := m.info.ProxyStatistics[name] + return ok +} diff --git a/pkg/msg/handler.go b/pkg/msg/handler.go index b073e59b..1f2775c9 100644 --- a/pkg/msg/handler.go +++ b/pkg/msg/handler.go @@ -109,10 +109,9 @@ func AsyncHandler(f func(Message)) func(Message) { type Dispatcher struct { rw ReadWriter - sendCh chan Message - doneCh chan struct{} - msgHandlers map[reflect.Type]func(Message) - defaultHandler func(Message) + sendCh chan Message + doneCh chan struct{} + msgHandlers map[reflect.Type]func(Message) } func NewDispatcher(rw ReadWriter) *Dispatcher { @@ -151,8 +150,6 @@ func (d *Dispatcher) readLoop() { if handler, ok := d.msgHandlers[reflect.TypeOf(m)]; ok { handler(m) - } else if d.defaultHandler != nil { - d.defaultHandler(m) } } } @@ -170,10 +167,6 @@ func (d *Dispatcher) RegisterHandler(msg Message, handler func(Message)) { d.msgHandlers[reflect.TypeOf(msg)] = handler } -func (d *Dispatcher) RegisterDefaultHandler(handler func(Message)) { - d.defaultHandler = handler -} - func (d *Dispatcher) Done() chan struct{} { return d.doneCh } diff --git a/pkg/nathole/analysis.go b/pkg/nathole/analysis.go index ed9eb6ce..2ff28e0e 100644 --- a/pkg/nathole/analysis.go +++ b/pkg/nathole/analysis.go @@ -21,6 +21,7 @@ import ( "time" "github.com/samber/lo" + "k8s.io/utils/clock" ) var ( @@ -144,19 +145,19 @@ func getBehaviorByModeAndIndex(mode int, index int) (RecommandBehavior, Recomman return behaviors[index].A, behaviors[index].B } -func getBehaviorScoresByMode(mode int, defaultScore int) []*BehaviorScore { +func getBehaviorScoresByMode(mode int, defaultScore int) []*behaviorScore { return getBehaviorScoresByMode2(mode, defaultScore, defaultScore) } -func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore { +func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*behaviorScore { behaviors := getBehaviorByMode(mode) - scores := make([]*BehaviorScore, 0, len(behaviors)) + scores := make([]*behaviorScore, 0, len(behaviors)) for i := range behaviors { score := receiverScore if behaviors[i].A.Role == DetectRoleSender { score = senderScore } - scores = append(scores, &BehaviorScore{Mode: mode, Index: i, Score: score}) + scores = append(scores, &behaviorScore{Mode: mode, Index: i, Score: score}) } return scores } @@ -170,14 +171,18 @@ type RecommandBehavior struct { ListenRandomPorts int } -type MakeHoleRecords struct { +type makeHoleRecords struct { mu sync.Mutex - scores []*BehaviorScore - LastUpdateTime time.Time + scores []*behaviorScore + clock clock.PassiveClock + lastUpdateTime time.Time } -func NewMakeHoleRecords(c, v *NatFeature) *MakeHoleRecords { - scores := []*BehaviorScore{} +func newMakeHoleRecordsWithClock(c, v *NatFeature, clk clock.PassiveClock) *makeHoleRecords { + if clk == nil { + clk = clock.RealClock{} + } + scores := []*behaviorScore{} easyCount, hardCount, portsChangedRegularCount := ClassifyFeatureCount([]*NatFeature{c, v}) appendMode0 := func() { switch { @@ -212,13 +217,17 @@ func NewMakeHoleRecords(c, v *NatFeature) *MakeHoleRecords { scores = append(scores, getBehaviorScoresByMode(DetectMode1, 1)...) scores = append(scores, getBehaviorScoresByMode(DetectMode3, 1)...) } - return &MakeHoleRecords{scores: scores, LastUpdateTime: time.Now()} + return &makeHoleRecords{ + scores: scores, + clock: clk, + lastUpdateTime: clk.Now(), + } } -func (mhr *MakeHoleRecords) ReportSuccess(mode int, index int) { +func (mhr *makeHoleRecords) reportSuccess(mode int, index int) { mhr.mu.Lock() defer mhr.mu.Unlock() - mhr.LastUpdateTime = time.Now() + mhr.lastUpdateTime = mhr.clock.Now() for i := range mhr.scores { score := mhr.scores[i] if score.Mode != mode || score.Index != index { @@ -231,22 +240,22 @@ func (mhr *MakeHoleRecords) ReportSuccess(mode int, index int) { } } -func (mhr *MakeHoleRecords) Recommand() (mode, index int) { +func (mhr *makeHoleRecords) recommand() (mode, index int) { mhr.mu.Lock() defer mhr.mu.Unlock() if len(mhr.scores) == 0 { return 0, 0 } - maxScore := slices.MaxFunc(mhr.scores, func(a, b *BehaviorScore) int { + maxScore := slices.MaxFunc(mhr.scores, func(a, b *behaviorScore) int { return cmp.Compare(a.Score, b.Score) }) maxScore.Score-- - mhr.LastUpdateTime = time.Now() + mhr.lastUpdateTime = mhr.clock.Now() return maxScore.Mode, maxScore.Index } -type BehaviorScore struct { +type behaviorScore struct { Mode int Index int // between -10 and 10 @@ -255,16 +264,25 @@ type BehaviorScore struct { type Analyzer struct { // key is client ip + visitor ip - records map[string]*MakeHoleRecords + records map[string]*makeHoleRecords dataReserveDuration time.Duration + clock clock.PassiveClock mu sync.Mutex } func NewAnalyzer(dataReserveDuration time.Duration) *Analyzer { + return newAnalyzerWithClock(dataReserveDuration, clock.RealClock{}) +} + +func newAnalyzerWithClock(dataReserveDuration time.Duration, clk clock.PassiveClock) *Analyzer { + if clk == nil { + clk = clock.RealClock{} + } return &Analyzer{ - records: make(map[string]*MakeHoleRecords), + records: make(map[string]*makeHoleRecords), dataReserveDuration: dataReserveDuration, + clock: clk, } } @@ -272,12 +290,12 @@ func (a *Analyzer) GetRecommandBehaviors(key string, c, v *NatFeature) (mode, in a.mu.Lock() records, ok := a.records[key] if !ok { - records = NewMakeHoleRecords(c, v) + records = newMakeHoleRecordsWithClock(c, v, a.clock) a.records[key] = records } a.mu.Unlock() - mode, index = records.Recommand() + mode, index = records.recommand() cBehavior, vBehavior := getBehaviorByModeAndIndex(mode, index) switch mode { @@ -307,11 +325,11 @@ func (a *Analyzer) ReportSuccess(key string, mode, index int) { if !ok { return } - records.ReportSuccess(mode, index) + records.reportSuccess(mode, index) } func (a *Analyzer) Clean() (int, int) { - now := time.Now() + now := a.clock.Now() total := 0 count := 0 @@ -321,7 +339,7 @@ func (a *Analyzer) Clean() (int, int) { total = len(a.records) // clean up records that have not been used for a period of time. for key, records := range a.records { - if now.Sub(records.LastUpdateTime) > a.dataReserveDuration { + if now.Sub(records.lastUpdateTime) > a.dataReserveDuration { delete(a.records, key) count++ } diff --git a/pkg/nathole/analysis_test.go b/pkg/nathole/analysis_test.go new file mode 100644 index 00000000..05d0f8dc --- /dev/null +++ b/pkg/nathole/analysis_test.go @@ -0,0 +1,33 @@ +package nathole + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + clocktesting "k8s.io/utils/clock/testing" +) + +func TestAnalyzerUsesClockForRecordTimestamps(t *testing.T) { + require := require.New(t) + + start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC) + clk := clocktesting.NewFakeClock(start) + analyzer := newAnalyzerWithClock(time.Hour, clk) + clientFeature := &NatFeature{NatType: EasyNAT, Behavior: BehaviorNoChange} + visitorFeature := &NatFeature{NatType: EasyNAT, Behavior: BehaviorNoChange} + + mode, index, _, _ := analyzer.GetRecommandBehaviors("key", clientFeature, visitorFeature) + require.Equal(start, analyzer.records["key"].lastUpdateTime) + + updatedAt := start.Add(time.Minute) + clk.SetTime(updatedAt) + analyzer.ReportSuccess("key", mode, index) + require.Equal(updatedAt, analyzer.records["key"].lastUpdateTime) + + clk.SetTime(start.Add(2 * time.Hour)) + count, total := analyzer.Clean() + require.Equal(1, count) + require.Equal(1, total) + require.Empty(analyzer.records) +} diff --git a/pkg/nathole/controller.go b/pkg/nathole/controller.go index 2562bfd2..22a5ec53 100644 --- a/pkg/nathole/controller.go +++ b/pkg/nathole/controller.go @@ -326,40 +326,16 @@ func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleR } protocol := vm.Protocol - vResp := &msg.NatHoleResp{ - TransactionID: vm.TransactionID, - Sid: session.sid, - Protocol: protocol, - CandidateAddrs: slices.Compact(cm.MappedAddrs), - AssistedAddrs: slices.Compact(cm.AssistedAddrs), - DetectBehavior: msg.NatHoleDetectBehavior{ - Mode: mode, - Role: vBehavior.Role, - TTL: vBehavior.TTL, - SendDelayMs: vBehavior.SendDelayMs, - ReadTimeoutMs: timeoutMs - vBehavior.SendDelayMs, - SendRandomPorts: vBehavior.PortsRandomNumber, - ListenRandomPorts: vBehavior.ListenRandomPorts, - CandidatePorts: getRangePorts(cm.MappedAddrs, cNatFeature.PortsDifference, vBehavior.PortsRangeNumber), - }, - } - cResp := &msg.NatHoleResp{ - TransactionID: cm.TransactionID, - Sid: session.sid, - Protocol: protocol, - CandidateAddrs: slices.Compact(vm.MappedAddrs), - AssistedAddrs: slices.Compact(vm.AssistedAddrs), - DetectBehavior: msg.NatHoleDetectBehavior{ - Mode: mode, - Role: cBehavior.Role, - TTL: cBehavior.TTL, - SendDelayMs: cBehavior.SendDelayMs, - ReadTimeoutMs: timeoutMs - cBehavior.SendDelayMs, - SendRandomPorts: cBehavior.PortsRandomNumber, - ListenRandomPorts: cBehavior.ListenRandomPorts, - CandidatePorts: getRangePorts(vm.MappedAddrs, vNatFeature.PortsDifference, cBehavior.PortsRangeNumber), - }, - } + vResp := newNatHoleResponse( + vm.TransactionID, session.sid, protocol, mode, + cm.MappedAddrs, cm.AssistedAddrs, vBehavior, + timeoutMs-vBehavior.SendDelayMs, cNatFeature.PortsDifference, + ) + cResp := newNatHoleResponse( + cm.TransactionID, session.sid, protocol, mode, + vm.MappedAddrs, vm.AssistedAddrs, cBehavior, + timeoutMs-cBehavior.SendDelayMs, vNatFeature.PortsDifference, + ) log.Debugf("sid [%s] visitor nat: %+v, candidateAddrs: %v; client nat: %+v, candidateAddrs: %v, protocol: %s", session.sid, *vNatFeature, vm.MappedAddrs, *cNatFeature, cm.MappedAddrs, protocol) @@ -368,6 +344,38 @@ func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleR return vResp, cResp, nil } +func newNatHoleResponse( + transactionID string, + sid string, + protocol string, + mode int, + candidateAddrs []string, + assistedAddrs []string, + behavior RecommandBehavior, + readTimeoutMs int, + portsDifference int, +) *msg.NatHoleResp { + compactCandidateAddrs := slices.Compact(candidateAddrs) + compactAssistedAddrs := slices.Compact(assistedAddrs) + return &msg.NatHoleResp{ + TransactionID: transactionID, + Sid: sid, + Protocol: protocol, + CandidateAddrs: compactCandidateAddrs, + AssistedAddrs: compactAssistedAddrs, + DetectBehavior: msg.NatHoleDetectBehavior{ + Mode: mode, + Role: behavior.Role, + TTL: behavior.TTL, + SendDelayMs: behavior.SendDelayMs, + ReadTimeoutMs: readTimeoutMs, + SendRandomPorts: behavior.PortsRandomNumber, + ListenRandomPorts: behavior.ListenRandomPorts, + CandidatePorts: getRangePorts(candidateAddrs, portsDifference, behavior.PortsRangeNumber), + }, + } +} + func getRangePorts(addrs []string, difference, maxNumber int) []msg.PortsRange { if maxNumber <= 0 { return nil diff --git a/pkg/plugin/client/http2http.go b/pkg/plugin/client/http2http.go index e50a91c0..f92daddd 100644 --- a/pkg/plugin/client/http2http.go +++ b/pkg/plugin/client/http2http.go @@ -17,17 +17,9 @@ package client import ( - "context" - stdlog "log" - "net/http" "net/http/httputil" - "time" - - "github.com/fatedier/golib/pool" v1 "github.com/fatedier/frp/pkg/config/v1" - "github.com/fatedier/frp/pkg/util/log" - netpkg "github.com/fatedier/frp/pkg/util/net" ) func init() { @@ -37,57 +29,28 @@ func init() { type HTTP2HTTPPlugin struct { opts *v1.HTTP2HTTPPluginOptions - l *Listener - s *http.Server + *httpBridgePlugin } func NewHTTP2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) { opts := options.(*v1.HTTP2HTTPPluginOptions) - listener := NewProxyListener() - p := &HTTP2HTTPPlugin{ opts: opts, - l: listener, } - rp := &httputil.ReverseProxy{ - Rewrite: func(r *httputil.ProxyRequest) { + rp := newHTTPBridgeReverseProxy( + func(r *httputil.ProxyRequest) { req := r.Out - req.URL.Scheme = "http" - req.URL.Host = p.opts.LocalAddr - if p.opts.HostHeaderRewrite != "" { - req.Host = p.opts.HostHeaderRewrite - } - for k, v := range p.opts.RequestHeaders.Set { - req.Header.Set(k, v) - } + rewriteHTTPPluginRequest(req, "http", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders) }, - BufferPool: pool.NewBuffer(32 * 1024), - ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0), - } - - p.s = &http.Server{ - Handler: rp, - ReadHeaderTimeout: 60 * time.Second, - } - - go func() { - _ = p.s.Serve(listener) - }() + nil, + ) + p.httpBridgePlugin = newHTTPBridgePluginServer(rp, false) return p, nil } -func (p *HTTP2HTTPPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) { - wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn) - _ = p.l.PutConn(wrapConn) -} - func (p *HTTP2HTTPPlugin) Name() string { return v1.PluginHTTP2HTTP } - -func (p *HTTP2HTTPPlugin) Close() error { - return p.s.Close() -} diff --git a/pkg/plugin/client/http2https.go b/pkg/plugin/client/http2https.go index 8119e095..f5a61602 100644 --- a/pkg/plugin/client/http2https.go +++ b/pkg/plugin/client/http2https.go @@ -17,18 +17,11 @@ package client import ( - "context" "crypto/tls" - stdlog "log" "net/http" "net/http/httputil" - "time" - - "github.com/fatedier/golib/pool" v1 "github.com/fatedier/frp/pkg/config/v1" - "github.com/fatedier/frp/pkg/util/log" - netpkg "github.com/fatedier/frp/pkg/util/net" ) func init() { @@ -38,65 +31,35 @@ func init() { type HTTP2HTTPSPlugin struct { opts *v1.HTTP2HTTPSPluginOptions - l *Listener - s *http.Server + *httpBridgePlugin } func NewHTTP2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) { opts := options.(*v1.HTTP2HTTPSPluginOptions) - listener := NewProxyListener() - p := &HTTP2HTTPSPlugin{ opts: opts, - l: listener, } tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } - rp := &httputil.ReverseProxy{ - Rewrite: func(r *httputil.ProxyRequest) { + rp := newHTTPBridgeReverseProxy( + func(r *httputil.ProxyRequest) { r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"] r.Out.Header["X-Forwarded-Host"] = r.In.Header["X-Forwarded-Host"] r.Out.Header["X-Forwarded-Proto"] = r.In.Header["X-Forwarded-Proto"] req := r.Out - req.URL.Scheme = "https" - req.URL.Host = p.opts.LocalAddr - if p.opts.HostHeaderRewrite != "" { - req.Host = p.opts.HostHeaderRewrite - } - for k, v := range p.opts.RequestHeaders.Set { - req.Header.Set(k, v) - } + rewriteHTTPPluginRequest(req, "https", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders) }, - Transport: tr, - BufferPool: pool.NewBuffer(32 * 1024), - ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0), - } - - p.s = &http.Server{ - Handler: rp, - ReadHeaderTimeout: 60 * time.Second, - } - - go func() { - _ = p.s.Serve(listener) - }() + tr, + ) + p.httpBridgePlugin = newHTTPBridgePluginServer(rp, false) return p, nil } -func (p *HTTP2HTTPSPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) { - wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn) - _ = p.l.PutConn(wrapConn) -} - func (p *HTTP2HTTPSPlugin) Name() string { return v1.PluginHTTP2HTTPS } - -func (p *HTTP2HTTPSPlugin) Close() error { - return p.s.Close() -} diff --git a/pkg/plugin/client/http_common.go b/pkg/plugin/client/http_common.go new file mode 100644 index 00000000..5010a4e9 --- /dev/null +++ b/pkg/plugin/client/http_common.go @@ -0,0 +1,126 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !frps + +package client + +import ( + "context" + stdlog "log" + "net/http" + "net/http/httputil" + "time" + + "github.com/fatedier/golib/pool" + + v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/plugin/client/internal/httpsserver" + "github.com/fatedier/frp/pkg/util/log" + netpkg "github.com/fatedier/frp/pkg/util/net" +) + +const httpBridgeReadHeaderTimeout = 60 * time.Second + +func rewriteHTTPPluginRequest( + req *http.Request, + scheme string, + localAddr string, + hostHeaderRewrite string, + requestHeaders v1.HeaderOperations, +) { + req.URL.Scheme = scheme + req.URL.Host = localAddr + if hostHeaderRewrite != "" { + req.Host = hostHeaderRewrite + } + for k, v := range requestHeaders.Set { + req.Header.Set(k, v) + } +} + +type httpBridgePlugin struct { + l *Listener + s *http.Server + + useSourceRemoteAddr bool +} + +func newHTTPBridgePluginServer(handler http.Handler, useSourceRemoteAddr bool) *httpBridgePlugin { + listener := NewProxyListener() + p := &httpBridgePlugin{ + l: listener, + useSourceRemoteAddr: useSourceRemoteAddr, + } + p.s = &http.Server{ + Handler: handler, + ReadHeaderTimeout: httpBridgeReadHeaderTimeout, + } + go func() { + _ = p.s.Serve(listener) + }() + return p +} + +func newHTTPSBridgePluginServer( + handler http.Handler, + crtPath string, + keyPath string, + enableHTTP2 *bool, + useSourceRemoteAddr bool, +) (*httpBridgePlugin, error) { + listener := NewProxyListener() + server, err := httpsserver.New(handler, crtPath, keyPath, enableHTTP2) + if err != nil { + return nil, err + } + p := &httpBridgePlugin{ + l: listener, + s: server, + useSourceRemoteAddr: useSourceRemoteAddr, + } + go func() { + _ = p.s.ServeTLS(listener, "", "") + }() + return p, nil +} + +func newHTTPBridgeReverseProxy( + rewrite func(*httputil.ProxyRequest), + transport http.RoundTripper, +) *httputil.ReverseProxy { + rp := &httputil.ReverseProxy{ + Rewrite: rewrite, + BufferPool: pool.NewBuffer(32 * 1024), + ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0), + } + if transport != nil { + rp.Transport = transport + } + return rp +} + +func (p *httpBridgePlugin) Handle(_ context.Context, connInfo *ConnectionInfo) { + wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn) + if p.useSourceRemoteAddr && connInfo.SrcAddr != nil { + wrapConn.SetRemoteAddr(connInfo.SrcAddr) + } + _ = p.l.PutConn(wrapConn) +} + +func (p *httpBridgePlugin) Close() error { + err := p.s.Close() + _ = p.l.Close() + return err +} diff --git a/pkg/plugin/client/https2http.go b/pkg/plugin/client/https2http.go index 963b9d2e..8722bec9 100644 --- a/pkg/plugin/client/https2http.go +++ b/pkg/plugin/client/https2http.go @@ -17,22 +17,9 @@ package client import ( - "context" - "crypto/tls" - "fmt" - stdlog "log" - "net/http" "net/http/httputil" - "time" - - "github.com/fatedier/golib/pool" - "github.com/samber/lo" v1 "github.com/fatedier/frp/pkg/config/v1" - "github.com/fatedier/frp/pkg/transport" - httppkg "github.com/fatedier/frp/pkg/util/http" - "github.com/fatedier/frp/pkg/util/log" - netpkg "github.com/fatedier/frp/pkg/util/net" ) func init() { @@ -42,80 +29,35 @@ func init() { type HTTPS2HTTPPlugin struct { opts *v1.HTTPS2HTTPPluginOptions - l *Listener - s *http.Server + *httpBridgePlugin } func NewHTTPS2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) { opts := options.(*v1.HTTPS2HTTPPluginOptions) - listener := NewProxyListener() p := &HTTPS2HTTPPlugin{ opts: opts, - l: listener, } - rp := &httputil.ReverseProxy{ - Rewrite: func(r *httputil.ProxyRequest) { + rp := newHTTPBridgeReverseProxy( + func(r *httputil.ProxyRequest) { r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"] r.SetXForwarded() req := r.Out - req.URL.Scheme = "http" - req.URL.Host = p.opts.LocalAddr - if p.opts.HostHeaderRewrite != "" { - req.Host = p.opts.HostHeaderRewrite - } - for k, v := range p.opts.RequestHeaders.Set { - req.Header.Set(k, v) - } + rewriteHTTPPluginRequest(req, "http", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders) }, - BufferPool: pool.NewBuffer(32 * 1024), - ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0), - } - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.TLS != nil { - tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName) - host, _ := httppkg.CanonicalHost(r.Host) - if tlsServerName != "" && tlsServerName != host { - w.WriteHeader(http.StatusMisdirectedRequest) - return - } - } - rp.ServeHTTP(w, r) - }) + nil, + ) - tlsConfig, err := transport.NewServerTLSConfig(p.opts.CrtPath, p.opts.KeyPath, "") + server, err := newHTTPSBridgePluginServer(rp, p.opts.CrtPath, p.opts.KeyPath, opts.EnableHTTP2, true) if err != nil { - return nil, fmt.Errorf("gen TLS config error: %v", err) + return nil, err } + p.httpBridgePlugin = server - p.s = &http.Server{ - Handler: handler, - ReadHeaderTimeout: 60 * time.Second, - TLSConfig: tlsConfig, - } - if !lo.FromPtr(opts.EnableHTTP2) { - p.s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) - } - - go func() { - _ = p.s.ServeTLS(listener, "", "") - }() return p, nil } -func (p *HTTPS2HTTPPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) { - wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn) - if connInfo.SrcAddr != nil { - wrapConn.SetRemoteAddr(connInfo.SrcAddr) - } - _ = p.l.PutConn(wrapConn) -} - func (p *HTTPS2HTTPPlugin) Name() string { return v1.PluginHTTPS2HTTP } - -func (p *HTTPS2HTTPPlugin) Close() error { - return p.s.Close() -} diff --git a/pkg/plugin/client/https2https.go b/pkg/plugin/client/https2https.go index 5c669d36..7b9e455e 100644 --- a/pkg/plugin/client/https2https.go +++ b/pkg/plugin/client/https2https.go @@ -17,22 +17,11 @@ package client import ( - "context" "crypto/tls" - "fmt" - stdlog "log" "net/http" "net/http/httputil" - "time" - - "github.com/fatedier/golib/pool" - "github.com/samber/lo" v1 "github.com/fatedier/frp/pkg/config/v1" - "github.com/fatedier/frp/pkg/transport" - httppkg "github.com/fatedier/frp/pkg/util/http" - "github.com/fatedier/frp/pkg/util/log" - netpkg "github.com/fatedier/frp/pkg/util/net" ) func init() { @@ -42,86 +31,39 @@ func init() { type HTTPS2HTTPSPlugin struct { opts *v1.HTTPS2HTTPSPluginOptions - l *Listener - s *http.Server + *httpBridgePlugin } func NewHTTPS2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) { opts := options.(*v1.HTTPS2HTTPSPluginOptions) - listener := NewProxyListener() - p := &HTTPS2HTTPSPlugin{ opts: opts, - l: listener, } tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } - rp := &httputil.ReverseProxy{ - Rewrite: func(r *httputil.ProxyRequest) { + rp := newHTTPBridgeReverseProxy( + func(r *httputil.ProxyRequest) { r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"] r.SetXForwarded() req := r.Out - req.URL.Scheme = "https" - req.URL.Host = p.opts.LocalAddr - if p.opts.HostHeaderRewrite != "" { - req.Host = p.opts.HostHeaderRewrite - } - for k, v := range p.opts.RequestHeaders.Set { - req.Header.Set(k, v) - } + rewriteHTTPPluginRequest(req, "https", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders) }, - Transport: tr, - BufferPool: pool.NewBuffer(32 * 1024), - ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0), - } - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.TLS != nil { - tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName) - host, _ := httppkg.CanonicalHost(r.Host) - if tlsServerName != "" && tlsServerName != host { - w.WriteHeader(http.StatusMisdirectedRequest) - return - } - } - rp.ServeHTTP(w, r) - }) + tr, + ) - tlsConfig, err := transport.NewServerTLSConfig(p.opts.CrtPath, p.opts.KeyPath, "") + server, err := newHTTPSBridgePluginServer(rp, p.opts.CrtPath, p.opts.KeyPath, opts.EnableHTTP2, true) if err != nil { - return nil, fmt.Errorf("gen TLS config error: %v", err) + return nil, err } + p.httpBridgePlugin = server - p.s = &http.Server{ - Handler: handler, - ReadHeaderTimeout: 60 * time.Second, - TLSConfig: tlsConfig, - } - if !lo.FromPtr(opts.EnableHTTP2) { - p.s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) - } - - go func() { - _ = p.s.ServeTLS(listener, "", "") - }() return p, nil } -func (p *HTTPS2HTTPSPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) { - wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn) - if connInfo.SrcAddr != nil { - wrapConn.SetRemoteAddr(connInfo.SrcAddr) - } - _ = p.l.PutConn(wrapConn) -} - func (p *HTTPS2HTTPSPlugin) Name() string { return v1.PluginHTTPS2HTTPS } - -func (p *HTTPS2HTTPSPlugin) Close() error { - return p.s.Close() -} diff --git a/pkg/plugin/client/internal/httpsserver/server.go b/pkg/plugin/client/internal/httpsserver/server.go new file mode 100644 index 00000000..a2047b90 --- /dev/null +++ b/pkg/plugin/client/internal/httpsserver/server.go @@ -0,0 +1,60 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !frps + +package httpsserver + +import ( + "crypto/tls" + "fmt" + "net/http" + "time" + + "github.com/samber/lo" + + "github.com/fatedier/frp/pkg/transport" + httppkg "github.com/fatedier/frp/pkg/util/http" +) + +func New(handler http.Handler, crtPath, keyPath string, enableHTTP2 *bool) (*http.Server, error) { + tlsConfig, err := transport.NewServerTLSConfig(crtPath, keyPath, "") + if err != nil { + return nil, fmt.Errorf("gen TLS config error: %v", err) + } + + server := &http.Server{ + Handler: withMisdirectedRequestCheck(handler), + ReadHeaderTimeout: 60 * time.Second, + TLSConfig: tlsConfig, + } + if !lo.FromPtr(enableHTTP2) { + server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) + } + return server, nil +} + +func withMisdirectedRequestCheck(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.TLS != nil { + tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName) + host, _ := httppkg.CanonicalHost(r.Host) + if tlsServerName != "" && tlsServerName != host { + w.WriteHeader(http.StatusMisdirectedRequest) + return + } + } + handler.ServeHTTP(w, r) + }) +} diff --git a/pkg/plugin/server/manager.go b/pkg/plugin/server/manager.go index dabfb46c..6f1061c6 100644 --- a/pkg/plugin/server/manager.go +++ b/pkg/plugin/server/manager.go @@ -44,6 +44,67 @@ func NewManager() *Manager { } } +func newPluginRequestContext() (context.Context, *xlog.Logger) { + reqid, _ := util.RandID() + xl := xlog.New().AppendPrefix("reqid: " + reqid) + ctx := xlog.NewContext(context.Background(), xl) + return NewReqidContext(ctx, reqid), xl +} + +type pluginErrorLogMode bool + +const ( + // Warn is the zero value because it is the default for mutable plugin operations. + pluginErrorLogWarn pluginErrorLogMode = false + pluginErrorLogInfo pluginErrorLogMode = true +) + +func logPluginError(xl *xlog.Logger, p Plugin, op string, err error, mode pluginErrorLogMode) { + if mode == pluginErrorLogInfo { + xl.Infof("send %s request to plugin [%s] error: %v", op, p.Name(), err) + return + } + xl.Warnf("send %s request to plugin [%s] error: %v", op, p.Name(), err) +} + +func handleMutableContent[T any]( + plugins []Plugin, + op string, + content *T, + logMode pluginErrorLogMode, +) (*T, error) { + if len(plugins) == 0 { + return content, nil + } + + var ( + res = &Response{ + Reject: false, + Unchange: true, + } + retContent any + err error + ) + ctx, xl := newPluginRequestContext() + + for _, p := range plugins { + res, retContent, err = p.Handle(ctx, op, *content) + if err != nil { + logPluginError(xl, p, op, err, logMode) + return nil, errors.New("send " + op + " request to plugin error") + } + if res.Reject { + return nil, fmt.Errorf("%s", res.RejectReason) + } + if !res.Unchange { + // Preserve the existing Plugin contract: changed content must be *T. + // Buggy Plugin implementations still panic here, by design. + content = retContent.(*T) + } + } + return content, nil +} + func (m *Manager) Register(p Plugin) { if p.IsSupport(OpLogin) { m.loginPlugins = append(m.loginPlugins, p) @@ -66,71 +127,11 @@ func (m *Manager) Register(p Plugin) { } func (m *Manager) Login(content *LoginContent) (*LoginContent, error) { - if len(m.loginPlugins) == 0 { - return content, nil - } - - var ( - res = &Response{ - Reject: false, - Unchange: true, - } - retContent any - err error - ) - reqid, _ := util.RandID() - xl := xlog.New().AppendPrefix("reqid: " + reqid) - ctx := xlog.NewContext(context.Background(), xl) - ctx = NewReqidContext(ctx, reqid) - - for _, p := range m.loginPlugins { - res, retContent, err = p.Handle(ctx, OpLogin, *content) - if err != nil { - xl.Warnf("send Login request to plugin [%s] error: %v", p.Name(), err) - return nil, errors.New("send Login request to plugin error") - } - if res.Reject { - return nil, fmt.Errorf("%s", res.RejectReason) - } - if !res.Unchange { - content = retContent.(*LoginContent) - } - } - return content, nil + return handleMutableContent(m.loginPlugins, OpLogin, content, pluginErrorLogWarn) } func (m *Manager) NewProxy(content *NewProxyContent) (*NewProxyContent, error) { - if len(m.newProxyPlugins) == 0 { - return content, nil - } - - var ( - res = &Response{ - Reject: false, - Unchange: true, - } - retContent any - err error - ) - reqid, _ := util.RandID() - xl := xlog.New().AppendPrefix("reqid: " + reqid) - ctx := xlog.NewContext(context.Background(), xl) - ctx = NewReqidContext(ctx, reqid) - - for _, p := range m.newProxyPlugins { - res, retContent, err = p.Handle(ctx, OpNewProxy, *content) - if err != nil { - xl.Warnf("send NewProxy request to plugin [%s] error: %v", p.Name(), err) - return nil, errors.New("send NewProxy request to plugin error") - } - if res.Reject { - return nil, fmt.Errorf("%s", res.RejectReason) - } - if !res.Unchange { - content = retContent.(*NewProxyContent) - } - } - return content, nil + return handleMutableContent(m.newProxyPlugins, OpNewProxy, content, pluginErrorLogWarn) } func (m *Manager) CloseProxy(content *CloseProxyContent) error { @@ -139,10 +140,7 @@ func (m *Manager) CloseProxy(content *CloseProxyContent) error { } errs := make([]string, 0) - reqid, _ := util.RandID() - xl := xlog.New().AppendPrefix("reqid: " + reqid) - ctx := xlog.NewContext(context.Background(), xl) - ctx = NewReqidContext(ctx, reqid) + ctx, xl := newPluginRequestContext() for _, p := range m.closeProxyPlugins { _, _, err := p.Handle(ctx, OpCloseProxy, *content) @@ -159,103 +157,14 @@ func (m *Manager) CloseProxy(content *CloseProxyContent) error { } func (m *Manager) Ping(content *PingContent) (*PingContent, error) { - if len(m.pingPlugins) == 0 { - return content, nil - } - - var ( - res = &Response{ - Reject: false, - Unchange: true, - } - retContent any - err error - ) - reqid, _ := util.RandID() - xl := xlog.New().AppendPrefix("reqid: " + reqid) - ctx := xlog.NewContext(context.Background(), xl) - ctx = NewReqidContext(ctx, reqid) - - for _, p := range m.pingPlugins { - res, retContent, err = p.Handle(ctx, OpPing, *content) - if err != nil { - xl.Warnf("send Ping request to plugin [%s] error: %v", p.Name(), err) - return nil, errors.New("send Ping request to plugin error") - } - if res.Reject { - return nil, fmt.Errorf("%s", res.RejectReason) - } - if !res.Unchange { - content = retContent.(*PingContent) - } - } - return content, nil + return handleMutableContent(m.pingPlugins, OpPing, content, pluginErrorLogWarn) } func (m *Manager) NewWorkConn(content *NewWorkConnContent) (*NewWorkConnContent, error) { - if len(m.newWorkConnPlugins) == 0 { - return content, nil - } - - var ( - res = &Response{ - Reject: false, - Unchange: true, - } - retContent any - err error - ) - reqid, _ := util.RandID() - xl := xlog.New().AppendPrefix("reqid: " + reqid) - ctx := xlog.NewContext(context.Background(), xl) - ctx = NewReqidContext(ctx, reqid) - - for _, p := range m.newWorkConnPlugins { - res, retContent, err = p.Handle(ctx, OpNewWorkConn, *content) - if err != nil { - xl.Warnf("send NewWorkConn request to plugin [%s] error: %v", p.Name(), err) - return nil, errors.New("send NewWorkConn request to plugin error") - } - if res.Reject { - return nil, fmt.Errorf("%s", res.RejectReason) - } - if !res.Unchange { - content = retContent.(*NewWorkConnContent) - } - } - return content, nil + return handleMutableContent(m.newWorkConnPlugins, OpNewWorkConn, content, pluginErrorLogWarn) } func (m *Manager) NewUserConn(content *NewUserConnContent) (*NewUserConnContent, error) { - if len(m.newUserConnPlugins) == 0 { - return content, nil - } - - var ( - res = &Response{ - Reject: false, - Unchange: true, - } - retContent any - err error - ) - reqid, _ := util.RandID() - xl := xlog.New().AppendPrefix("reqid: " + reqid) - ctx := xlog.NewContext(context.Background(), xl) - ctx = NewReqidContext(ctx, reqid) - - for _, p := range m.newUserConnPlugins { - res, retContent, err = p.Handle(ctx, OpNewUserConn, *content) - if err != nil { - xl.Infof("send NewUserConn request to plugin [%s] error: %v", p.Name(), err) - return nil, errors.New("send NewUserConn request to plugin error") - } - if res.Reject { - return nil, fmt.Errorf("%s", res.RejectReason) - } - if !res.Unchange { - content = retContent.(*NewUserConnContent) - } - } - return content, nil + // Preserve the pre-refactor log level for NewUserConn plugin errors. + return handleMutableContent(m.newUserConnPlugins, OpNewUserConn, content, pluginErrorLogInfo) } diff --git a/pkg/plugin/server/manager_test.go b/pkg/plugin/server/manager_test.go new file mode 100644 index 00000000..92391d4b --- /dev/null +++ b/pkg/plugin/server/manager_test.go @@ -0,0 +1,336 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "context" + "errors" + "strings" + "sync" + "testing" + "time" + + goliblog "github.com/fatedier/golib/log" + + "github.com/fatedier/frp/pkg/msg" + frplog "github.com/fatedier/frp/pkg/util/log" +) + +type testPlugin struct { + name string + ops map[string]bool + handler func(context.Context, string, any) (*Response, any, error) +} + +// Log-capturing subtests serialize global logger swaps; do not use t.Parallel. +var logCaptureMu sync.Mutex + +type logCapture struct { + bytes.Buffer + levels []goliblog.Level +} + +func (p testPlugin) Name() string { + return p.name +} + +func (p testPlugin) IsSupport(op string) bool { + return p.ops[op] +} + +func (p testPlugin) Handle(ctx context.Context, op string, content any) (*Response, any, error) { + return p.handler(ctx, op, content) +} + +func (w *logCapture) WriteLog(p []byte, level goliblog.Level, _ time.Time) (int, error) { + w.levels = append(w.levels, level) + return w.Write(p) +} + +func captureLogOutput(t *testing.T) *logCapture { + t.Helper() + + logCaptureMu.Lock() + logOutput := &logCapture{} + oldLogger := frplog.Logger + frplog.Logger = goliblog.New( + goliblog.WithOutput(logOutput), + goliblog.WithLevel(goliblog.TraceLevel), + goliblog.WithCaller(false), + ) + t.Cleanup(func() { + frplog.Logger = oldLogger + logCaptureMu.Unlock() + }) + return logOutput +} + +var mutablePluginOps = []struct { + name string + op string +}{ + {name: "login", op: OpLogin}, + {name: "new proxy", op: OpNewProxy}, + {name: "ping", op: OpPing}, + {name: "new work conn", op: OpNewWorkConn}, + {name: "new user conn", op: OpNewUserConn}, +} + +func callMutableWithUser(m *Manager, op string, user string) (string, error) { + switch op { + case OpLogin: + got, err := m.Login(&LoginContent{Login: msg.Login{User: user}}) + if got == nil { + return "", err + } + return got.User, err + case OpNewProxy: + got, err := m.NewProxy(&NewProxyContent{User: UserInfo{User: user}}) + if got == nil { + return "", err + } + return got.User.User, err + case OpPing: + got, err := m.Ping(&PingContent{User: UserInfo{User: user}}) + if got == nil { + return "", err + } + return got.User.User, err + case OpNewWorkConn: + got, err := m.NewWorkConn(&NewWorkConnContent{User: UserInfo{User: user}}) + if got == nil { + return "", err + } + return got.User.User, err + case OpNewUserConn: + got, err := m.NewUserConn(&NewUserConnContent{User: UserInfo{User: user}}) + if got == nil { + return "", err + } + return got.User.User, err + default: + panic("unsupported mutable op: " + op) + } +} + +func mutableUser(t *testing.T, op string, content any) string { + t.Helper() + + switch op { + case OpLogin: + return content.(LoginContent).User + case OpNewProxy: + return content.(NewProxyContent).User.User + case OpPing: + return content.(PingContent).User.User + case OpNewWorkConn: + return content.(NewWorkConnContent).User.User + case OpNewUserConn: + return content.(NewUserConnContent).User.User + default: + t.Fatalf("unsupported mutable op: %s", op) + return "" + } +} + +func mutateMutableContent(t *testing.T, op string, content any, user string) any { + t.Helper() + + switch op { + case OpLogin: + got := content.(LoginContent) + got.User = user + return &got + case OpNewProxy: + got := content.(NewProxyContent) + got.User.User = user + return &got + case OpPing: + got := content.(PingContent) + got.User.User = user + return &got + case OpNewWorkConn: + got := content.(NewWorkConnContent) + got.User.User = user + return &got + case OpNewUserConn: + got := content.(NewUserConnContent) + got.User.User = user + return &got + default: + t.Fatalf("unsupported mutable op: %s", op) + return nil + } +} + +func TestManagerMutableContentAcrossOps(t *testing.T) { + for _, tt := range mutablePluginOps { + t.Run(tt.name, func(t *testing.T) { + m := NewManager() + m.Register(testPlugin{ + name: "mutate", + ops: map[string]bool{tt.op: true}, + handler: func(ctx context.Context, op string, content any) (*Response, any, error) { + if op != tt.op { + t.Fatalf("unexpected op: %s", op) + } + if GetReqidFromContext(ctx) == "" { + t.Fatal("expected request id in context") + } + if got := mutableUser(t, tt.op, content); got != "initial" { + t.Fatalf("expected initial user, got %q", got) + } + return &Response{Unchange: false}, mutateMutableContent(t, tt.op, content, "mutated"), nil + }, + }) + m.Register(testPlugin{ + name: "observe", + ops: map[string]bool{tt.op: true}, + handler: func(ctx context.Context, op string, content any) (*Response, any, error) { + if op != tt.op { + t.Fatalf("unexpected op: %s", op) + } + if GetReqidFromContext(ctx) == "" { + t.Fatal("expected request id in context") + } + if got := mutableUser(t, tt.op, content); got != "mutated" { + t.Fatalf("expected mutated user, got %q", got) + } + return &Response{Unchange: true}, mutateMutableContent(t, tt.op, content, "ignored"), nil + }, + }) + + got, err := callMutableWithUser(m, tt.op, "initial") + if err != nil { + t.Fatalf("mutable op failed: %v", err) + } + if got != "mutated" { + t.Fatalf("expected mutated user, got %q", got) + } + }) + } +} + +func TestManagerMutableContentRejectStopsChain(t *testing.T) { + m := NewManager() + + var called bool + m.Register(testPlugin{ + name: "reject", + ops: map[string]bool{OpPing: true}, + handler: func(context.Context, string, any) (*Response, any, error) { + return &Response{Reject: true, RejectReason: "blocked"}, nil, nil + }, + }) + m.Register(testPlugin{ + name: "unused", + ops: map[string]bool{OpPing: true}, + handler: func(context.Context, string, any) (*Response, any, error) { + called = true + return &Response{Unchange: true}, nil, nil + }, + }) + + got, err := m.Ping(&PingContent{}) + if err == nil { + t.Fatal("expected reject error") + } + if got != nil { + t.Fatalf("expected no returned content, got %#v", got) + } + if err.Error() != "blocked" { + t.Fatalf("unexpected error: %v", err) + } + if called { + t.Fatal("expected plugin chain to stop after reject") + } +} + +func TestManagerMutableContentPluginErrorLogLevel(t *testing.T) { + tests := []struct { + name string + op string + level goliblog.Level + }{ + {name: "default warning", op: OpLogin, level: goliblog.WarnLevel}, + {name: "new user conn info", op: OpNewUserConn, level: goliblog.InfoLevel}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOutput := captureLogOutput(t) + m := NewManager() + m.Register(testPlugin{ + name: "error", + ops: map[string]bool{tt.op: true}, + handler: func(context.Context, string, any) (*Response, any, error) { + return nil, nil, errors.New("boom") + }, + }) + + _, err := callMutableWithUser(m, tt.op, "initial") + if err == nil { + t.Fatal("expected plugin error") + } + if want := "send " + tt.op + " request to plugin error"; err.Error() != want { + t.Fatalf("unexpected error: %v", err) + } + if len(logOutput.levels) != 1 || logOutput.levels[0] != tt.level { + t.Fatalf("expected log level %v, got %v in %q", tt.level, logOutput.levels, logOutput.String()) + } + }) + } +} + +func TestManagerCloseProxyAggregatesErrors(t *testing.T) { + logOutput := captureLogOutput(t) + m := NewManager() + + for _, name := range []string{"first", "second"} { + m.Register(testPlugin{ + name: name, + ops: map[string]bool{OpCloseProxy: true}, + handler: func(ctx context.Context, op string, content any) (*Response, any, error) { + if GetReqidFromContext(ctx) == "" { + t.Fatal("expected request id in context") + } + if op != OpCloseProxy { + t.Fatalf("unexpected op: %s", op) + } + return nil, nil, errors.New(name + " error") + }, + }) + } + + err := m.CloseProxy(&CloseProxyContent{}) + if err == nil { + t.Fatal("expected close proxy error") + } + if !strings.HasPrefix(err.Error(), "send CloseProxy request to plugin errors: ") { + t.Fatalf("unexpected close proxy error prefix: %v", err) + } + if !strings.Contains(err.Error(), "[first]: first error") || !strings.Contains(err.Error(), "[second]: second error") { + t.Fatalf("missing aggregated errors: %v", err) + } + if len(logOutput.levels) != 2 { + t.Fatalf("expected two warning logs, got %v", logOutput.levels) + } + for _, level := range logOutput.levels { + if level != goliblog.WarnLevel { + t.Fatalf("expected warning log level, got %v", logOutput.levels) + } + } +} diff --git a/pkg/util/metric/date_counter.go b/pkg/util/metric/date_counter.go index 4524fec8..5e6b889f 100644 --- a/pkg/util/metric/date_counter.go +++ b/pkg/util/metric/date_counter.go @@ -17,6 +17,8 @@ package metric import ( "sync" "time" + + "k8s.io/utils/clock" ) type DateCounter interface { @@ -38,27 +40,33 @@ func NewDateCounter(reserveDays int64) DateCounter { type StandardDateCounter struct { reserveDays int64 counts []int64 + clock clock.PassiveClock lastUpdateDate time.Time mu sync.Mutex } func newStandardDateCounter(reserveDays int64) *StandardDateCounter { - now := time.Now() - now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) - s := &StandardDateCounter{ + return newStandardDateCounterWithClock(reserveDays, clock.RealClock{}) +} + +func newStandardDateCounterWithClock(reserveDays int64, clk clock.PassiveClock) *StandardDateCounter { + if clk == nil { + clk = clock.RealClock{} + } + return &StandardDateCounter{ reserveDays: reserveDays, counts: make([]int64, reserveDays), - lastUpdateDate: now, + clock: clk, + lastUpdateDate: startOfDay(clk.Now()), } - return s } func (c *StandardDateCounter) TodayCount() int64 { c.mu.Lock() defer c.mu.Unlock() - c.rotate(time.Now()) + c.rotate(c.clock.Now()) return c.counts[0] } @@ -70,65 +78,61 @@ func (c *StandardDateCounter) GetLastDaysCount(lastdays int64) []int64 { c.mu.Lock() defer c.mu.Unlock() - c.rotate(time.Now()) - for i := 0; i < int(lastdays); i++ { - counts[i] = c.counts[i] - } + c.rotate(c.clock.Now()) + copy(counts, c.counts) return counts } func (c *StandardDateCounter) Inc(count int64) { c.mu.Lock() defer c.mu.Unlock() - c.rotate(time.Now()) + c.rotate(c.clock.Now()) c.counts[0] += count } func (c *StandardDateCounter) Dec(count int64) { c.mu.Lock() defer c.mu.Unlock() - c.rotate(time.Now()) + c.rotate(c.clock.Now()) c.counts[0] -= count } func (c *StandardDateCounter) Snapshot() DateCounter { c.mu.Lock() defer c.mu.Unlock() - tmp := newStandardDateCounter(c.reserveDays) - for i := 0; i < int(c.reserveDays); i++ { - tmp.counts[i] = c.counts[i] - } + tmp := newStandardDateCounterWithClock(c.reserveDays, c.clock) + copy(tmp.counts, c.counts) return tmp } func (c *StandardDateCounter) Clear() { c.mu.Lock() defer c.mu.Unlock() - for i := 0; i < int(c.reserveDays); i++ { - c.counts[i] = 0 - } + clear(c.counts) } // rotate // Must hold the lock before calling this function. func (c *StandardDateCounter) rotate(now time.Time) { - now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) + now = startOfDay(now) days := int(now.Sub(c.lastUpdateDate).Hours() / 24) - - defer func() { - c.lastUpdateDate = now - }() + reserveDays := int(c.reserveDays) if days <= 0 { return - } else if days >= int(c.reserveDays) { + } else if days >= reserveDays { c.counts = make([]int64, c.reserveDays) + c.lastUpdateDate = now return } newCounts := make([]int64, c.reserveDays) - for i := days; i < int(c.reserveDays); i++ { - newCounts[i] = c.counts[i-days] - } + copy(newCounts[days:], c.counts[:reserveDays-days]) c.counts = newCounts + c.lastUpdateDate = now +} + +// startOfDay returns midnight in t's location. +func startOfDay(t time.Time) time.Time { + return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) } diff --git a/pkg/util/metric/date_counter_test.go b/pkg/util/metric/date_counter_test.go index 8752f198..cbd29386 100644 --- a/pkg/util/metric/date_counter_test.go +++ b/pkg/util/metric/date_counter_test.go @@ -1,9 +1,12 @@ package metric import ( + "sync" "testing" + "time" "github.com/stretchr/testify/require" + clocktesting "k8s.io/utils/clock/testing" ) func TestDateCounter(t *testing.T) { @@ -25,3 +28,107 @@ func TestDateCounter(t *testing.T) { dcTmp := dc.Snapshot() require.EqualValues(5, dcTmp.TodayCount()) } + +func TestDateCounterRotate(t *testing.T) { + loc := time.FixedZone("test", 8*60*60) + lastUpdateDate := time.Date(2026, time.May, 8, 0, 0, 0, 0, loc) + + tests := []struct { + name string + now time.Time + want []int64 + wantLastUpdateDate time.Time + }{ + { + name: "same day", + now: time.Date(2026, time.May, 8, 12, 30, 0, 0, loc), + want: []int64{10, 7, 3}, + wantLastUpdateDate: lastUpdateDate, + }, + { + name: "clock skew", + now: time.Date(2026, time.May, 7, 12, 30, 0, 0, loc), + want: []int64{10, 7, 3}, + wantLastUpdateDate: lastUpdateDate, + }, + { + name: "one day", + now: time.Date(2026, time.May, 9, 12, 30, 0, 0, loc), + want: []int64{0, 10, 7}, + wantLastUpdateDate: time.Date(2026, time.May, 9, 0, 0, 0, 0, loc), + }, + { + name: "two days", + now: time.Date(2026, time.May, 10, 12, 30, 0, 0, loc), + want: []int64{0, 0, 10}, + wantLastUpdateDate: time.Date(2026, time.May, 10, 0, 0, 0, 0, loc), + }, + { + name: "all reserved days elapsed", + now: time.Date(2026, time.May, 11, 12, 30, 0, 0, loc), + want: []int64{0, 0, 0}, + wantLastUpdateDate: time.Date(2026, time.May, 11, 0, 0, 0, 0, loc), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + dc := newStandardDateCounter(3) + dc.counts = []int64{10, 7, 3} + dc.lastUpdateDate = lastUpdateDate + + dc.mu.Lock() + dc.rotate(tt.now) + dc.mu.Unlock() + + require.Equal(tt.want, dc.counts) + require.Equal(tt.wantLastUpdateDate, dc.lastUpdateDate) + }) + } +} + +func TestDateCounterGetLastDaysCountReturnsCopy(t *testing.T) { + require := require.New(t) + + clk := clocktesting.NewFakeClock(time.Date(2026, time.May, 8, 12, 30, 0, 0, time.Local)) + dc := newStandardDateCounterWithClock(3, clk) + dc.counts = []int64{10, 7, 3} + + counts := dc.GetLastDaysCount(2) + require.Equal([]int64{10, 7}, counts) + + counts[0] = 100 + require.Equal([]int64{10, 7}, dc.GetLastDaysCount(2)) +} + +func TestDateCounterClear(t *testing.T) { + require := require.New(t) + + dc := newStandardDateCounter(3) + dc.counts = []int64{10, 7, 3} + + dc.Clear() + + require.Equal([]int64{0, 0, 0}, dc.counts) +} + +func TestDateCounterConcurrentAccess(t *testing.T) { + clk := clocktesting.NewFakeClock(time.Date(2026, time.May, 8, 12, 30, 0, 0, time.Local)) + dc := newStandardDateCounterWithClock(3, clk) + + var wg sync.WaitGroup + for range 8 { + wg.Go(func() { + for range 100 { + dc.Inc(1) + dc.Dec(1) + _ = dc.TodayCount() + _ = dc.GetLastDaysCount(3) + _ = dc.Snapshot() + } + }) + } + wg.Wait() +} diff --git a/pkg/util/tcpmux/httpconnect.go b/pkg/util/tcpmux/httpconnect.go index 6be29a4a..d43b5904 100644 --- a/pkg/util/tcpmux/httpconnect.go +++ b/pkg/util/tcpmux/httpconnect.go @@ -39,11 +39,14 @@ type HTTPConnectTCPMuxer struct { func NewHTTPConnectTCPMuxer(listener net.Listener, passthrough bool, timeout time.Duration) (*HTTPConnectTCPMuxer, error) { ret := &HTTPConnectTCPMuxer{passthrough: passthrough} mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, timeout) + if err != nil { + return nil, err + } mux.SetCheckAuthFunc(ret.auth). SetSuccessHookFunc(ret.sendConnectResponse). SetFailHookFunc(vhostFailed) ret.Muxer = mux - return ret, err + return ret, nil } func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host, httpUser, httpPwd string, err error) { diff --git a/pkg/util/vhost/http.go b/pkg/util/vhost/http.go index cac88e35..865ff9d5 100644 --- a/pkg/util/vhost/http.go +++ b/pkg/util/vhost/http.go @@ -24,7 +24,6 @@ import ( "net/http" "net/http/httputil" "net/url" - "strings" "time" libio "github.com/fatedier/golib/io" @@ -160,7 +159,7 @@ func (rp *HTTPReverseProxy) UnRegister(routeCfg RouteConfig) { } func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser string) *RouteConfig { - vr, ok := rp.getVhost(domain, location, routeByHTTPUser) + vr, ok := rp.vhostRouter.getByRoute(domain, location, routeByHTTPUser) if ok { log.Debugf("get new http request host [%s] path [%s] httpuser [%s]", domain, location, routeByHTTPUser) return vr.payload.(*RouteConfig) @@ -171,7 +170,7 @@ func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser str // CreateConnection create a new connection by route config func (rp *HTTPReverseProxy) CreateConnection(reqRouteInfo *RequestRouteInfo, byEndpoint bool) (net.Conn, error) { host, _ := httppkg.CanonicalHost(reqRouteInfo.Host) - vr, ok := rp.getVhost(host, reqRouteInfo.URL, reqRouteInfo.HTTPUser) + vr, ok := rp.vhostRouter.getByRoute(host, reqRouteInfo.URL, reqRouteInfo.HTTPUser) if ok { if byEndpoint { fn := vr.payload.(*RouteConfig).CreateConnByEndpointFn @@ -208,50 +207,6 @@ func checkRouteAuthByRequest(req *http.Request, rc *RouteConfig) bool { return ok && user == rc.Username && passwd == rc.Password } -// getVhost tries to get vhost router by route policy. -func (rp *HTTPReverseProxy) getVhost(domain, location, routeByHTTPUser string) (*Router, bool) { - findRouter := func(inDomain, inLocation, inRouteByHTTPUser string) (*Router, bool) { - vr, ok := rp.vhostRouter.Get(inDomain, inLocation, inRouteByHTTPUser) - if ok { - return vr, ok - } - // Try to check if there is one proxy that doesn't specify routerByHTTPUser, it means match all. - vr, ok = rp.vhostRouter.Get(inDomain, inLocation, "") - if ok { - return vr, ok - } - return nil, false - } - - // First we check the full hostname - // if not exist, then check the wildcard_domain such as *.example.com - vr, ok := findRouter(domain, location, routeByHTTPUser) - if ok { - return vr, ok - } - - // e.g. domain = test.example.com, try to match wildcard domains. - // *.example.com - // *.com - domainSplit := strings.Split(domain, ".") - for len(domainSplit) >= 3 { - domainSplit[0] = "*" - domain = strings.Join(domainSplit, ".") - vr, ok = findRouter(domain, location, routeByHTTPUser) - if ok { - return vr, true - } - domainSplit = domainSplit[1:] - } - - // Finally, try to check if there is one proxy that domain is "*" means match all domains. - vr, ok = findRouter("*", location, routeByHTTPUser) - if ok { - return vr, true - } - return nil, false -} - func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Request) { hj, ok := rw.(http.Hijacker) if !ok { diff --git a/pkg/util/vhost/https.go b/pkg/util/vhost/https.go index bcfdb81e..857a10d2 100644 --- a/pkg/util/vhost/https.go +++ b/pkg/util/vhost/https.go @@ -29,11 +29,11 @@ type HTTPSMuxer struct { func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, error) { mux, err := NewMuxer(listener, GetHTTPSHostname, timeout) - mux.SetFailHookFunc(vhostFailed) if err != nil { return nil, err } - return &HTTPSMuxer{mux}, err + mux.SetFailHookFunc(vhostFailed) + return &HTTPSMuxer{mux}, nil } func GetHTTPSHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) { diff --git a/pkg/util/vhost/https_test.go b/pkg/util/vhost/https_test.go index 08a3f558..6c303687 100644 --- a/pkg/util/vhost/https_test.go +++ b/pkg/util/vhost/https_test.go @@ -16,23 +16,53 @@ func TestGetHTTPSHostname(t *testing.T) { require.NoError(err) defer l.Close() - var conn net.Conn + connCh := make(chan net.Conn, 1) + acceptErrCh := make(chan error, 1) go func() { - conn, _ = l.Accept() - require.NotNil(conn) + conn, err := l.Accept() + if err != nil { + acceptErrCh <- err + return + } + connCh <- conn }() + clientErrCh := make(chan error, 1) go func() { time.Sleep(100 * time.Millisecond) - tls.Dial("tcp", l.Addr().String(), &tls.Config{ + conn, err := tls.Dial("tcp", l.Addr().String(), &tls.Config{ InsecureSkipVerify: true, ServerName: "example.com", }) + if conn != nil { + _ = conn.Close() + } + clientErrCh <- err }() - time.Sleep(200 * time.Millisecond) - _, infos, err := GetHTTPSHostname(conn) + var conn net.Conn + select { + case conn = <-connCh: + case err := <-acceptErrCh: + require.NoError(err) + case <-time.After(time.Second): + t.Fatal("timed out waiting for accepted connection") + } + require.NotNil(conn) + + serverConn, infos, err := GetHTTPSHostname(conn) + if serverConn != nil { + _ = serverConn.Close() + } else { + _ = conn.Close() + } require.NoError(err) require.Equal("example.com", infos["Host"]) require.Equal("https", infos["Scheme"]) + + select { + case <-clientErrCh: + case <-time.After(time.Second): + t.Fatal("timed out waiting for TLS client") + } } diff --git a/pkg/util/vhost/router.go b/pkg/util/vhost/router.go index 315ff9e7..f8ec99d3 100644 --- a/pkg/util/vhost/router.go +++ b/pkg/util/vhost/router.go @@ -93,12 +93,19 @@ func (r *Routers) Del(domain, location, httpUser string) { routersByHTTPUser[httpUser] = newVrs } +// Get returns the best location match for an exact host and exact HTTP user. +// It does not apply all-users, wildcard-domain, or catch-all-domain fallback. func (r *Routers) Get(host, path, httpUser string) (vr *Router, exist bool) { host = strings.ToLower(host) r.mutex.RLock() defer r.mutex.RUnlock() + return r.getLocked(host, path, httpUser) +} + +// getLocked performs an exact-host lookup; host must already be lower-cased. +func (r *Routers) getLocked(host, path, httpUser string) (vr *Router, exist bool) { routersByHTTPUser, found := r.indexByDomain[host] if !found { return @@ -117,6 +124,49 @@ func (r *Routers) Get(host, path, httpUser string) (vr *Router, exist bool) { return } +func (r *Routers) getByRoute(host, path, httpUser string) (*Router, bool) { + host = strings.ToLower(host) + + r.mutex.RLock() + defer r.mutex.RUnlock() + + // First we check the full hostname; if it doesn't exist, then check wildcard domains. + // For example, test.example.com checks *.example.com before falling back to "*". + vr, ok := r.getExactOrAllUsersLocked(host, path, httpUser) + if ok { + return vr, true + } + + hostSplit := strings.Split(host, ".") + // Keep two-label hosts out of the wildcard walk, so example.com does not match *.com. + for len(hostSplit) >= 3 { + // Replace the leftmost remaining label with the wildcard marker. + hostSplit[0] = "*" + host = strings.Join(hostSplit, ".") + vr, ok = r.getExactOrAllUsersLocked(host, path, httpUser) + if ok { + return vr, true + } + hostSplit = hostSplit[1:] + } + + // Finally, try to check if there is one proxy whose domain is "*", which means match all domains. + return r.getExactOrAllUsersLocked("*", path, httpUser) +} + +func (r *Routers) getExactOrAllUsersLocked(host, path, httpUser string) (*Router, bool) { + vr, ok := r.getLocked(host, path, httpUser) + if ok { + return vr, true + } + // Try to check if there is one proxy that doesn't specify routeByHTTPUser, it means match all. + vr, ok = r.getLocked(host, path, "") + if ok { + return vr, true + } + return nil, false +} + func (r *Routers) exist(host, path, httpUser string) (route *Router, exist bool) { routersByHTTPUser, found := r.indexByDomain[host] if !found { diff --git a/pkg/util/vhost/router_test.go b/pkg/util/vhost/router_test.go new file mode 100644 index 00000000..f7931edc --- /dev/null +++ b/pkg/util/vhost/router_test.go @@ -0,0 +1,257 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vhost + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRoutersGet(t *testing.T) { + routers := NewRouters() + require.NoError(t, routers.Add("example.com", "/api", "alice", "exact-user")) + require.NoError(t, routers.Add("example.com", "/public", "", "exact-all-users")) + require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users")) + require.NoError(t, routers.Add("*", "/api", "", "all-domains")) + + t.Run("exact host and user match with normalized domain", func(t *testing.T) { + router, ok := routers.Get("EXAMPLE.COM", "/api/users", "alice") + require.True(t, ok) + require.Equal(t, "exact-user", router.payload) + }) + + t.Run("exact host and empty user match", func(t *testing.T) { + router, ok := routers.Get("EXAMPLE.COM", "/public/docs", "") + require.True(t, ok) + require.Equal(t, "exact-all-users", router.payload) + }) + + t.Run("does not fall back from named user to empty user", func(t *testing.T) { + // Get intentionally requires an exact HTTP user; route-level fallbacks live in getByRoute. + _, ok := routers.Get("EXAMPLE.COM", "/public/docs", "alice") + require.False(t, ok) + }) + + t.Run("does not fall back to wildcard domains", func(t *testing.T) { + _, ok := routers.Get("foo.example.com", "/api/users", "") + require.False(t, ok) + }) + + t.Run("does not fall back to catch-all domain", func(t *testing.T) { + _, ok := routers.Get("missing.test", "/api/users", "") + require.False(t, ok) + }) +} + +func TestRoutersGetByRoute(t *testing.T) { + routers := NewRouters() + require.NoError(t, routers.Add("example.com", "/api", "alice", "exact-user")) + require.NoError(t, routers.Add("example.com", "/api", "", "exact-all-users")) + require.NoError(t, routers.Add("exact.example.com", "/api", "", "exact-subdomain")) + require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users")) + require.NoError(t, routers.Add("*.foo.example.com", "/api", "", "specific-wildcard")) + require.NoError(t, routers.Add("*.bar.com", "/api", "", "wildcard-parent-domain")) + require.NoError(t, routers.Add("*", "/admin", "root", "all-domains-user")) + require.NoError(t, routers.Add("*", "/", "", "all-domains")) + + tests := []struct { + name string + domain string + location string + httpUser string + want string + }{ + { + name: "exact domain and http user", + domain: "example.com", + location: "/api/users", + httpUser: "alice", + want: "exact-user", + }, + { + name: "exact domain falls back to all users", + domain: "example.com", + location: "/api/users", + httpUser: "bob", + want: "exact-all-users", + }, + { + name: "wildcard domain uses all users fallback", + domain: "foo.example.com", + location: "/api/users", + httpUser: "bob", + want: "wildcard-all-users", + }, + { + name: "mixed-case domain is normalized", + domain: "Foo.Example.Com", + location: "/api/users", + httpUser: "bob", + want: "wildcard-all-users", + }, + { + name: "exact domain wins over wildcard domain", + domain: "exact.example.com", + location: "/api/users", + httpUser: "bob", + want: "exact-subdomain", + }, + { + name: "more specific wildcard wins over broader wildcard", + domain: "bar.foo.example.com", + location: "/api/users", + httpUser: "bob", + want: "specific-wildcard", + }, + { + name: "wildcard walk checks parent domains", + domain: "a.b.bar.com", + location: "/api/users", + httpUser: "bob", + want: "wildcard-parent-domain", + }, + { + name: "catch-all domain fallback", + domain: "foo.test.com", + location: "/other", + httpUser: "bob", + want: "all-domains", + }, + { + name: "catch-all domain honors http user", + domain: "foo.test.com", + location: "/admin/panel", + httpUser: "root", + want: "all-domains-user", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router, ok := routers.getByRoute(tt.domain, tt.location, tt.httpUser) + require.True(t, ok) + require.Equal(t, tt.want, router.payload) + }) + } +} + +func TestRoutersGetByRouteNoMatch(t *testing.T) { + routers := NewRouters() + require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users")) + require.NoError(t, routers.Add("*.com", "/api", "", "top-level-wildcard")) + + tests := []struct { + name string + domain string + location string + }{ + { + name: "two-label domain does not enter wildcard walk", + domain: "example.com", + location: "/api/users", + }, + { + name: "missing catch-all remains no match", + domain: "foo.test.com", + location: "/api/users", + }, + { + name: "wrong path remains no match", + domain: "foo.example.com", + location: "/other", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router, ok := routers.getByRoute(tt.domain, tt.location, "bob") + require.False(t, ok) + require.Nil(t, router) + }) + } +} + +func TestRoutersConcurrentGetByRouteAndAdd(t *testing.T) { + routers := NewRouters() + require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard")) + + const readers = 8 + const iterations = 200 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + start := make(chan struct{}) + errCh := make(chan error, readers+1) + var wg sync.WaitGroup + + for id := range readers { + wg.Go(func() { + <-start + for j := range iterations { + select { + case <-ctx.Done(): + return + default: + } + router, ok := routers.getByRoute("foo.example.com", "/api/users", "") + if !ok || router == nil || router.payload != "wildcard" { + errCh <- fmt.Errorf("reader %d iteration %d got router=%v ok=%v", id, j, router, ok) + return + } + } + }) + } + + wg.Go(func() { + <-start + for i := range iterations { + select { + case <-ctx.Done(): + return + default: + } + err := routers.Add(fmt.Sprintf("host-%d.example.com", i), "/api", "", i) + if err != nil { + errCh <- err + return + } + } + }) + + close(start) + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + cancel() + t.Fatal("concurrent route lookup and add timed out") + } + + close(errCh) + for err := range errCh { + require.NoError(t, err) + } +} diff --git a/pkg/util/vhost/vhost.go b/pkg/util/vhost/vhost.go index 007751d7..637c8dcf 100644 --- a/pkg/util/vhost/vhost.go +++ b/pkg/util/vhost/vhost.go @@ -148,41 +148,9 @@ func (v *Muxer) Listen(ctx context.Context, cfg *RouteConfig) (l *Listener, err } func (v *Muxer) getListener(name, path, httpUser string) (*Listener, bool) { - findRouter := func(inName, inPath, inHTTPUser string) (*Listener, bool) { - vr, ok := v.registryRouter.Get(inName, inPath, inHTTPUser) - if ok { - return vr.payload.(*Listener), true - } - // Try to check if there is one proxy that doesn't specify routerByHTTPUser, it means match all. - vr, ok = v.registryRouter.Get(inName, inPath, "") - if ok { - return vr.payload.(*Listener), true - } - return nil, false - } - - // first we check the full hostname - // if not exist, then check the wildcard_domain such as *.example.com - l, ok := findRouter(name, path, httpUser) + vr, ok := v.registryRouter.getByRoute(name, path, httpUser) if ok { - return l, true - } - - domainSplit := strings.Split(name, ".") - for len(domainSplit) >= 3 { - domainSplit[0] = "*" - name = strings.Join(domainSplit, ".") - - l, ok = findRouter(name, path, httpUser) - if ok { - return l, true - } - domainSplit = domainSplit[1:] - } - // Finally, try to check if there is one proxy that domain is "*" means match all domains. - l, ok = findRouter("*", path, httpUser) - if ok { - return l, true + return vr.payload.(*Listener), true } return nil, false } diff --git a/pkg/util/wait/backoff.go b/pkg/util/wait/backoff.go index 4d01ace3..0ee5db71 100644 --- a/pkg/util/wait/backoff.go +++ b/pkg/util/wait/backoff.go @@ -18,6 +18,8 @@ import ( "math/rand/v2" "time" + "k8s.io/utils/clock" + "github.com/fatedier/frp/pkg/util/util" ) @@ -48,6 +50,7 @@ type FastBackoffOptions struct { type fastBackoffImpl struct { options FastBackoffOptions + clock clock.PassiveClock lastCalledTime time.Time consecutiveErrCount int @@ -57,18 +60,26 @@ type fastBackoffImpl struct { } func NewFastBackoffManager(options FastBackoffOptions) BackoffManager { + return newFastBackoffManagerWithClock(options, clock.RealClock{}) +} + +func newFastBackoffManagerWithClock(options FastBackoffOptions, clk clock.PassiveClock) BackoffManager { + if clk == nil { + clk = clock.RealClock{} + } return &fastBackoffImpl{ options: options, + clock: clk, countsInFastRetryWindow: 1, } } func (f *fastBackoffImpl) Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration { if f.lastCalledTime.IsZero() { - f.lastCalledTime = time.Now() + f.lastCalledTime = f.clock.Now() return f.options.Duration } - now := time.Now() + now := f.clock.Now() f.lastCalledTime = now if previousConditionError { diff --git a/pkg/util/wait/backoff_test.go b/pkg/util/wait/backoff_test.go new file mode 100644 index 00000000..edea3669 --- /dev/null +++ b/pkg/util/wait/backoff_test.go @@ -0,0 +1,27 @@ +package wait + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + clocktesting "k8s.io/utils/clock/testing" +) + +func TestFastBackoffManagerUsesClock(t *testing.T) { + require := require.New(t) + + start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC) + clk := clocktesting.NewFakeClock(start) + backoff := newFastBackoffManagerWithClock(FastBackoffOptions{ + Duration: time.Second, + }, clk).(*fastBackoffImpl) + + require.Equal(time.Second, backoff.Backoff(0, false)) + require.Equal(start, backoff.lastCalledTime) + + next := start.Add(time.Minute) + clk.SetTime(next) + require.Equal(time.Second, backoff.Backoff(time.Second, false)) + require.Equal(next, backoff.lastCalledTime) +} diff --git a/pkg/util/xlog/log_writer.go b/pkg/util/xlog/log_writer.go index 3fff7324..28ddc22b 100644 --- a/pkg/util/xlog/log_writer.go +++ b/pkg/util/xlog/log_writer.go @@ -19,7 +19,6 @@ import "strings" // LogWriter forwards writes to frp's logger at configurable level. // It is safe for concurrent use as long as the underlying Logger is thread-safe. type LogWriter struct { - xl *Logger logFunc func(string) } @@ -31,35 +30,30 @@ func (w LogWriter) Write(p []byte) (n int, err error) { func NewTraceWriter(xl *Logger) LogWriter { return LogWriter{ - xl: xl, logFunc: func(msg string) { xl.Tracef("%s", msg) }, } } func NewDebugWriter(xl *Logger) LogWriter { return LogWriter{ - xl: xl, logFunc: func(msg string) { xl.Debugf("%s", msg) }, } } func NewInfoWriter(xl *Logger) LogWriter { return LogWriter{ - xl: xl, logFunc: func(msg string) { xl.Infof("%s", msg) }, } } func NewWarnWriter(xl *Logger) LogWriter { return LogWriter{ - xl: xl, logFunc: func(msg string) { xl.Warnf("%s", msg) }, } } func NewErrorWriter(xl *Logger) LogWriter { return LogWriter{ - xl: xl, logFunc: func(msg string) { xl.Errorf("%s", msg) }, } } diff --git a/pkg/vnet/controller.go b/pkg/vnet/controller.go index d5c97c66..348d50d2 100644 --- a/pkg/vnet/controller.go +++ b/pkg/vnet/controller.go @@ -286,7 +286,6 @@ func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWri r.mu.Lock() defer r.mu.Unlock() r.routes[name] = &routeElement{ - name: name, routes: routes, conn: conn, } @@ -383,7 +382,6 @@ func (r *serverRouter) cleanupConnIPs(conn io.Writer) { } type routeElement struct { - name string routes []net.IPNet conn io.ReadWriteCloser } diff --git a/server/ports/ports.go b/server/ports/ports.go index 5a73fbc5..653715ea 100644 --- a/server/ports/ports.go +++ b/server/ports/ports.go @@ -7,6 +7,8 @@ import ( "sync" "time" + "k8s.io/utils/clock" + "github.com/fatedier/frp/pkg/config/types" ) @@ -38,16 +40,25 @@ type Manager struct { bindAddr string netType string + clock clock.WithTicker mu sync.Mutex } func NewManager(netType string, bindAddr string, allowPorts []types.PortsRange) *Manager { + return newManagerWithClock(netType, bindAddr, allowPorts, clock.RealClock{}) +} + +func newManagerWithClock(netType string, bindAddr string, allowPorts []types.PortsRange, clk clock.WithTicker) *Manager { + if clk == nil { + clk = clock.RealClock{} + } pm := &Manager{ reservedPorts: make(map[string]*PortCtx), usedPorts: make(map[int]*PortCtx), freePorts: make(map[int]struct{}), bindAddr: bindAddr, netType: netType, + clock: clk, } if len(allowPorts) > 0 { for _, pair := range allowPorts { @@ -72,7 +83,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) { portCtx := &PortCtx{ ProxyName: name, Closed: false, - UpdateTime: time.Now(), + UpdateTime: pm.clock.Now(), } var ok bool @@ -90,9 +101,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) { if ctx, ok := pm.reservedPorts[name]; ok { if pm.isPortAvailable(ctx.Port) { realPort = ctx.Port - pm.usedPorts[realPort] = portCtx - pm.reservedPorts[name] = portCtx - delete(pm.freePorts, realPort) + pm.markPortAcquiredLocked(name, realPort, portCtx) return } } @@ -109,9 +118,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) { } if pm.isPortAvailable(k) { realPort = k - pm.usedPorts[realPort] = portCtx - pm.reservedPorts[name] = portCtx - delete(pm.freePorts, realPort) + pm.markPortAcquiredLocked(name, realPort, portCtx) break } } @@ -123,9 +130,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) { if _, ok = pm.freePorts[port]; ok { if pm.isPortAvailable(port) { realPort = port - pm.usedPorts[realPort] = portCtx - pm.reservedPorts[name] = portCtx - delete(pm.freePorts, realPort) + pm.markPortAcquiredLocked(name, realPort, portCtx) } else { err = ErrPortUnAvailable } @@ -140,6 +145,13 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) { return } +// markPortAcquiredLocked records a successful acquisition. pm.mu must be held. +func (pm *Manager) markPortAcquiredLocked(name string, port int, portCtx *PortCtx) { + pm.usedPorts[port] = portCtx + pm.reservedPorts[name] = portCtx + delete(pm.freePorts, port) +} + func (pm *Manager) isPortAvailable(port int) bool { if pm.netType == "udp" { addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pm.bindAddr, strconv.Itoa(port))) @@ -169,20 +181,36 @@ func (pm *Manager) Release(port int) { pm.freePorts[port] = struct{}{} delete(pm.usedPorts, port) ctx.Closed = true - ctx.UpdateTime = time.Now() + ctx.UpdateTime = pm.clock.Now() } } // Release reserved port if it isn't used in last 24 hours. func (pm *Manager) cleanReservedPortsWorker() { + pm.cleanReservedPortsWorkerUntil(nil) +} + +func (pm *Manager) cleanReservedPortsWorkerUntil(stopCh <-chan struct{}) { + ticker := pm.clock.NewTicker(CleanReservedPortsInterval) + defer ticker.Stop() + for { - time.Sleep(CleanReservedPortsInterval) - pm.mu.Lock() - for name, ctx := range pm.reservedPorts { - if ctx.Closed && time.Since(ctx.UpdateTime) > MaxPortReservedDuration { - delete(pm.reservedPorts, name) - } + select { + case <-ticker.C(): + pm.cleanReservedPortsOnce() + case <-stopCh: + return + } + } +} + +func (pm *Manager) cleanReservedPortsOnce() { + pm.mu.Lock() + defer pm.mu.Unlock() + + for name, ctx := range pm.reservedPorts { + if ctx.Closed && pm.clock.Since(ctx.UpdateTime) > MaxPortReservedDuration { + delete(pm.reservedPorts, name) } - pm.mu.Unlock() } } diff --git a/server/ports/ports_test.go b/server/ports/ports_test.go new file mode 100644 index 00000000..cd560b25 --- /dev/null +++ b/server/ports/ports_test.go @@ -0,0 +1,72 @@ +package ports + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + clocktesting "k8s.io/utils/clock/testing" + + "github.com/fatedier/frp/pkg/config/types" +) + +func TestManagerUsesClockForPortTimestamps(t *testing.T) { + require := require.New(t) + + port := freeTCPPort(t) + start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC) + clk := clocktesting.NewFakeClock(start) + pm := newManagerWithClock("tcp", "127.0.0.1", []types.PortsRange{{Single: port}}, clk) + + realPort, err := pm.Acquire("proxy", port) + require.NoError(err) + require.Equal(port, realPort) + require.Equal(start, pm.usedPorts[port].UpdateTime) + + releasedAt := start.Add(time.Minute) + clk.SetTime(releasedAt) + pm.Release(port) + + require.Equal(releasedAt, pm.reservedPorts["proxy"].UpdateTime) +} + +func TestManagerCleanReservedPortsWorkerUsesClockTicker(t *testing.T) { + require := require.New(t) + + port := freeTCPPort(t) + start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC) + clk := clocktesting.NewFakeClock(start) + pm := newManagerWithClock("tcp", "127.0.0.1", []types.PortsRange{{Single: port}}, clk) + + realPort, err := pm.Acquire("proxy", port) + require.NoError(err) + require.Equal(port, realPort) + pm.Release(port) + require.True(pm.hasReservedPort("proxy")) + + require.Eventually(clk.HasWaiters, time.Second, time.Millisecond) + clk.Step(MaxPortReservedDuration + CleanReservedPortsInterval + time.Minute) + + require.Eventually(func() bool { + return !pm.hasReservedPort("proxy") + }, time.Second, time.Millisecond) +} + +func (pm *Manager) hasReservedPort(name string) bool { + pm.mu.Lock() + defer pm.mu.Unlock() + + _, ok := pm.reservedPorts[name] + return ok +} + +func freeTCPPort(t *testing.T) int { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + return listener.Addr().(*net.TCPAddr).Port +} diff --git a/server/registry/registry.go b/server/registry/registry.go index c521d1c2..21771bce 100644 --- a/server/registry/registry.go +++ b/server/registry/registry.go @@ -18,6 +18,8 @@ import ( "fmt" "sync" "time" + + "k8s.io/utils/clock" ) // ClientInfo captures metadata about a connected frpc instance. @@ -42,12 +44,21 @@ type ClientRegistry struct { mu sync.RWMutex clients map[string]*ClientInfo runIndex map[string]string + clock clock.PassiveClock } func NewClientRegistry() *ClientRegistry { + return newClientRegistryWithClock(clock.RealClock{}) +} + +func newClientRegistryWithClock(clk clock.PassiveClock) *ClientRegistry { + if clk == nil { + clk = clock.RealClock{} + } return &ClientRegistry{ clients: make(map[string]*ClientInfo), runIndex: make(map[string]string), + clock: clk, } } @@ -64,7 +75,7 @@ func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version, key = cr.composeClientKey(user, effectiveID) enforceUnique := rawClientID != "" - now := time.Now() + now := cr.clock.Now() cr.mu.Lock() defer cr.mu.Unlock() @@ -116,7 +127,7 @@ func (cr *ClientRegistry) MarkOfflineByRunID(runID string) { } else { info.RunID = "" info.Online = false - now := time.Now() + now := cr.clock.Now() info.DisconnectedAt = now } } diff --git a/server/registry/registry_test.go b/server/registry/registry_test.go index cf428964..0ff083b8 100644 --- a/server/registry/registry_test.go +++ b/server/registry/registry_test.go @@ -16,6 +16,9 @@ package registry import ( "testing" + "time" + + clocktesting "k8s.io/utils/clock/testing" "github.com/fatedier/frp/pkg/proto/wire" ) @@ -35,3 +38,37 @@ func TestClientRegistryRegisterStoresWireProtocol(t *testing.T) { t.Fatalf("wire protocol mismatch, want %q got %q", wire.ProtocolV2, info.WireProtocol) } } + +func TestClientRegistryUsesClockForTimestamps(t *testing.T) { + start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC) + clk := clocktesting.NewFakeClock(start) + registry := newClientRegistryWithClock(clk) + + key, conflict := registry.Register("user", "client-id", "run-id", "host", "1.0.0", "127.0.0.1", wire.ProtocolV2) + if conflict { + t.Fatal("unexpected client conflict") + } + + info, ok := registry.GetByKey(key) + if !ok { + t.Fatalf("client %q not found", key) + } + if !info.FirstConnectedAt.Equal(start) { + t.Fatalf("first connected time mismatch, want %s got %s", start, info.FirstConnectedAt) + } + if !info.LastConnectedAt.Equal(start) { + t.Fatalf("last connected time mismatch, want %s got %s", start, info.LastConnectedAt) + } + + disconnectedAt := start.Add(time.Minute) + clk.SetTime(disconnectedAt) + registry.MarkOfflineByRunID("run-id") + + info, ok = registry.GetByKey(key) + if !ok { + t.Fatalf("client %q not found after disconnect", key) + } + if !info.DisconnectedAt.Equal(disconnectedAt) { + t.Fatalf("disconnected time mismatch, want %s got %s", disconnectedAt, info.DisconnectedAt) + } +} diff --git a/test/e2e/framework/cleanup.go b/test/e2e/framework/cleanup.go deleted file mode 100644 index 15e1a9dd..00000000 --- a/test/e2e/framework/cleanup.go +++ /dev/null @@ -1,62 +0,0 @@ -package framework - -import ( - "sync" -) - -// CleanupActionHandle is an integer pointer type for handling cleanup action -type CleanupActionHandle *int - -type cleanupFuncHandle struct { - actionHandle CleanupActionHandle - actionHook func() -} - -var ( - cleanupActionsLock sync.Mutex - cleanupHookList = []cleanupFuncHandle{} -) - -// AddCleanupAction installs a function that will be called in the event of the -// whole test being terminated. This allows arbitrary pieces of the overall -// test to hook into SynchronizedAfterSuite(). -// The hooks are called in last-in-first-out order. -func AddCleanupAction(fn func()) CleanupActionHandle { - p := CleanupActionHandle(new(int)) - cleanupActionsLock.Lock() - defer cleanupActionsLock.Unlock() - c := cleanupFuncHandle{actionHandle: p, actionHook: fn} - cleanupHookList = append([]cleanupFuncHandle{c}, cleanupHookList...) - return p -} - -// RemoveCleanupAction removes a function that was installed by -// AddCleanupAction. -func RemoveCleanupAction(p CleanupActionHandle) { - cleanupActionsLock.Lock() - defer cleanupActionsLock.Unlock() - for i, item := range cleanupHookList { - if item.actionHandle == p { - cleanupHookList = append(cleanupHookList[:i], cleanupHookList[i+1:]...) - break - } - } -} - -// RunCleanupActions runs all functions installed by AddCleanupAction. It does -// not remove them (see RemoveCleanupAction) but it does run unlocked, so they -// may remove themselves. -func RunCleanupActions() { - list := []func(){} - func() { - cleanupActionsLock.Lock() - defer cleanupActionsLock.Unlock() - for _, p := range cleanupHookList { - list = append(list, p.actionHook) - } - }() - // Run unlocked. - for _, fn := range list { - fn() - } -} diff --git a/test/e2e/framework/expect.go b/test/e2e/framework/expect.go index 8e904489..ea80ecb8 100644 --- a/test/e2e/framework/expect.go +++ b/test/e2e/framework/expect.go @@ -23,11 +23,6 @@ func ExpectNotEqual(actual any, extra any, explain ...any) { gomega.ExpectWithOffset(1, actual).NotTo(gomega.Equal(extra), explain...) } -// ExpectError expects an error happens, otherwise an exception raises -func ExpectError(err error, explain ...any) { - gomega.ExpectWithOffset(1, err).To(gomega.HaveOccurred(), explain...) -} - func ExpectErrorWithOffset(offset int, err error, explain ...any) { gomega.ExpectWithOffset(1+offset, err).To(gomega.HaveOccurred(), explain...) } @@ -47,11 +42,6 @@ func ExpectContainSubstring(actual, substr string, explain ...any) { gomega.ExpectWithOffset(1, actual).To(gomega.ContainSubstring(substr), explain...) } -// ExpectConsistOf expects actual contains precisely the extra elements. The ordering of the elements does not matter. -func ExpectConsistOf(actual any, extra any, explain ...any) { - gomega.ExpectWithOffset(1, actual).To(gomega.ConsistOf(extra), explain...) -} - func ExpectContainElements(actual any, extra any, explain ...any) { gomega.ExpectWithOffset(1, actual).To(gomega.ContainElements(extra), explain...) } @@ -60,16 +50,6 @@ func ExpectNotContainElements(actual any, extra any, explain ...any) { gomega.ExpectWithOffset(1, actual).NotTo(gomega.ContainElements(extra), explain...) } -// ExpectHaveKey expects the actual map has the key in the keyset -func ExpectHaveKey(actual any, key any, explain ...any) { - gomega.ExpectWithOffset(1, actual).To(gomega.HaveKey(key), explain...) -} - -// ExpectEmpty expects actual is empty -func ExpectEmpty(actual any, explain ...any) { - gomega.ExpectWithOffset(1, actual).To(gomega.BeEmpty(), explain...) -} - func ExpectTrue(actual any, explain ...any) { gomega.ExpectWithOffset(1, actual).Should(gomega.BeTrue(), explain...) } diff --git a/test/e2e/framework/framework.go b/test/e2e/framework/framework.go index 7713fb3e..34c0eb63 100644 --- a/test/e2e/framework/framework.go +++ b/test/e2e/framework/framework.go @@ -38,11 +38,6 @@ type Framework struct { // Multiple default mock servers used for e2e testing. mockServers *MockServers - // To make sure that this framework cleans up after itself, no matter what, - // we install a Cleanup action before each test and clear it after. If we - // should abort, the AfterSuite hook should run all Cleanup actions. - cleanupHandle CleanupActionHandle - // beforeEachStarted indicates that BeforeEach has started beforeEachStarted bool @@ -87,8 +82,6 @@ func NewFramework(opt Options) *Framework { func (f *Framework) BeforeEach() { f.beforeEachStarted = true - f.cleanupHandle = AddCleanupAction(f.AfterEach) - dir, err := os.MkdirTemp(os.TempDir(), "frp-e2e-test-*") ExpectNoError(err) f.TempDirectory = dir @@ -113,8 +106,6 @@ func (f *Framework) AfterEach() { return } - RemoveCleanupAction(f.cleanupHandle) - // stop processor for _, p := range f.serverProcesses { _ = p.Stop() @@ -266,10 +257,6 @@ func (f *Framework) AllocPortExcludingRanges(ranges ...[2]int) int { return 0 } -func (f *Framework) ReleasePort(port int) { - f.portAllocator.Release(port) -} - func (f *Framework) RunServer(portName string, s server.Server) { f.servers = append(f.servers, s) if s.BindPort() > 0 && portName != "" { diff --git a/test/e2e/framework/mockservers.go b/test/e2e/framework/mockservers.go index 6b6c2868..65d9f621 100644 --- a/test/e2e/framework/mockservers.go +++ b/test/e2e/framework/mockservers.go @@ -75,11 +75,3 @@ func (m *MockServers) GetTemplateParams() map[string]any { ret[HTTPSimpleServerPort] = m.httpSimpleServer.BindPort() return ret } - -func (m *MockServers) GetParam(key string) any { - params := m.GetTemplateParams() - if v, ok := params[key]; ok { - return v - } - return nil -} diff --git a/test/e2e/framework/request.go b/test/e2e/framework/request.go index 599ff11b..a6ac565a 100644 --- a/test/e2e/framework/request.go +++ b/test/e2e/framework/request.go @@ -124,7 +124,3 @@ func (e *RequestExpect) Ensure(fns ...EnsureFunc) { } } } - -func (e *RequestExpect) Do() (*request.Response, error) { - return e.req.Do() -} diff --git a/test/e2e/mock/server/httpserver/server.go b/test/e2e/mock/server/httpserver/server.go index f9818a0d..fb80288f 100644 --- a/test/e2e/mock/server/httpserver/server.go +++ b/test/e2e/mock/server/httpserver/server.go @@ -31,13 +31,6 @@ func New(options ...Option) *Server { return s } -func WithBindAddr(addr string) Option { - return func(s *Server) *Server { - s.bindAddr = addr - return s - } -} - func WithBindPort(port int) Option { return func(s *Server) *Server { s.bindPort = port diff --git a/test/e2e/mock/server/oidcserver/oidcserver.go b/test/e2e/mock/server/oidcserver/oidcserver.go index 22236dc0..eb7ce2e4 100644 --- a/test/e2e/mock/server/oidcserver/oidcserver.go +++ b/test/e2e/mock/server/oidcserver/oidcserver.go @@ -61,21 +61,6 @@ func WithBindPort(port int) Option { return func(s *Server) { s.bindPort = port } } -func WithClientCredentials(id, secret string) Option { - return func(s *Server) { - s.clientID = id - s.clientSecret = secret - } -} - -func WithAudience(aud string) Option { - return func(s *Server) { s.audience = aud } -} - -func WithSubject(sub string) Option { - return func(s *Server) { s.subject = sub } -} - func WithExpiresIn(seconds int) Option { return func(s *Server) { s.expiresIn = seconds } } diff --git a/test/e2e/pkg/process/process.go b/test/e2e/pkg/process/process.go index 8235461b..0e05b90d 100644 --- a/test/e2e/pkg/process/process.go +++ b/test/e2e/pkg/process/process.go @@ -40,13 +40,8 @@ type Process struct { closeOne sync.Once waitErr error - started bool - beforeStopHandler func() - stopped bool -} - -func New(path string, params []string) *Process { - return NewWithEnvs(path, params, nil) + started bool + stopped bool } func NewWithEnvs(path string, params []string, envs []string) *Process { @@ -100,9 +95,6 @@ func (p *Process) Stop() error { defer func() { p.stopped = true }() - if p.beforeStopHandler != nil { - p.beforeStopHandler() - } p.cancel() <-p.done return p.waitErr @@ -125,10 +117,6 @@ func (p *Process) CountOutput(pattern string) int { return strings.Count(p.Output(), pattern) } -func (p *Process) SetBeforeStopHandler(fn func()) { - p.beforeStopHandler = fn -} - // WaitForOutput polls the combined process output until the pattern is found // count time(s) or the timeout is reached. It also returns early if the process exits. func (p *Process) WaitForOutput(pattern string, count int, timeout time.Duration) error { diff --git a/test/e2e/pkg/request/request.go b/test/e2e/pkg/request/request.go index 211cc425..c3b72669 100644 --- a/test/e2e/pkg/request/request.go +++ b/test/e2e/pkg/request/request.go @@ -22,11 +22,10 @@ type Request struct { protocol string // for all protocol - addr string - port int - body []byte - timeout time.Duration - resolver *net.Resolver + addr string + port int + body []byte + timeout time.Duration // for http or https method string @@ -134,11 +133,6 @@ func (r *Request) Body(content []byte) *Request { return r } -func (r *Request) Resolver(resolver *net.Resolver) *Request { - r.resolver = resolver - return r -} - func (r *Request) Do() (*Response, error) { var ( conn net.Conn @@ -169,7 +163,7 @@ func (r *Request) Do() (*Response, error) { return nil, err } } else { - dialer := &net.Dialer{Resolver: r.resolver} + dialer := &net.Dialer{} switch r.protocol { case "tcp": conn, err = dialer.Dial("tcp", addr) @@ -225,7 +219,6 @@ func (r *Request) sendHTTPRequest(method, urlstr string, host string, headers ma Timeout: time.Second, KeepAlive: 30 * time.Second, DualStack: true, - Resolver: r.resolver, }).DialContext, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second,