mirror of
https://github.com/fatedier/frp.git
synced 2026-05-15 08:05:49 -06:00
refactor: clean up code (#5308)
Some checks failed
golangci-lint / lint (push) Has been cancelled
Some checks failed
golangci-lint / lint (push) Has been cancelled
This commit is contained in:
parent
ad07d27914
commit
a88e0e9a49
49 changed files with 2082 additions and 931 deletions
2
go.mod
2
go.mod
|
|
@ -36,6 +36,7 @@ require (
|
|||
gopkg.in/ini.v1 v1.67.0
|
||||
k8s.io/apimachinery v0.28.8
|
||||
k8s.io/client-go v0.28.8
|
||||
k8s.io/utils v0.0.0-20230406110748-d93618cff8a2
|
||||
)
|
||||
|
||||
require (
|
||||
|
|
@ -75,7 +76,6 @@ require (
|
|||
google.golang.org/protobuf v1.36.5 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 // indirect
|
||||
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect
|
||||
sigs.k8s.io/yaml v1.3.0 // indirect
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,11 +14,7 @@
|
|||
|
||||
package source
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
import v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
|
||||
// ConfigSource implements Source for in-memory configuration.
|
||||
// All operations are thread-safe.
|
||||
|
|
@ -39,23 +35,17 @@ func (s *ConfigSource) ReplaceAll(proxies []v1.ProxyConfigurer, visitors []v1.Vi
|
|||
|
||||
nextProxies := make(map[string]v1.ProxyConfigurer, len(proxies))
|
||||
for _, p := range proxies {
|
||||
if p == nil {
|
||||
return fmt.Errorf("proxy cannot be nil")
|
||||
}
|
||||
name := p.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("proxy name cannot be empty")
|
||||
name, err := validateProxyName(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nextProxies[name] = p
|
||||
}
|
||||
nextVisitors := make(map[string]v1.VisitorConfigurer, len(visitors))
|
||||
for _, v := range visitors {
|
||||
if v == nil {
|
||||
return fmt.Errorf("visitor cannot be nil")
|
||||
}
|
||||
name := v.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("visitor name cannot be empty")
|
||||
name, err := validateVisitorName(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nextVisitors[name] = v
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,6 +43,11 @@ var (
|
|||
ErrNotFound = errors.New("not found")
|
||||
)
|
||||
|
||||
const (
|
||||
storeKindProxy = "proxy"
|
||||
storeKindVisitor = "visitor"
|
||||
)
|
||||
|
||||
func NewStoreSource(cfg StoreSourceConfig) (*StoreSource, error) {
|
||||
if cfg.Path == "" {
|
||||
return nil, fmt.Errorf("path is required")
|
||||
|
|
@ -172,79 +177,111 @@ func (s *StoreSource) saveToFileUnlocked() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *StoreSource) AddProxy(proxy v1.ProxyConfigurer) error {
|
||||
if proxy == nil {
|
||||
return fmt.Errorf("proxy cannot be nil")
|
||||
func (s *StoreSource) persistOrRollbackUnlocked(rollback func()) error {
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store map selectors return the target map for generic helpers.
|
||||
func proxyStoreEntries(s *StoreSource) map[string]v1.ProxyConfigurer {
|
||||
return s.proxies
|
||||
}
|
||||
|
||||
func visitorStoreEntries(s *StoreSource) map[string]v1.VisitorConfigurer {
|
||||
return s.visitors
|
||||
}
|
||||
|
||||
// Store entry helpers share mutation, persistence, and rollback for proxy and visitor maps.
|
||||
// T is intentionally limited by callers to v1.ProxyConfigurer or v1.VisitorConfigurer.
|
||||
func addStoreEntry[T any](
|
||||
s *StoreSource,
|
||||
entriesFn func(*StoreSource) map[string]T,
|
||||
kind string,
|
||||
name string,
|
||||
value T,
|
||||
) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
entries := entriesFn(s)
|
||||
if _, exists := entries[name]; exists {
|
||||
return fmt.Errorf("%w: %s %q", ErrAlreadyExists, kind, name)
|
||||
}
|
||||
|
||||
name := proxy.GetBaseConfig().Name
|
||||
entries[name] = value
|
||||
return s.persistOrRollbackUnlocked(func() {
|
||||
delete(entries, name)
|
||||
})
|
||||
}
|
||||
|
||||
func updateStoreEntry[T any](
|
||||
s *StoreSource,
|
||||
entriesFn func(*StoreSource) map[string]T,
|
||||
kind string,
|
||||
name string,
|
||||
value T,
|
||||
) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
entries := entriesFn(s)
|
||||
old, exists := entries[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: %s %q", ErrNotFound, kind, name)
|
||||
}
|
||||
|
||||
entries[name] = value
|
||||
return s.persistOrRollbackUnlocked(func() {
|
||||
entries[name] = old
|
||||
})
|
||||
}
|
||||
|
||||
func removeStoreEntry[T any](
|
||||
s *StoreSource,
|
||||
entriesFn func(*StoreSource) map[string]T,
|
||||
kind string,
|
||||
name string,
|
||||
) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("proxy name cannot be empty")
|
||||
return fmt.Errorf("%s name cannot be empty", kind)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.proxies[name]; exists {
|
||||
return fmt.Errorf("%w: proxy %q", ErrAlreadyExists, name)
|
||||
entries := entriesFn(s)
|
||||
old, exists := entries[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: %s %q", ErrNotFound, kind, name)
|
||||
}
|
||||
|
||||
s.proxies[name] = proxy
|
||||
delete(entries, name)
|
||||
return s.persistOrRollbackUnlocked(func() {
|
||||
entries[name] = old
|
||||
})
|
||||
}
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
delete(s.proxies, name)
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
func (s *StoreSource) AddProxy(proxy v1.ProxyConfigurer) error {
|
||||
name, err := validateProxyName(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return addStoreEntry(s, proxyStoreEntries, storeKindProxy, name, proxy)
|
||||
}
|
||||
|
||||
func (s *StoreSource) UpdateProxy(proxy v1.ProxyConfigurer) error {
|
||||
if proxy == nil {
|
||||
return fmt.Errorf("proxy cannot be nil")
|
||||
name, err := validateProxyName(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name := proxy.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("proxy name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
oldProxy, exists := s.proxies[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: proxy %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
s.proxies[name] = proxy
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
s.proxies[name] = oldProxy
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return updateStoreEntry(s, proxyStoreEntries, storeKindProxy, name, proxy)
|
||||
}
|
||||
|
||||
func (s *StoreSource) RemoveProxy(name string) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("proxy name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
oldProxy, exists := s.proxies[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: proxy %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
delete(s.proxies, name)
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
s.proxies[name] = oldProxy
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return removeStoreEntry(s, proxyStoreEntries, storeKindProxy, name)
|
||||
}
|
||||
|
||||
func (s *StoreSource) GetProxy(name string) v1.ProxyConfigurer {
|
||||
|
|
@ -259,78 +296,23 @@ func (s *StoreSource) GetProxy(name string) v1.ProxyConfigurer {
|
|||
}
|
||||
|
||||
func (s *StoreSource) AddVisitor(visitor v1.VisitorConfigurer) error {
|
||||
if visitor == nil {
|
||||
return fmt.Errorf("visitor cannot be nil")
|
||||
name, err := validateVisitorName(visitor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name := visitor.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("visitor name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.visitors[name]; exists {
|
||||
return fmt.Errorf("%w: visitor %q", ErrAlreadyExists, name)
|
||||
}
|
||||
|
||||
s.visitors[name] = visitor
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
delete(s.visitors, name)
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return addStoreEntry(s, visitorStoreEntries, storeKindVisitor, name, visitor)
|
||||
}
|
||||
|
||||
func (s *StoreSource) UpdateVisitor(visitor v1.VisitorConfigurer) error {
|
||||
if visitor == nil {
|
||||
return fmt.Errorf("visitor cannot be nil")
|
||||
name, err := validateVisitorName(visitor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name := visitor.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("visitor name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
oldVisitor, exists := s.visitors[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: visitor %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
s.visitors[name] = visitor
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
s.visitors[name] = oldVisitor
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return updateStoreEntry(s, visitorStoreEntries, storeKindVisitor, name, visitor)
|
||||
}
|
||||
|
||||
func (s *StoreSource) RemoveVisitor(name string) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("visitor name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
oldVisitor, exists := s.visitors[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: visitor %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
delete(s.visitors, name)
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
s.visitors[name] = oldVisitor
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return removeStoreEntry(s, visitorStoreEntries, storeKindVisitor, name)
|
||||
}
|
||||
|
||||
func (s *StoreSource) GetVisitor(name string) v1.VisitorConfigurer {
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ package source
|
|||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -59,6 +60,101 @@ func TestStoreSource_AddProxyAndVisitor_DoesNotApplyRuntimeDefaults(t *testing.T
|
|||
require.Empty(gotVisitor.(*v1.XTCPVisitorConfig).Protocol)
|
||||
}
|
||||
|
||||
func TestStoreSource_UpdateAndRemoveProxyAndVisitor(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
storeSource := newTestStoreSource(t)
|
||||
|
||||
proxyCfg := mockProxy("proxy1")
|
||||
visitorCfg := mockVisitor("visitor1")
|
||||
|
||||
require.NoError(storeSource.AddProxy(proxyCfg))
|
||||
require.NoError(storeSource.AddVisitor(visitorCfg))
|
||||
require.ErrorIs(storeSource.AddProxy(proxyCfg), ErrAlreadyExists)
|
||||
require.ErrorIs(storeSource.AddVisitor(visitorCfg), ErrAlreadyExists)
|
||||
require.ErrorContains(storeSource.RemoveProxy(""), "proxy name cannot be empty")
|
||||
require.ErrorContains(storeSource.RemoveVisitor(""), "visitor name cannot be empty")
|
||||
|
||||
updatedProxy := mockProxy("proxy1").(*v1.TCPProxyConfig)
|
||||
updatedProxy.RemotePort = 19090
|
||||
require.NoError(storeSource.UpdateProxy(updatedProxy))
|
||||
require.Equal(19090, storeSource.GetProxy("proxy1").(*v1.TCPProxyConfig).RemotePort)
|
||||
|
||||
updatedVisitor := mockVisitor("visitor1").(*v1.STCPVisitorConfig)
|
||||
updatedVisitor.ServerName = "updated-server"
|
||||
require.NoError(storeSource.UpdateVisitor(updatedVisitor))
|
||||
require.Equal("updated-server", storeSource.GetVisitor("visitor1").(*v1.STCPVisitorConfig).ServerName)
|
||||
|
||||
require.NoError(storeSource.RemoveProxy("proxy1"))
|
||||
require.Nil(storeSource.GetProxy("proxy1"))
|
||||
require.ErrorIs(storeSource.RemoveProxy("proxy1"), ErrNotFound)
|
||||
|
||||
require.NoError(storeSource.RemoveVisitor("visitor1"))
|
||||
require.Nil(storeSource.GetVisitor("visitor1"))
|
||||
require.ErrorIs(storeSource.RemoveVisitor("visitor1"), ErrNotFound)
|
||||
|
||||
require.ErrorIs(storeSource.UpdateProxy(updatedProxy), ErrNotFound)
|
||||
require.ErrorIs(storeSource.UpdateVisitor(updatedVisitor), ErrNotFound)
|
||||
}
|
||||
|
||||
func TestStoreSource_MutationRollsBackOnPersistFailure(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("chmod does not make directories unwritable on Windows")
|
||||
}
|
||||
if os.Getuid() == 0 {
|
||||
t.Skip("chmod does not block writes for uid 0")
|
||||
}
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "store.json")
|
||||
storeSource, err := NewStoreSource(StoreSourceConfig{Path: path})
|
||||
require.NoError(err)
|
||||
|
||||
proxyCfg := mockProxy("proxy1")
|
||||
visitorCfg := mockVisitor("visitor1")
|
||||
originalRemotePort := proxyCfg.(*v1.TCPProxyConfig).RemotePort
|
||||
originalServerName := visitorCfg.(*v1.STCPVisitorConfig).ServerName
|
||||
require.NoError(storeSource.AddProxy(proxyCfg))
|
||||
require.NoError(storeSource.AddVisitor(visitorCfg))
|
||||
|
||||
require.NoError(os.Chmod(dir, 0o500))
|
||||
t.Cleanup(func() {
|
||||
_ = os.Chmod(dir, 0o700)
|
||||
})
|
||||
|
||||
requirePersistError := func(err error) {
|
||||
t.Helper()
|
||||
require.Error(err)
|
||||
require.ErrorContains(err, "failed to persist")
|
||||
require.NotErrorIs(err, ErrAlreadyExists)
|
||||
require.NotErrorIs(err, ErrNotFound)
|
||||
}
|
||||
|
||||
requirePersistError(storeSource.AddProxy(mockProxy("proxy2")))
|
||||
require.Nil(storeSource.GetProxy("proxy2"))
|
||||
|
||||
updatedProxy := mockProxy("proxy1").(*v1.TCPProxyConfig)
|
||||
updatedProxy.RemotePort = 19090
|
||||
requirePersistError(storeSource.UpdateProxy(updatedProxy))
|
||||
require.Equal(originalRemotePort, storeSource.GetProxy("proxy1").(*v1.TCPProxyConfig).RemotePort)
|
||||
|
||||
requirePersistError(storeSource.RemoveProxy("proxy1"))
|
||||
require.NotNil(storeSource.GetProxy("proxy1"))
|
||||
|
||||
requirePersistError(storeSource.AddVisitor(mockVisitor("visitor2")))
|
||||
require.Nil(storeSource.GetVisitor("visitor2"))
|
||||
|
||||
updatedVisitor := mockVisitor("visitor1").(*v1.STCPVisitorConfig)
|
||||
updatedVisitor.ServerName = "updated-server"
|
||||
requirePersistError(storeSource.UpdateVisitor(updatedVisitor))
|
||||
require.Equal(originalServerName, storeSource.GetVisitor("visitor1").(*v1.STCPVisitorConfig).ServerName)
|
||||
|
||||
requirePersistError(storeSource.RemoveVisitor("visitor1"))
|
||||
require.NotNil(storeSource.GetVisitor("visitor1"))
|
||||
}
|
||||
|
||||
func TestStoreSource_LoadFromFile_DoesNotApplyRuntimeDefaults(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
|
|
|
|||
43
pkg/config/source/validation.go
Normal file
43
pkg/config/source/validation.go
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
func validateProxyName(proxy v1.ProxyConfigurer) (string, error) {
|
||||
if proxy == nil {
|
||||
return "", fmt.Errorf("proxy cannot be nil")
|
||||
}
|
||||
name := proxy.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return "", fmt.Errorf("proxy name cannot be empty")
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func validateVisitorName(visitor v1.VisitorConfigurer) (string, error) {
|
||||
if visitor == nil {
|
||||
return "", fmt.Errorf("visitor cannot be nil")
|
||||
}
|
||||
name := visitor.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return "", fmt.Errorf("visitor name cannot be empty")
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
43
pkg/config/v1/validation/auth.go
Normal file
43
pkg/config/v1/validation/auth.go
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/policy/security"
|
||||
)
|
||||
|
||||
func (v *ConfigValidator) validateAuthTokenSource(token string, tokenSource *v1.ValueSource) error {
|
||||
var errs error
|
||||
// Preserve the previous client/server validation order for joined errors.
|
||||
if token != "" && tokenSource != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
|
||||
}
|
||||
if tokenSource == nil {
|
||||
return errs
|
||||
}
|
||||
|
||||
if tokenSource.Type == "exec" {
|
||||
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
}
|
||||
if err := tokenSource.Validate(); err != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
|
||||
}
|
||||
return errs
|
||||
}
|
||||
228
pkg/config/v1/validation/auth_test.go
Normal file
228
pkg/config/v1/validation/auth_test.go
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package validation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/policy/security"
|
||||
)
|
||||
|
||||
const (
|
||||
tokenSourceConflictErr = "cannot specify both auth.token and auth.tokenSource"
|
||||
tokenSourceExecErr = "unsafe feature \"TokenSourceExec\" is not enabled. To enable it, ensure it is allowed in the configuration or command line flags"
|
||||
invalidFileSourceErr = "invalid auth.tokenSource: file configuration is required when type is 'file'"
|
||||
unsupportedSourceErr = "invalid auth.tokenSource: unsupported value source type: env (only 'file' and 'exec' are supported)"
|
||||
)
|
||||
|
||||
func TestValidateAuthTokenSource(t *testing.T) {
|
||||
for _, tc := range authTokenSourceTestCases() {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
validator := newAuthTokenSourceValidator(tc.unsafeAllowed)
|
||||
err := validator.validateAuthTokenSource(tc.token, tc.tokenSource())
|
||||
requireValidationErrors(t, err, tc.wantErrs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateClientAuthTokenSource(t *testing.T) {
|
||||
for _, tc := range authTokenSourceTestCases() {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
auth := v1.AuthClientConfig{
|
||||
Method: v1.AuthMethodToken,
|
||||
Token: tc.token,
|
||||
TokenSource: tc.tokenSource(),
|
||||
}
|
||||
validator := newAuthTokenSourceValidator(tc.unsafeAllowed)
|
||||
_, err := validator.ValidateClientCommonConfig(validClientConfigWithAuth(auth))
|
||||
requireValidationErrors(t, err, tc.wantErrs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateServerAuthTokenSource(t *testing.T) {
|
||||
for _, tc := range authTokenSourceTestCases() {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
auth := v1.AuthServerConfig{
|
||||
Method: v1.AuthMethodToken,
|
||||
Token: tc.token,
|
||||
TokenSource: tc.tokenSource(),
|
||||
}
|
||||
validator := newAuthTokenSourceValidator(tc.unsafeAllowed)
|
||||
_, err := validator.ValidateServerConfig(validServerConfigWithAuth(auth))
|
||||
requireValidationErrors(t, err, tc.wantErrs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type authTokenSourceTestCase struct {
|
||||
name string
|
||||
token string
|
||||
tokenSource func() *v1.ValueSource
|
||||
unsafeAllowed bool
|
||||
wantErrs []string
|
||||
}
|
||||
|
||||
func authTokenSourceTestCases() []authTokenSourceTestCase {
|
||||
return []authTokenSourceTestCase{
|
||||
{
|
||||
name: "empty token config",
|
||||
tokenSource: nilTokenSource,
|
||||
},
|
||||
{
|
||||
name: "valid file tokenSource",
|
||||
tokenSource: validFileTokenSource,
|
||||
},
|
||||
{
|
||||
name: "literal token without tokenSource",
|
||||
token: "token",
|
||||
tokenSource: nilTokenSource,
|
||||
},
|
||||
{
|
||||
name: "literal token conflicts with file tokenSource",
|
||||
token: "token",
|
||||
tokenSource: validFileTokenSource,
|
||||
wantErrs: []string{tokenSourceConflictErr},
|
||||
},
|
||||
{
|
||||
name: "exec tokenSource requires unsafe feature",
|
||||
tokenSource: validExecTokenSource,
|
||||
wantErrs: []string{tokenSourceExecErr},
|
||||
},
|
||||
{
|
||||
name: "exec tokenSource with unsafe feature allowed",
|
||||
tokenSource: validExecTokenSource,
|
||||
unsafeAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "literal token conflicts with exec tokenSource and unsafe feature disabled",
|
||||
token: "token",
|
||||
tokenSource: validExecTokenSource,
|
||||
wantErrs: []string{
|
||||
tokenSourceConflictErr,
|
||||
tokenSourceExecErr,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "literal token conflicts with exec tokenSource and unsafe feature allowed",
|
||||
token: "token",
|
||||
tokenSource: validExecTokenSource,
|
||||
unsafeAllowed: true,
|
||||
wantErrs: []string{tokenSourceConflictErr},
|
||||
},
|
||||
{
|
||||
name: "invalid file tokenSource is wrapped",
|
||||
tokenSource: invalidFileTokenSource,
|
||||
wantErrs: []string{invalidFileSourceErr},
|
||||
},
|
||||
{
|
||||
name: "unsupported tokenSource type is wrapped",
|
||||
tokenSource: unsupportedTokenSource,
|
||||
wantErrs: []string{unsupportedSourceErr},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newAuthTokenSourceValidator(unsafeAllowed bool) *ConfigValidator {
|
||||
if !unsafeAllowed {
|
||||
return NewConfigValidator(nil)
|
||||
}
|
||||
return NewConfigValidator(security.NewUnsafeFeatures([]string{security.TokenSourceExec}))
|
||||
}
|
||||
|
||||
func requireValidationErrors(t *testing.T, err error, wantErrs []string) {
|
||||
t.Helper()
|
||||
if len(wantErrs) == 0 {
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
require.Error(t, err)
|
||||
// Client/server validators may wrap joined errors in another join layer; compare leaf errors.
|
||||
gotErrs := unwrapValidationErrors(err)
|
||||
require.Len(t, gotErrs, len(wantErrs))
|
||||
for i, wantErr := range wantErrs {
|
||||
require.EqualError(t, gotErrs[i], wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
func unwrapValidationErrors(err error) []error {
|
||||
type joinedError interface {
|
||||
Unwrap() []error
|
||||
}
|
||||
joined, ok := err.(joinedError)
|
||||
if !ok {
|
||||
return []error{err}
|
||||
}
|
||||
|
||||
var errs []error
|
||||
for _, err := range joined.Unwrap() {
|
||||
errs = append(errs, unwrapValidationErrors(err)...)
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
// nilTokenSource keeps the shared table shape uniform for cases without a tokenSource.
|
||||
func nilTokenSource() *v1.ValueSource {
|
||||
return nil
|
||||
}
|
||||
|
||||
func validFileTokenSource() *v1.ValueSource {
|
||||
return &v1.ValueSource{
|
||||
Type: "file",
|
||||
File: &v1.FileSource{Path: "token.txt"},
|
||||
}
|
||||
}
|
||||
|
||||
func validExecTokenSource() *v1.ValueSource {
|
||||
return &v1.ValueSource{
|
||||
Type: "exec",
|
||||
Exec: &v1.ExecSource{Command: "print-token"},
|
||||
}
|
||||
}
|
||||
|
||||
func invalidFileTokenSource() *v1.ValueSource {
|
||||
return &v1.ValueSource{
|
||||
Type: "file",
|
||||
}
|
||||
}
|
||||
|
||||
func unsupportedTokenSource() *v1.ValueSource {
|
||||
return &v1.ValueSource{Type: "env"}
|
||||
}
|
||||
|
||||
func validClientConfigWithAuth(auth v1.AuthClientConfig) *v1.ClientCommonConfig {
|
||||
return &v1.ClientCommonConfig{
|
||||
Auth: auth,
|
||||
Log: v1.LogConfig{
|
||||
Level: "info",
|
||||
},
|
||||
Transport: v1.ClientTransportConfig{
|
||||
Protocol: "tcp",
|
||||
WireProtocol: "v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func validServerConfigWithAuth(auth v1.AuthServerConfig) *v1.ServerConfig {
|
||||
return &v1.ServerConfig{
|
||||
Auth: auth,
|
||||
Log: v1.LogConfig{
|
||||
Level: "info",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
@ -68,22 +68,7 @@ func (v *ConfigValidator) validateAuthConfig(c *v1.AuthClientConfig) (Warning, e
|
|||
errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes))
|
||||
}
|
||||
|
||||
// Validate token/tokenSource mutual exclusivity
|
||||
if c.Token != "" && c.TokenSource != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
|
||||
}
|
||||
|
||||
// Validate tokenSource if specified
|
||||
if c.TokenSource != nil {
|
||||
if c.TokenSource.Type == "exec" {
|
||||
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
}
|
||||
if err := c.TokenSource.Validate(); err != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
|
||||
}
|
||||
}
|
||||
errs = AppendError(errs, v.validateAuthTokenSource(c.Token, c.TokenSource))
|
||||
|
||||
if err := v.validateOIDCConfig(&c.OIDC); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ import (
|
|||
"github.com/samber/lo"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/policy/security"
|
||||
)
|
||||
|
||||
func (v *ConfigValidator) ValidateServerConfig(c *v1.ServerConfig) (Warning, error) {
|
||||
|
|
@ -36,22 +35,7 @@ func (v *ConfigValidator) ValidateServerConfig(c *v1.ServerConfig) (Warning, err
|
|||
errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes))
|
||||
}
|
||||
|
||||
// Validate token/tokenSource mutual exclusivity
|
||||
if c.Auth.Token != "" && c.Auth.TokenSource != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
|
||||
}
|
||||
|
||||
// Validate tokenSource if specified
|
||||
if c.Auth.TokenSource != nil {
|
||||
if c.Auth.TokenSource.Type == "exec" {
|
||||
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
}
|
||||
if err := c.Auth.TokenSource.Validate(); err != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
|
||||
}
|
||||
}
|
||||
errs = AppendError(errs, v.validateAuthTokenSource(c.Auth.Token, c.Auth.TokenSource))
|
||||
|
||||
if err := validateLogConfig(&c.Log); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"k8s.io/utils/clock"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
"github.com/fatedier/frp/pkg/util/metric"
|
||||
server "github.com/fatedier/frp/server/metrics"
|
||||
|
|
@ -37,12 +39,21 @@ func init() {
|
|||
}
|
||||
|
||||
type serverMetrics struct {
|
||||
info *ServerStatistics
|
||||
mu sync.Mutex
|
||||
info *ServerStatistics
|
||||
clock clock.WithTicker
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newServerMetrics() *serverMetrics {
|
||||
return newServerMetricsWithClock(clock.RealClock{})
|
||||
}
|
||||
|
||||
func newServerMetricsWithClock(clk clock.WithTicker) *serverMetrics {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
return &serverMetrics{
|
||||
clock: clk,
|
||||
info: &ServerStatistics{
|
||||
TotalTrafficIn: metric.NewDateCounter(ReserveDays),
|
||||
TotalTrafficOut: metric.NewDateCounter(ReserveDays),
|
||||
|
|
@ -57,14 +68,23 @@ func newServerMetrics() *serverMetrics {
|
|||
}
|
||||
|
||||
func (m *serverMetrics) run() {
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(12 * time.Hour)
|
||||
start := time.Now()
|
||||
go m.runUntil(nil)
|
||||
}
|
||||
|
||||
func (m *serverMetrics) runUntil(stopCh <-chan struct{}) {
|
||||
ticker := m.clock.NewTicker(12 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C():
|
||||
start := m.clock.Now()
|
||||
count, total := m.clearUselessInfo(time.Duration(7*24) * time.Hour)
|
||||
log.Debugf("clear useless proxy statistics data count %d/%d, cost %v", count, total, time.Since(start))
|
||||
log.Debugf("clear useless proxy statistics data count %d/%d, cost %v", count, total, m.clock.Since(start))
|
||||
case <-stopCh:
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) clearUselessInfo(continuousOfflineDuration time.Duration) (int, int) {
|
||||
|
|
@ -77,7 +97,7 @@ func (m *serverMetrics) clearUselessInfo(continuousOfflineDuration time.Duration
|
|||
for name, data := range m.info.ProxyStatistics {
|
||||
if !data.LastCloseTime.IsZero() &&
|
||||
data.LastStartTime.Before(data.LastCloseTime) &&
|
||||
time.Since(data.LastCloseTime) > continuousOfflineDuration {
|
||||
m.clock.Since(data.LastCloseTime) > continuousOfflineDuration {
|
||||
delete(m.info.ProxyStatistics, name)
|
||||
count++
|
||||
log.Tracef("clear proxy [%s]'s statistics data, lastCloseTime: [%s]", name, data.LastCloseTime.String())
|
||||
|
|
@ -121,7 +141,7 @@ func (m *serverMetrics) NewProxy(name string, proxyType string, user string, cli
|
|||
}
|
||||
proxyStats.User = user
|
||||
proxyStats.ClientID = clientID
|
||||
proxyStats.LastStartTime = time.Now()
|
||||
proxyStats.LastStartTime = m.clock.Now()
|
||||
}
|
||||
|
||||
func (m *serverMetrics) CloseProxy(name string, proxyType string) {
|
||||
|
|
@ -131,7 +151,7 @@ func (m *serverMetrics) CloseProxy(name string, proxyType string) {
|
|||
counter.Dec(1)
|
||||
}
|
||||
if proxyStats, ok := m.info.ProxyStatistics[name]; ok {
|
||||
proxyStats.LastCloseTime = time.Now()
|
||||
proxyStats.LastCloseTime = m.clock.Now()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
83
pkg/metrics/mem/server_test.go
Normal file
83
pkg/metrics/mem/server_test.go
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
package mem
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
)
|
||||
|
||||
func TestServerMetricsUsesClockForProxyTimestamps(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
metrics := newServerMetricsWithClock(clk)
|
||||
|
||||
metrics.NewProxy("proxy", "tcp", "user", "client-id")
|
||||
require.Equal(start, metrics.info.ProxyStatistics["proxy"].LastStartTime)
|
||||
|
||||
closedAt := start.Add(time.Minute)
|
||||
clk.SetTime(closedAt)
|
||||
metrics.CloseProxy("proxy", "tcp")
|
||||
require.Equal(closedAt, metrics.info.ProxyStatistics["proxy"].LastCloseTime)
|
||||
}
|
||||
|
||||
func TestServerMetricsClearUselessInfoUsesClock(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start.Add(25 * time.Hour))
|
||||
metrics := newServerMetricsWithClock(clk)
|
||||
metrics.info.ProxyStatistics["proxy"] = &ProxyStatistics{
|
||||
Name: "proxy",
|
||||
LastStartTime: start.Add(-time.Hour),
|
||||
LastCloseTime: start,
|
||||
}
|
||||
|
||||
count, total := metrics.clearUselessInfo(24 * time.Hour)
|
||||
|
||||
require.Equal(1, count)
|
||||
require.Equal(1, total)
|
||||
require.Empty(metrics.info.ProxyStatistics)
|
||||
}
|
||||
|
||||
func TestServerMetricsRunUsesClockTicker(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
metrics := newServerMetricsWithClock(clk)
|
||||
metrics.info.ProxyStatistics["proxy"] = &ProxyStatistics{
|
||||
Name: "proxy",
|
||||
LastStartTime: start.Add(-time.Hour),
|
||||
LastCloseTime: start,
|
||||
}
|
||||
|
||||
stopCh := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
metrics.runUntil(stopCh)
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
close(stopCh)
|
||||
<-done
|
||||
})
|
||||
|
||||
require.Eventually(clk.HasWaiters, time.Second, time.Millisecond)
|
||||
clk.Step(8 * 24 * time.Hour)
|
||||
|
||||
require.Eventually(func() bool {
|
||||
return !metrics.hasProxyStatistics("proxy")
|
||||
}, time.Second, time.Millisecond)
|
||||
}
|
||||
|
||||
func (m *serverMetrics) hasProxyStatistics(name string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
_, ok := m.info.ProxyStatistics[name]
|
||||
return ok
|
||||
}
|
||||
|
|
@ -109,10 +109,9 @@ func AsyncHandler(f func(Message)) func(Message) {
|
|||
type Dispatcher struct {
|
||||
rw ReadWriter
|
||||
|
||||
sendCh chan Message
|
||||
doneCh chan struct{}
|
||||
msgHandlers map[reflect.Type]func(Message)
|
||||
defaultHandler func(Message)
|
||||
sendCh chan Message
|
||||
doneCh chan struct{}
|
||||
msgHandlers map[reflect.Type]func(Message)
|
||||
}
|
||||
|
||||
func NewDispatcher(rw ReadWriter) *Dispatcher {
|
||||
|
|
@ -151,8 +150,6 @@ func (d *Dispatcher) readLoop() {
|
|||
|
||||
if handler, ok := d.msgHandlers[reflect.TypeOf(m)]; ok {
|
||||
handler(m)
|
||||
} else if d.defaultHandler != nil {
|
||||
d.defaultHandler(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -170,10 +167,6 @@ func (d *Dispatcher) RegisterHandler(msg Message, handler func(Message)) {
|
|||
d.msgHandlers[reflect.TypeOf(msg)] = handler
|
||||
}
|
||||
|
||||
func (d *Dispatcher) RegisterDefaultHandler(handler func(Message)) {
|
||||
d.defaultHandler = handler
|
||||
}
|
||||
|
||||
func (d *Dispatcher) Done() chan struct{} {
|
||||
return d.doneCh
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"k8s.io/utils/clock"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -144,19 +145,19 @@ func getBehaviorByModeAndIndex(mode int, index int) (RecommandBehavior, Recomman
|
|||
return behaviors[index].A, behaviors[index].B
|
||||
}
|
||||
|
||||
func getBehaviorScoresByMode(mode int, defaultScore int) []*BehaviorScore {
|
||||
func getBehaviorScoresByMode(mode int, defaultScore int) []*behaviorScore {
|
||||
return getBehaviorScoresByMode2(mode, defaultScore, defaultScore)
|
||||
}
|
||||
|
||||
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore {
|
||||
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*behaviorScore {
|
||||
behaviors := getBehaviorByMode(mode)
|
||||
scores := make([]*BehaviorScore, 0, len(behaviors))
|
||||
scores := make([]*behaviorScore, 0, len(behaviors))
|
||||
for i := range behaviors {
|
||||
score := receiverScore
|
||||
if behaviors[i].A.Role == DetectRoleSender {
|
||||
score = senderScore
|
||||
}
|
||||
scores = append(scores, &BehaviorScore{Mode: mode, Index: i, Score: score})
|
||||
scores = append(scores, &behaviorScore{Mode: mode, Index: i, Score: score})
|
||||
}
|
||||
return scores
|
||||
}
|
||||
|
|
@ -170,14 +171,18 @@ type RecommandBehavior struct {
|
|||
ListenRandomPorts int
|
||||
}
|
||||
|
||||
type MakeHoleRecords struct {
|
||||
type makeHoleRecords struct {
|
||||
mu sync.Mutex
|
||||
scores []*BehaviorScore
|
||||
LastUpdateTime time.Time
|
||||
scores []*behaviorScore
|
||||
clock clock.PassiveClock
|
||||
lastUpdateTime time.Time
|
||||
}
|
||||
|
||||
func NewMakeHoleRecords(c, v *NatFeature) *MakeHoleRecords {
|
||||
scores := []*BehaviorScore{}
|
||||
func newMakeHoleRecordsWithClock(c, v *NatFeature, clk clock.PassiveClock) *makeHoleRecords {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
scores := []*behaviorScore{}
|
||||
easyCount, hardCount, portsChangedRegularCount := ClassifyFeatureCount([]*NatFeature{c, v})
|
||||
appendMode0 := func() {
|
||||
switch {
|
||||
|
|
@ -212,13 +217,17 @@ func NewMakeHoleRecords(c, v *NatFeature) *MakeHoleRecords {
|
|||
scores = append(scores, getBehaviorScoresByMode(DetectMode1, 1)...)
|
||||
scores = append(scores, getBehaviorScoresByMode(DetectMode3, 1)...)
|
||||
}
|
||||
return &MakeHoleRecords{scores: scores, LastUpdateTime: time.Now()}
|
||||
return &makeHoleRecords{
|
||||
scores: scores,
|
||||
clock: clk,
|
||||
lastUpdateTime: clk.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (mhr *MakeHoleRecords) ReportSuccess(mode int, index int) {
|
||||
func (mhr *makeHoleRecords) reportSuccess(mode int, index int) {
|
||||
mhr.mu.Lock()
|
||||
defer mhr.mu.Unlock()
|
||||
mhr.LastUpdateTime = time.Now()
|
||||
mhr.lastUpdateTime = mhr.clock.Now()
|
||||
for i := range mhr.scores {
|
||||
score := mhr.scores[i]
|
||||
if score.Mode != mode || score.Index != index {
|
||||
|
|
@ -231,22 +240,22 @@ func (mhr *MakeHoleRecords) ReportSuccess(mode int, index int) {
|
|||
}
|
||||
}
|
||||
|
||||
func (mhr *MakeHoleRecords) Recommand() (mode, index int) {
|
||||
func (mhr *makeHoleRecords) recommand() (mode, index int) {
|
||||
mhr.mu.Lock()
|
||||
defer mhr.mu.Unlock()
|
||||
|
||||
if len(mhr.scores) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
maxScore := slices.MaxFunc(mhr.scores, func(a, b *BehaviorScore) int {
|
||||
maxScore := slices.MaxFunc(mhr.scores, func(a, b *behaviorScore) int {
|
||||
return cmp.Compare(a.Score, b.Score)
|
||||
})
|
||||
maxScore.Score--
|
||||
mhr.LastUpdateTime = time.Now()
|
||||
mhr.lastUpdateTime = mhr.clock.Now()
|
||||
return maxScore.Mode, maxScore.Index
|
||||
}
|
||||
|
||||
type BehaviorScore struct {
|
||||
type behaviorScore struct {
|
||||
Mode int
|
||||
Index int
|
||||
// between -10 and 10
|
||||
|
|
@ -255,16 +264,25 @@ type BehaviorScore struct {
|
|||
|
||||
type Analyzer struct {
|
||||
// key is client ip + visitor ip
|
||||
records map[string]*MakeHoleRecords
|
||||
records map[string]*makeHoleRecords
|
||||
dataReserveDuration time.Duration
|
||||
clock clock.PassiveClock
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewAnalyzer(dataReserveDuration time.Duration) *Analyzer {
|
||||
return newAnalyzerWithClock(dataReserveDuration, clock.RealClock{})
|
||||
}
|
||||
|
||||
func newAnalyzerWithClock(dataReserveDuration time.Duration, clk clock.PassiveClock) *Analyzer {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
return &Analyzer{
|
||||
records: make(map[string]*MakeHoleRecords),
|
||||
records: make(map[string]*makeHoleRecords),
|
||||
dataReserveDuration: dataReserveDuration,
|
||||
clock: clk,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -272,12 +290,12 @@ func (a *Analyzer) GetRecommandBehaviors(key string, c, v *NatFeature) (mode, in
|
|||
a.mu.Lock()
|
||||
records, ok := a.records[key]
|
||||
if !ok {
|
||||
records = NewMakeHoleRecords(c, v)
|
||||
records = newMakeHoleRecordsWithClock(c, v, a.clock)
|
||||
a.records[key] = records
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
mode, index = records.Recommand()
|
||||
mode, index = records.recommand()
|
||||
cBehavior, vBehavior := getBehaviorByModeAndIndex(mode, index)
|
||||
|
||||
switch mode {
|
||||
|
|
@ -307,11 +325,11 @@ func (a *Analyzer) ReportSuccess(key string, mode, index int) {
|
|||
if !ok {
|
||||
return
|
||||
}
|
||||
records.ReportSuccess(mode, index)
|
||||
records.reportSuccess(mode, index)
|
||||
}
|
||||
|
||||
func (a *Analyzer) Clean() (int, int) {
|
||||
now := time.Now()
|
||||
now := a.clock.Now()
|
||||
total := 0
|
||||
count := 0
|
||||
|
||||
|
|
@ -321,7 +339,7 @@ func (a *Analyzer) Clean() (int, int) {
|
|||
total = len(a.records)
|
||||
// clean up records that have not been used for a period of time.
|
||||
for key, records := range a.records {
|
||||
if now.Sub(records.LastUpdateTime) > a.dataReserveDuration {
|
||||
if now.Sub(records.lastUpdateTime) > a.dataReserveDuration {
|
||||
delete(a.records, key)
|
||||
count++
|
||||
}
|
||||
|
|
|
|||
33
pkg/nathole/analysis_test.go
Normal file
33
pkg/nathole/analysis_test.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package nathole
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
)
|
||||
|
||||
func TestAnalyzerUsesClockForRecordTimestamps(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
analyzer := newAnalyzerWithClock(time.Hour, clk)
|
||||
clientFeature := &NatFeature{NatType: EasyNAT, Behavior: BehaviorNoChange}
|
||||
visitorFeature := &NatFeature{NatType: EasyNAT, Behavior: BehaviorNoChange}
|
||||
|
||||
mode, index, _, _ := analyzer.GetRecommandBehaviors("key", clientFeature, visitorFeature)
|
||||
require.Equal(start, analyzer.records["key"].lastUpdateTime)
|
||||
|
||||
updatedAt := start.Add(time.Minute)
|
||||
clk.SetTime(updatedAt)
|
||||
analyzer.ReportSuccess("key", mode, index)
|
||||
require.Equal(updatedAt, analyzer.records["key"].lastUpdateTime)
|
||||
|
||||
clk.SetTime(start.Add(2 * time.Hour))
|
||||
count, total := analyzer.Clean()
|
||||
require.Equal(1, count)
|
||||
require.Equal(1, total)
|
||||
require.Empty(analyzer.records)
|
||||
}
|
||||
|
|
@ -326,40 +326,16 @@ func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleR
|
|||
}
|
||||
|
||||
protocol := vm.Protocol
|
||||
vResp := &msg.NatHoleResp{
|
||||
TransactionID: vm.TransactionID,
|
||||
Sid: session.sid,
|
||||
Protocol: protocol,
|
||||
CandidateAddrs: slices.Compact(cm.MappedAddrs),
|
||||
AssistedAddrs: slices.Compact(cm.AssistedAddrs),
|
||||
DetectBehavior: msg.NatHoleDetectBehavior{
|
||||
Mode: mode,
|
||||
Role: vBehavior.Role,
|
||||
TTL: vBehavior.TTL,
|
||||
SendDelayMs: vBehavior.SendDelayMs,
|
||||
ReadTimeoutMs: timeoutMs - vBehavior.SendDelayMs,
|
||||
SendRandomPorts: vBehavior.PortsRandomNumber,
|
||||
ListenRandomPorts: vBehavior.ListenRandomPorts,
|
||||
CandidatePorts: getRangePorts(cm.MappedAddrs, cNatFeature.PortsDifference, vBehavior.PortsRangeNumber),
|
||||
},
|
||||
}
|
||||
cResp := &msg.NatHoleResp{
|
||||
TransactionID: cm.TransactionID,
|
||||
Sid: session.sid,
|
||||
Protocol: protocol,
|
||||
CandidateAddrs: slices.Compact(vm.MappedAddrs),
|
||||
AssistedAddrs: slices.Compact(vm.AssistedAddrs),
|
||||
DetectBehavior: msg.NatHoleDetectBehavior{
|
||||
Mode: mode,
|
||||
Role: cBehavior.Role,
|
||||
TTL: cBehavior.TTL,
|
||||
SendDelayMs: cBehavior.SendDelayMs,
|
||||
ReadTimeoutMs: timeoutMs - cBehavior.SendDelayMs,
|
||||
SendRandomPorts: cBehavior.PortsRandomNumber,
|
||||
ListenRandomPorts: cBehavior.ListenRandomPorts,
|
||||
CandidatePorts: getRangePorts(vm.MappedAddrs, vNatFeature.PortsDifference, cBehavior.PortsRangeNumber),
|
||||
},
|
||||
}
|
||||
vResp := newNatHoleResponse(
|
||||
vm.TransactionID, session.sid, protocol, mode,
|
||||
cm.MappedAddrs, cm.AssistedAddrs, vBehavior,
|
||||
timeoutMs-vBehavior.SendDelayMs, cNatFeature.PortsDifference,
|
||||
)
|
||||
cResp := newNatHoleResponse(
|
||||
cm.TransactionID, session.sid, protocol, mode,
|
||||
vm.MappedAddrs, vm.AssistedAddrs, cBehavior,
|
||||
timeoutMs-cBehavior.SendDelayMs, vNatFeature.PortsDifference,
|
||||
)
|
||||
|
||||
log.Debugf("sid [%s] visitor nat: %+v, candidateAddrs: %v; client nat: %+v, candidateAddrs: %v, protocol: %s",
|
||||
session.sid, *vNatFeature, vm.MappedAddrs, *cNatFeature, cm.MappedAddrs, protocol)
|
||||
|
|
@ -368,6 +344,38 @@ func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleR
|
|||
return vResp, cResp, nil
|
||||
}
|
||||
|
||||
func newNatHoleResponse(
|
||||
transactionID string,
|
||||
sid string,
|
||||
protocol string,
|
||||
mode int,
|
||||
candidateAddrs []string,
|
||||
assistedAddrs []string,
|
||||
behavior RecommandBehavior,
|
||||
readTimeoutMs int,
|
||||
portsDifference int,
|
||||
) *msg.NatHoleResp {
|
||||
compactCandidateAddrs := slices.Compact(candidateAddrs)
|
||||
compactAssistedAddrs := slices.Compact(assistedAddrs)
|
||||
return &msg.NatHoleResp{
|
||||
TransactionID: transactionID,
|
||||
Sid: sid,
|
||||
Protocol: protocol,
|
||||
CandidateAddrs: compactCandidateAddrs,
|
||||
AssistedAddrs: compactAssistedAddrs,
|
||||
DetectBehavior: msg.NatHoleDetectBehavior{
|
||||
Mode: mode,
|
||||
Role: behavior.Role,
|
||||
TTL: behavior.TTL,
|
||||
SendDelayMs: behavior.SendDelayMs,
|
||||
ReadTimeoutMs: readTimeoutMs,
|
||||
SendRandomPorts: behavior.PortsRandomNumber,
|
||||
ListenRandomPorts: behavior.ListenRandomPorts,
|
||||
CandidatePorts: getRangePorts(candidateAddrs, portsDifference, behavior.PortsRangeNumber),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getRangePorts(addrs []string, difference, maxNumber int) []msg.PortsRange {
|
||||
if maxNumber <= 0 {
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -17,17 +17,9 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
|
@ -37,57 +29,28 @@ func init() {
|
|||
type HTTP2HTTPPlugin struct {
|
||||
opts *v1.HTTP2HTTPPluginOptions
|
||||
|
||||
l *Listener
|
||||
s *http.Server
|
||||
*httpBridgePlugin
|
||||
}
|
||||
|
||||
func NewHTTP2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
||||
opts := options.(*v1.HTTP2HTTPPluginOptions)
|
||||
|
||||
listener := NewProxyListener()
|
||||
|
||||
p := &HTTP2HTTPPlugin{
|
||||
opts: opts,
|
||||
l: listener,
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
rp := newHTTPBridgeReverseProxy(
|
||||
func(r *httputil.ProxyRequest) {
|
||||
req := r.Out
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = p.opts.LocalAddr
|
||||
if p.opts.HostHeaderRewrite != "" {
|
||||
req.Host = p.opts.HostHeaderRewrite
|
||||
}
|
||||
for k, v := range p.opts.RequestHeaders.Set {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
rewriteHTTPPluginRequest(req, "http", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
|
||||
},
|
||||
BufferPool: pool.NewBuffer(32 * 1024),
|
||||
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
||||
}
|
||||
|
||||
p.s = &http.Server{
|
||||
Handler: rp,
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = p.s.Serve(listener)
|
||||
}()
|
||||
nil,
|
||||
)
|
||||
p.httpBridgePlugin = newHTTPBridgePluginServer(rp, false)
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
||||
_ = p.l.PutConn(wrapConn)
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPPlugin) Name() string {
|
||||
return v1.PluginHTTP2HTTP
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPPlugin) Close() error {
|
||||
return p.s.Close()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,18 +17,11 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
|
@ -38,65 +31,35 @@ func init() {
|
|||
type HTTP2HTTPSPlugin struct {
|
||||
opts *v1.HTTP2HTTPSPluginOptions
|
||||
|
||||
l *Listener
|
||||
s *http.Server
|
||||
*httpBridgePlugin
|
||||
}
|
||||
|
||||
func NewHTTP2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
||||
opts := options.(*v1.HTTP2HTTPSPluginOptions)
|
||||
|
||||
listener := NewProxyListener()
|
||||
|
||||
p := &HTTP2HTTPSPlugin{
|
||||
opts: opts,
|
||||
l: listener,
|
||||
}
|
||||
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
rp := newHTTPBridgeReverseProxy(
|
||||
func(r *httputil.ProxyRequest) {
|
||||
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
||||
r.Out.Header["X-Forwarded-Host"] = r.In.Header["X-Forwarded-Host"]
|
||||
r.Out.Header["X-Forwarded-Proto"] = r.In.Header["X-Forwarded-Proto"]
|
||||
req := r.Out
|
||||
req.URL.Scheme = "https"
|
||||
req.URL.Host = p.opts.LocalAddr
|
||||
if p.opts.HostHeaderRewrite != "" {
|
||||
req.Host = p.opts.HostHeaderRewrite
|
||||
}
|
||||
for k, v := range p.opts.RequestHeaders.Set {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
rewriteHTTPPluginRequest(req, "https", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
|
||||
},
|
||||
Transport: tr,
|
||||
BufferPool: pool.NewBuffer(32 * 1024),
|
||||
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
||||
}
|
||||
|
||||
p.s = &http.Server{
|
||||
Handler: rp,
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = p.s.Serve(listener)
|
||||
}()
|
||||
tr,
|
||||
)
|
||||
p.httpBridgePlugin = newHTTPBridgePluginServer(rp, false)
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPSPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
||||
_ = p.l.PutConn(wrapConn)
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPSPlugin) Name() string {
|
||||
return v1.PluginHTTP2HTTPS
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPSPlugin) Close() error {
|
||||
return p.s.Close()
|
||||
}
|
||||
|
|
|
|||
126
pkg/plugin/client/http_common.go
Normal file
126
pkg/plugin/client/http_common.go
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !frps
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/plugin/client/internal/httpsserver"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
const httpBridgeReadHeaderTimeout = 60 * time.Second
|
||||
|
||||
func rewriteHTTPPluginRequest(
|
||||
req *http.Request,
|
||||
scheme string,
|
||||
localAddr string,
|
||||
hostHeaderRewrite string,
|
||||
requestHeaders v1.HeaderOperations,
|
||||
) {
|
||||
req.URL.Scheme = scheme
|
||||
req.URL.Host = localAddr
|
||||
if hostHeaderRewrite != "" {
|
||||
req.Host = hostHeaderRewrite
|
||||
}
|
||||
for k, v := range requestHeaders.Set {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
type httpBridgePlugin struct {
|
||||
l *Listener
|
||||
s *http.Server
|
||||
|
||||
useSourceRemoteAddr bool
|
||||
}
|
||||
|
||||
func newHTTPBridgePluginServer(handler http.Handler, useSourceRemoteAddr bool) *httpBridgePlugin {
|
||||
listener := NewProxyListener()
|
||||
p := &httpBridgePlugin{
|
||||
l: listener,
|
||||
useSourceRemoteAddr: useSourceRemoteAddr,
|
||||
}
|
||||
p.s = &http.Server{
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: httpBridgeReadHeaderTimeout,
|
||||
}
|
||||
go func() {
|
||||
_ = p.s.Serve(listener)
|
||||
}()
|
||||
return p
|
||||
}
|
||||
|
||||
func newHTTPSBridgePluginServer(
|
||||
handler http.Handler,
|
||||
crtPath string,
|
||||
keyPath string,
|
||||
enableHTTP2 *bool,
|
||||
useSourceRemoteAddr bool,
|
||||
) (*httpBridgePlugin, error) {
|
||||
listener := NewProxyListener()
|
||||
server, err := httpsserver.New(handler, crtPath, keyPath, enableHTTP2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := &httpBridgePlugin{
|
||||
l: listener,
|
||||
s: server,
|
||||
useSourceRemoteAddr: useSourceRemoteAddr,
|
||||
}
|
||||
go func() {
|
||||
_ = p.s.ServeTLS(listener, "", "")
|
||||
}()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func newHTTPBridgeReverseProxy(
|
||||
rewrite func(*httputil.ProxyRequest),
|
||||
transport http.RoundTripper,
|
||||
) *httputil.ReverseProxy {
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: rewrite,
|
||||
BufferPool: pool.NewBuffer(32 * 1024),
|
||||
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
||||
}
|
||||
if transport != nil {
|
||||
rp.Transport = transport
|
||||
}
|
||||
return rp
|
||||
}
|
||||
|
||||
func (p *httpBridgePlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
||||
if p.useSourceRemoteAddr && connInfo.SrcAddr != nil {
|
||||
wrapConn.SetRemoteAddr(connInfo.SrcAddr)
|
||||
}
|
||||
_ = p.l.PutConn(wrapConn)
|
||||
}
|
||||
|
||||
func (p *httpBridgePlugin) Close() error {
|
||||
err := p.s.Close()
|
||||
_ = p.l.Close()
|
||||
return err
|
||||
}
|
||||
|
|
@ -17,22 +17,9 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
"github.com/samber/lo"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
|
@ -42,80 +29,35 @@ func init() {
|
|||
type HTTPS2HTTPPlugin struct {
|
||||
opts *v1.HTTPS2HTTPPluginOptions
|
||||
|
||||
l *Listener
|
||||
s *http.Server
|
||||
*httpBridgePlugin
|
||||
}
|
||||
|
||||
func NewHTTPS2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
||||
opts := options.(*v1.HTTPS2HTTPPluginOptions)
|
||||
listener := NewProxyListener()
|
||||
|
||||
p := &HTTPS2HTTPPlugin{
|
||||
opts: opts,
|
||||
l: listener,
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
rp := newHTTPBridgeReverseProxy(
|
||||
func(r *httputil.ProxyRequest) {
|
||||
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
||||
r.SetXForwarded()
|
||||
req := r.Out
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = p.opts.LocalAddr
|
||||
if p.opts.HostHeaderRewrite != "" {
|
||||
req.Host = p.opts.HostHeaderRewrite
|
||||
}
|
||||
for k, v := range p.opts.RequestHeaders.Set {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
rewriteHTTPPluginRequest(req, "http", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
|
||||
},
|
||||
BufferPool: pool.NewBuffer(32 * 1024),
|
||||
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
||||
}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS != nil {
|
||||
tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName)
|
||||
host, _ := httppkg.CanonicalHost(r.Host)
|
||||
if tlsServerName != "" && tlsServerName != host {
|
||||
w.WriteHeader(http.StatusMisdirectedRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
rp.ServeHTTP(w, r)
|
||||
})
|
||||
nil,
|
||||
)
|
||||
|
||||
tlsConfig, err := transport.NewServerTLSConfig(p.opts.CrtPath, p.opts.KeyPath, "")
|
||||
server, err := newHTTPSBridgePluginServer(rp, p.opts.CrtPath, p.opts.KeyPath, opts.EnableHTTP2, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gen TLS config error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
p.httpBridgePlugin = server
|
||||
|
||||
p.s = &http.Server{
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
if !lo.FromPtr(opts.EnableHTTP2) {
|
||||
p.s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = p.s.ServeTLS(listener, "", "")
|
||||
}()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
||||
if connInfo.SrcAddr != nil {
|
||||
wrapConn.SetRemoteAddr(connInfo.SrcAddr)
|
||||
}
|
||||
_ = p.l.PutConn(wrapConn)
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPPlugin) Name() string {
|
||||
return v1.PluginHTTPS2HTTP
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPPlugin) Close() error {
|
||||
return p.s.Close()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,22 +17,11 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
"github.com/samber/lo"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
|
@ -42,86 +31,39 @@ func init() {
|
|||
type HTTPS2HTTPSPlugin struct {
|
||||
opts *v1.HTTPS2HTTPSPluginOptions
|
||||
|
||||
l *Listener
|
||||
s *http.Server
|
||||
*httpBridgePlugin
|
||||
}
|
||||
|
||||
func NewHTTPS2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
||||
opts := options.(*v1.HTTPS2HTTPSPluginOptions)
|
||||
|
||||
listener := NewProxyListener()
|
||||
|
||||
p := &HTTPS2HTTPSPlugin{
|
||||
opts: opts,
|
||||
l: listener,
|
||||
}
|
||||
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
rp := newHTTPBridgeReverseProxy(
|
||||
func(r *httputil.ProxyRequest) {
|
||||
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
||||
r.SetXForwarded()
|
||||
req := r.Out
|
||||
req.URL.Scheme = "https"
|
||||
req.URL.Host = p.opts.LocalAddr
|
||||
if p.opts.HostHeaderRewrite != "" {
|
||||
req.Host = p.opts.HostHeaderRewrite
|
||||
}
|
||||
for k, v := range p.opts.RequestHeaders.Set {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
rewriteHTTPPluginRequest(req, "https", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
|
||||
},
|
||||
Transport: tr,
|
||||
BufferPool: pool.NewBuffer(32 * 1024),
|
||||
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
||||
}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS != nil {
|
||||
tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName)
|
||||
host, _ := httppkg.CanonicalHost(r.Host)
|
||||
if tlsServerName != "" && tlsServerName != host {
|
||||
w.WriteHeader(http.StatusMisdirectedRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
rp.ServeHTTP(w, r)
|
||||
})
|
||||
tr,
|
||||
)
|
||||
|
||||
tlsConfig, err := transport.NewServerTLSConfig(p.opts.CrtPath, p.opts.KeyPath, "")
|
||||
server, err := newHTTPSBridgePluginServer(rp, p.opts.CrtPath, p.opts.KeyPath, opts.EnableHTTP2, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gen TLS config error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
p.httpBridgePlugin = server
|
||||
|
||||
p.s = &http.Server{
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
if !lo.FromPtr(opts.EnableHTTP2) {
|
||||
p.s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = p.s.ServeTLS(listener, "", "")
|
||||
}()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPSPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
||||
if connInfo.SrcAddr != nil {
|
||||
wrapConn.SetRemoteAddr(connInfo.SrcAddr)
|
||||
}
|
||||
_ = p.l.PutConn(wrapConn)
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPSPlugin) Name() string {
|
||||
return v1.PluginHTTPS2HTTPS
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPSPlugin) Close() error {
|
||||
return p.s.Close()
|
||||
}
|
||||
|
|
|
|||
60
pkg/plugin/client/internal/httpsserver/server.go
Normal file
60
pkg/plugin/client/internal/httpsserver/server.go
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !frps
|
||||
|
||||
package httpsserver
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
)
|
||||
|
||||
func New(handler http.Handler, crtPath, keyPath string, enableHTTP2 *bool) (*http.Server, error) {
|
||||
tlsConfig, err := transport.NewServerTLSConfig(crtPath, keyPath, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gen TLS config error: %v", err)
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: withMisdirectedRequestCheck(handler),
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
if !lo.FromPtr(enableHTTP2) {
|
||||
server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
|
||||
}
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func withMisdirectedRequestCheck(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS != nil {
|
||||
tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName)
|
||||
host, _ := httppkg.CanonicalHost(r.Host)
|
||||
if tlsServerName != "" && tlsServerName != host {
|
||||
w.WriteHeader(http.StatusMisdirectedRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
|
@ -44,6 +44,67 @@ func NewManager() *Manager {
|
|||
}
|
||||
}
|
||||
|
||||
func newPluginRequestContext() (context.Context, *xlog.Logger) {
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
return NewReqidContext(ctx, reqid), xl
|
||||
}
|
||||
|
||||
type pluginErrorLogMode bool
|
||||
|
||||
const (
|
||||
// Warn is the zero value because it is the default for mutable plugin operations.
|
||||
pluginErrorLogWarn pluginErrorLogMode = false
|
||||
pluginErrorLogInfo pluginErrorLogMode = true
|
||||
)
|
||||
|
||||
func logPluginError(xl *xlog.Logger, p Plugin, op string, err error, mode pluginErrorLogMode) {
|
||||
if mode == pluginErrorLogInfo {
|
||||
xl.Infof("send %s request to plugin [%s] error: %v", op, p.Name(), err)
|
||||
return
|
||||
}
|
||||
xl.Warnf("send %s request to plugin [%s] error: %v", op, p.Name(), err)
|
||||
}
|
||||
|
||||
func handleMutableContent[T any](
|
||||
plugins []Plugin,
|
||||
op string,
|
||||
content *T,
|
||||
logMode pluginErrorLogMode,
|
||||
) (*T, error) {
|
||||
if len(plugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent any
|
||||
err error
|
||||
)
|
||||
ctx, xl := newPluginRequestContext()
|
||||
|
||||
for _, p := range plugins {
|
||||
res, retContent, err = p.Handle(ctx, op, *content)
|
||||
if err != nil {
|
||||
logPluginError(xl, p, op, err, logMode)
|
||||
return nil, errors.New("send " + op + " request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
// Preserve the existing Plugin contract: changed content must be *T.
|
||||
// Buggy Plugin implementations still panic here, by design.
|
||||
content = retContent.(*T)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Register(p Plugin) {
|
||||
if p.IsSupport(OpLogin) {
|
||||
m.loginPlugins = append(m.loginPlugins, p)
|
||||
|
|
@ -66,71 +127,11 @@ func (m *Manager) Register(p Plugin) {
|
|||
}
|
||||
|
||||
func (m *Manager) Login(content *LoginContent) (*LoginContent, error) {
|
||||
if len(m.loginPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent any
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.loginPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpLogin, *content)
|
||||
if err != nil {
|
||||
xl.Warnf("send Login request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send Login request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*LoginContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
return handleMutableContent(m.loginPlugins, OpLogin, content, pluginErrorLogWarn)
|
||||
}
|
||||
|
||||
func (m *Manager) NewProxy(content *NewProxyContent) (*NewProxyContent, error) {
|
||||
if len(m.newProxyPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent any
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.newProxyPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpNewProxy, *content)
|
||||
if err != nil {
|
||||
xl.Warnf("send NewProxy request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send NewProxy request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*NewProxyContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
return handleMutableContent(m.newProxyPlugins, OpNewProxy, content, pluginErrorLogWarn)
|
||||
}
|
||||
|
||||
func (m *Manager) CloseProxy(content *CloseProxyContent) error {
|
||||
|
|
@ -139,10 +140,7 @@ func (m *Manager) CloseProxy(content *CloseProxyContent) error {
|
|||
}
|
||||
|
||||
errs := make([]string, 0)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
ctx, xl := newPluginRequestContext()
|
||||
|
||||
for _, p := range m.closeProxyPlugins {
|
||||
_, _, err := p.Handle(ctx, OpCloseProxy, *content)
|
||||
|
|
@ -159,103 +157,14 @@ func (m *Manager) CloseProxy(content *CloseProxyContent) error {
|
|||
}
|
||||
|
||||
func (m *Manager) Ping(content *PingContent) (*PingContent, error) {
|
||||
if len(m.pingPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent any
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.pingPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpPing, *content)
|
||||
if err != nil {
|
||||
xl.Warnf("send Ping request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send Ping request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*PingContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
return handleMutableContent(m.pingPlugins, OpPing, content, pluginErrorLogWarn)
|
||||
}
|
||||
|
||||
func (m *Manager) NewWorkConn(content *NewWorkConnContent) (*NewWorkConnContent, error) {
|
||||
if len(m.newWorkConnPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent any
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.newWorkConnPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpNewWorkConn, *content)
|
||||
if err != nil {
|
||||
xl.Warnf("send NewWorkConn request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send NewWorkConn request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*NewWorkConnContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
return handleMutableContent(m.newWorkConnPlugins, OpNewWorkConn, content, pluginErrorLogWarn)
|
||||
}
|
||||
|
||||
func (m *Manager) NewUserConn(content *NewUserConnContent) (*NewUserConnContent, error) {
|
||||
if len(m.newUserConnPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent any
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.newUserConnPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpNewUserConn, *content)
|
||||
if err != nil {
|
||||
xl.Infof("send NewUserConn request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send NewUserConn request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*NewUserConnContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
// Preserve the pre-refactor log level for NewUserConn plugin errors.
|
||||
return handleMutableContent(m.newUserConnPlugins, OpNewUserConn, content, pluginErrorLogInfo)
|
||||
}
|
||||
|
|
|
|||
336
pkg/plugin/server/manager_test.go
Normal file
336
pkg/plugin/server/manager_test.go
Normal file
|
|
@ -0,0 +1,336 @@
|
|||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
goliblog "github.com/fatedier/golib/log"
|
||||
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
frplog "github.com/fatedier/frp/pkg/util/log"
|
||||
)
|
||||
|
||||
type testPlugin struct {
|
||||
name string
|
||||
ops map[string]bool
|
||||
handler func(context.Context, string, any) (*Response, any, error)
|
||||
}
|
||||
|
||||
// Log-capturing subtests serialize global logger swaps; do not use t.Parallel.
|
||||
var logCaptureMu sync.Mutex
|
||||
|
||||
type logCapture struct {
|
||||
bytes.Buffer
|
||||
levels []goliblog.Level
|
||||
}
|
||||
|
||||
func (p testPlugin) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p testPlugin) IsSupport(op string) bool {
|
||||
return p.ops[op]
|
||||
}
|
||||
|
||||
func (p testPlugin) Handle(ctx context.Context, op string, content any) (*Response, any, error) {
|
||||
return p.handler(ctx, op, content)
|
||||
}
|
||||
|
||||
func (w *logCapture) WriteLog(p []byte, level goliblog.Level, _ time.Time) (int, error) {
|
||||
w.levels = append(w.levels, level)
|
||||
return w.Write(p)
|
||||
}
|
||||
|
||||
func captureLogOutput(t *testing.T) *logCapture {
|
||||
t.Helper()
|
||||
|
||||
logCaptureMu.Lock()
|
||||
logOutput := &logCapture{}
|
||||
oldLogger := frplog.Logger
|
||||
frplog.Logger = goliblog.New(
|
||||
goliblog.WithOutput(logOutput),
|
||||
goliblog.WithLevel(goliblog.TraceLevel),
|
||||
goliblog.WithCaller(false),
|
||||
)
|
||||
t.Cleanup(func() {
|
||||
frplog.Logger = oldLogger
|
||||
logCaptureMu.Unlock()
|
||||
})
|
||||
return logOutput
|
||||
}
|
||||
|
||||
var mutablePluginOps = []struct {
|
||||
name string
|
||||
op string
|
||||
}{
|
||||
{name: "login", op: OpLogin},
|
||||
{name: "new proxy", op: OpNewProxy},
|
||||
{name: "ping", op: OpPing},
|
||||
{name: "new work conn", op: OpNewWorkConn},
|
||||
{name: "new user conn", op: OpNewUserConn},
|
||||
}
|
||||
|
||||
func callMutableWithUser(m *Manager, op string, user string) (string, error) {
|
||||
switch op {
|
||||
case OpLogin:
|
||||
got, err := m.Login(&LoginContent{Login: msg.Login{User: user}})
|
||||
if got == nil {
|
||||
return "", err
|
||||
}
|
||||
return got.User, err
|
||||
case OpNewProxy:
|
||||
got, err := m.NewProxy(&NewProxyContent{User: UserInfo{User: user}})
|
||||
if got == nil {
|
||||
return "", err
|
||||
}
|
||||
return got.User.User, err
|
||||
case OpPing:
|
||||
got, err := m.Ping(&PingContent{User: UserInfo{User: user}})
|
||||
if got == nil {
|
||||
return "", err
|
||||
}
|
||||
return got.User.User, err
|
||||
case OpNewWorkConn:
|
||||
got, err := m.NewWorkConn(&NewWorkConnContent{User: UserInfo{User: user}})
|
||||
if got == nil {
|
||||
return "", err
|
||||
}
|
||||
return got.User.User, err
|
||||
case OpNewUserConn:
|
||||
got, err := m.NewUserConn(&NewUserConnContent{User: UserInfo{User: user}})
|
||||
if got == nil {
|
||||
return "", err
|
||||
}
|
||||
return got.User.User, err
|
||||
default:
|
||||
panic("unsupported mutable op: " + op)
|
||||
}
|
||||
}
|
||||
|
||||
func mutableUser(t *testing.T, op string, content any) string {
|
||||
t.Helper()
|
||||
|
||||
switch op {
|
||||
case OpLogin:
|
||||
return content.(LoginContent).User
|
||||
case OpNewProxy:
|
||||
return content.(NewProxyContent).User.User
|
||||
case OpPing:
|
||||
return content.(PingContent).User.User
|
||||
case OpNewWorkConn:
|
||||
return content.(NewWorkConnContent).User.User
|
||||
case OpNewUserConn:
|
||||
return content.(NewUserConnContent).User.User
|
||||
default:
|
||||
t.Fatalf("unsupported mutable op: %s", op)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func mutateMutableContent(t *testing.T, op string, content any, user string) any {
|
||||
t.Helper()
|
||||
|
||||
switch op {
|
||||
case OpLogin:
|
||||
got := content.(LoginContent)
|
||||
got.User = user
|
||||
return &got
|
||||
case OpNewProxy:
|
||||
got := content.(NewProxyContent)
|
||||
got.User.User = user
|
||||
return &got
|
||||
case OpPing:
|
||||
got := content.(PingContent)
|
||||
got.User.User = user
|
||||
return &got
|
||||
case OpNewWorkConn:
|
||||
got := content.(NewWorkConnContent)
|
||||
got.User.User = user
|
||||
return &got
|
||||
case OpNewUserConn:
|
||||
got := content.(NewUserConnContent)
|
||||
got.User.User = user
|
||||
return &got
|
||||
default:
|
||||
t.Fatalf("unsupported mutable op: %s", op)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerMutableContentAcrossOps(t *testing.T) {
|
||||
for _, tt := range mutablePluginOps {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.Register(testPlugin{
|
||||
name: "mutate",
|
||||
ops: map[string]bool{tt.op: true},
|
||||
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
|
||||
if op != tt.op {
|
||||
t.Fatalf("unexpected op: %s", op)
|
||||
}
|
||||
if GetReqidFromContext(ctx) == "" {
|
||||
t.Fatal("expected request id in context")
|
||||
}
|
||||
if got := mutableUser(t, tt.op, content); got != "initial" {
|
||||
t.Fatalf("expected initial user, got %q", got)
|
||||
}
|
||||
return &Response{Unchange: false}, mutateMutableContent(t, tt.op, content, "mutated"), nil
|
||||
},
|
||||
})
|
||||
m.Register(testPlugin{
|
||||
name: "observe",
|
||||
ops: map[string]bool{tt.op: true},
|
||||
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
|
||||
if op != tt.op {
|
||||
t.Fatalf("unexpected op: %s", op)
|
||||
}
|
||||
if GetReqidFromContext(ctx) == "" {
|
||||
t.Fatal("expected request id in context")
|
||||
}
|
||||
if got := mutableUser(t, tt.op, content); got != "mutated" {
|
||||
t.Fatalf("expected mutated user, got %q", got)
|
||||
}
|
||||
return &Response{Unchange: true}, mutateMutableContent(t, tt.op, content, "ignored"), nil
|
||||
},
|
||||
})
|
||||
|
||||
got, err := callMutableWithUser(m, tt.op, "initial")
|
||||
if err != nil {
|
||||
t.Fatalf("mutable op failed: %v", err)
|
||||
}
|
||||
if got != "mutated" {
|
||||
t.Fatalf("expected mutated user, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerMutableContentRejectStopsChain(t *testing.T) {
|
||||
m := NewManager()
|
||||
|
||||
var called bool
|
||||
m.Register(testPlugin{
|
||||
name: "reject",
|
||||
ops: map[string]bool{OpPing: true},
|
||||
handler: func(context.Context, string, any) (*Response, any, error) {
|
||||
return &Response{Reject: true, RejectReason: "blocked"}, nil, nil
|
||||
},
|
||||
})
|
||||
m.Register(testPlugin{
|
||||
name: "unused",
|
||||
ops: map[string]bool{OpPing: true},
|
||||
handler: func(context.Context, string, any) (*Response, any, error) {
|
||||
called = true
|
||||
return &Response{Unchange: true}, nil, nil
|
||||
},
|
||||
})
|
||||
|
||||
got, err := m.Ping(&PingContent{})
|
||||
if err == nil {
|
||||
t.Fatal("expected reject error")
|
||||
}
|
||||
if got != nil {
|
||||
t.Fatalf("expected no returned content, got %#v", got)
|
||||
}
|
||||
if err.Error() != "blocked" {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if called {
|
||||
t.Fatal("expected plugin chain to stop after reject")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerMutableContentPluginErrorLogLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
op string
|
||||
level goliblog.Level
|
||||
}{
|
||||
{name: "default warning", op: OpLogin, level: goliblog.WarnLevel},
|
||||
{name: "new user conn info", op: OpNewUserConn, level: goliblog.InfoLevel},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logOutput := captureLogOutput(t)
|
||||
m := NewManager()
|
||||
m.Register(testPlugin{
|
||||
name: "error",
|
||||
ops: map[string]bool{tt.op: true},
|
||||
handler: func(context.Context, string, any) (*Response, any, error) {
|
||||
return nil, nil, errors.New("boom")
|
||||
},
|
||||
})
|
||||
|
||||
_, err := callMutableWithUser(m, tt.op, "initial")
|
||||
if err == nil {
|
||||
t.Fatal("expected plugin error")
|
||||
}
|
||||
if want := "send " + tt.op + " request to plugin error"; err.Error() != want {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(logOutput.levels) != 1 || logOutput.levels[0] != tt.level {
|
||||
t.Fatalf("expected log level %v, got %v in %q", tt.level, logOutput.levels, logOutput.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCloseProxyAggregatesErrors(t *testing.T) {
|
||||
logOutput := captureLogOutput(t)
|
||||
m := NewManager()
|
||||
|
||||
for _, name := range []string{"first", "second"} {
|
||||
m.Register(testPlugin{
|
||||
name: name,
|
||||
ops: map[string]bool{OpCloseProxy: true},
|
||||
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
|
||||
if GetReqidFromContext(ctx) == "" {
|
||||
t.Fatal("expected request id in context")
|
||||
}
|
||||
if op != OpCloseProxy {
|
||||
t.Fatalf("unexpected op: %s", op)
|
||||
}
|
||||
return nil, nil, errors.New(name + " error")
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
err := m.CloseProxy(&CloseProxyContent{})
|
||||
if err == nil {
|
||||
t.Fatal("expected close proxy error")
|
||||
}
|
||||
if !strings.HasPrefix(err.Error(), "send CloseProxy request to plugin errors: ") {
|
||||
t.Fatalf("unexpected close proxy error prefix: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "[first]: first error") || !strings.Contains(err.Error(), "[second]: second error") {
|
||||
t.Fatalf("missing aggregated errors: %v", err)
|
||||
}
|
||||
if len(logOutput.levels) != 2 {
|
||||
t.Fatalf("expected two warning logs, got %v", logOutput.levels)
|
||||
}
|
||||
for _, level := range logOutput.levels {
|
||||
if level != goliblog.WarnLevel {
|
||||
t.Fatalf("expected warning log level, got %v", logOutput.levels)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -17,6 +17,8 @@ package metric
|
|||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"k8s.io/utils/clock"
|
||||
)
|
||||
|
||||
type DateCounter interface {
|
||||
|
|
@ -38,27 +40,33 @@ func NewDateCounter(reserveDays int64) DateCounter {
|
|||
type StandardDateCounter struct {
|
||||
reserveDays int64
|
||||
counts []int64
|
||||
clock clock.PassiveClock
|
||||
|
||||
lastUpdateDate time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newStandardDateCounter(reserveDays int64) *StandardDateCounter {
|
||||
now := time.Now()
|
||||
now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
s := &StandardDateCounter{
|
||||
return newStandardDateCounterWithClock(reserveDays, clock.RealClock{})
|
||||
}
|
||||
|
||||
func newStandardDateCounterWithClock(reserveDays int64, clk clock.PassiveClock) *StandardDateCounter {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
return &StandardDateCounter{
|
||||
reserveDays: reserveDays,
|
||||
counts: make([]int64, reserveDays),
|
||||
lastUpdateDate: now,
|
||||
clock: clk,
|
||||
lastUpdateDate: startOfDay(clk.Now()),
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) TodayCount() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.rotate(time.Now())
|
||||
c.rotate(c.clock.Now())
|
||||
return c.counts[0]
|
||||
}
|
||||
|
||||
|
|
@ -70,65 +78,61 @@ func (c *StandardDateCounter) GetLastDaysCount(lastdays int64) []int64 {
|
|||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.rotate(time.Now())
|
||||
for i := 0; i < int(lastdays); i++ {
|
||||
counts[i] = c.counts[i]
|
||||
}
|
||||
c.rotate(c.clock.Now())
|
||||
copy(counts, c.counts)
|
||||
return counts
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) Inc(count int64) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.rotate(time.Now())
|
||||
c.rotate(c.clock.Now())
|
||||
c.counts[0] += count
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) Dec(count int64) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.rotate(time.Now())
|
||||
c.rotate(c.clock.Now())
|
||||
c.counts[0] -= count
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) Snapshot() DateCounter {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
tmp := newStandardDateCounter(c.reserveDays)
|
||||
for i := 0; i < int(c.reserveDays); i++ {
|
||||
tmp.counts[i] = c.counts[i]
|
||||
}
|
||||
tmp := newStandardDateCounterWithClock(c.reserveDays, c.clock)
|
||||
copy(tmp.counts, c.counts)
|
||||
return tmp
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for i := 0; i < int(c.reserveDays); i++ {
|
||||
c.counts[i] = 0
|
||||
}
|
||||
clear(c.counts)
|
||||
}
|
||||
|
||||
// rotate
|
||||
// Must hold the lock before calling this function.
|
||||
func (c *StandardDateCounter) rotate(now time.Time) {
|
||||
now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
now = startOfDay(now)
|
||||
days := int(now.Sub(c.lastUpdateDate).Hours() / 24)
|
||||
|
||||
defer func() {
|
||||
c.lastUpdateDate = now
|
||||
}()
|
||||
reserveDays := int(c.reserveDays)
|
||||
|
||||
if days <= 0 {
|
||||
return
|
||||
} else if days >= int(c.reserveDays) {
|
||||
} else if days >= reserveDays {
|
||||
c.counts = make([]int64, c.reserveDays)
|
||||
c.lastUpdateDate = now
|
||||
return
|
||||
}
|
||||
newCounts := make([]int64, c.reserveDays)
|
||||
|
||||
for i := days; i < int(c.reserveDays); i++ {
|
||||
newCounts[i] = c.counts[i-days]
|
||||
}
|
||||
copy(newCounts[days:], c.counts[:reserveDays-days])
|
||||
c.counts = newCounts
|
||||
c.lastUpdateDate = now
|
||||
}
|
||||
|
||||
// startOfDay returns midnight in t's location.
|
||||
func startOfDay(t time.Time) time.Time {
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
package metric
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
)
|
||||
|
||||
func TestDateCounter(t *testing.T) {
|
||||
|
|
@ -25,3 +28,107 @@ func TestDateCounter(t *testing.T) {
|
|||
dcTmp := dc.Snapshot()
|
||||
require.EqualValues(5, dcTmp.TodayCount())
|
||||
}
|
||||
|
||||
func TestDateCounterRotate(t *testing.T) {
|
||||
loc := time.FixedZone("test", 8*60*60)
|
||||
lastUpdateDate := time.Date(2026, time.May, 8, 0, 0, 0, 0, loc)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
now time.Time
|
||||
want []int64
|
||||
wantLastUpdateDate time.Time
|
||||
}{
|
||||
{
|
||||
name: "same day",
|
||||
now: time.Date(2026, time.May, 8, 12, 30, 0, 0, loc),
|
||||
want: []int64{10, 7, 3},
|
||||
wantLastUpdateDate: lastUpdateDate,
|
||||
},
|
||||
{
|
||||
name: "clock skew",
|
||||
now: time.Date(2026, time.May, 7, 12, 30, 0, 0, loc),
|
||||
want: []int64{10, 7, 3},
|
||||
wantLastUpdateDate: lastUpdateDate,
|
||||
},
|
||||
{
|
||||
name: "one day",
|
||||
now: time.Date(2026, time.May, 9, 12, 30, 0, 0, loc),
|
||||
want: []int64{0, 10, 7},
|
||||
wantLastUpdateDate: time.Date(2026, time.May, 9, 0, 0, 0, 0, loc),
|
||||
},
|
||||
{
|
||||
name: "two days",
|
||||
now: time.Date(2026, time.May, 10, 12, 30, 0, 0, loc),
|
||||
want: []int64{0, 0, 10},
|
||||
wantLastUpdateDate: time.Date(2026, time.May, 10, 0, 0, 0, 0, loc),
|
||||
},
|
||||
{
|
||||
name: "all reserved days elapsed",
|
||||
now: time.Date(2026, time.May, 11, 12, 30, 0, 0, loc),
|
||||
want: []int64{0, 0, 0},
|
||||
wantLastUpdateDate: time.Date(2026, time.May, 11, 0, 0, 0, 0, loc),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
dc := newStandardDateCounter(3)
|
||||
dc.counts = []int64{10, 7, 3}
|
||||
dc.lastUpdateDate = lastUpdateDate
|
||||
|
||||
dc.mu.Lock()
|
||||
dc.rotate(tt.now)
|
||||
dc.mu.Unlock()
|
||||
|
||||
require.Equal(tt.want, dc.counts)
|
||||
require.Equal(tt.wantLastUpdateDate, dc.lastUpdateDate)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateCounterGetLastDaysCountReturnsCopy(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
clk := clocktesting.NewFakeClock(time.Date(2026, time.May, 8, 12, 30, 0, 0, time.Local))
|
||||
dc := newStandardDateCounterWithClock(3, clk)
|
||||
dc.counts = []int64{10, 7, 3}
|
||||
|
||||
counts := dc.GetLastDaysCount(2)
|
||||
require.Equal([]int64{10, 7}, counts)
|
||||
|
||||
counts[0] = 100
|
||||
require.Equal([]int64{10, 7}, dc.GetLastDaysCount(2))
|
||||
}
|
||||
|
||||
func TestDateCounterClear(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
dc := newStandardDateCounter(3)
|
||||
dc.counts = []int64{10, 7, 3}
|
||||
|
||||
dc.Clear()
|
||||
|
||||
require.Equal([]int64{0, 0, 0}, dc.counts)
|
||||
}
|
||||
|
||||
func TestDateCounterConcurrentAccess(t *testing.T) {
|
||||
clk := clocktesting.NewFakeClock(time.Date(2026, time.May, 8, 12, 30, 0, 0, time.Local))
|
||||
dc := newStandardDateCounterWithClock(3, clk)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 8 {
|
||||
wg.Go(func() {
|
||||
for range 100 {
|
||||
dc.Inc(1)
|
||||
dc.Dec(1)
|
||||
_ = dc.TodayCount()
|
||||
_ = dc.GetLastDaysCount(3)
|
||||
_ = dc.Snapshot()
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,11 +39,14 @@ type HTTPConnectTCPMuxer struct {
|
|||
func NewHTTPConnectTCPMuxer(listener net.Listener, passthrough bool, timeout time.Duration) (*HTTPConnectTCPMuxer, error) {
|
||||
ret := &HTTPConnectTCPMuxer{passthrough: passthrough}
|
||||
mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mux.SetCheckAuthFunc(ret.auth).
|
||||
SetSuccessHookFunc(ret.sendConnectResponse).
|
||||
SetFailHookFunc(vhostFailed)
|
||||
ret.Muxer = mux
|
||||
return ret, err
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host, httpUser, httpPwd string, err error) {
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ import (
|
|||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
libio "github.com/fatedier/golib/io"
|
||||
|
|
@ -160,7 +159,7 @@ func (rp *HTTPReverseProxy) UnRegister(routeCfg RouteConfig) {
|
|||
}
|
||||
|
||||
func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser string) *RouteConfig {
|
||||
vr, ok := rp.getVhost(domain, location, routeByHTTPUser)
|
||||
vr, ok := rp.vhostRouter.getByRoute(domain, location, routeByHTTPUser)
|
||||
if ok {
|
||||
log.Debugf("get new http request host [%s] path [%s] httpuser [%s]", domain, location, routeByHTTPUser)
|
||||
return vr.payload.(*RouteConfig)
|
||||
|
|
@ -171,7 +170,7 @@ func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser str
|
|||
// CreateConnection create a new connection by route config
|
||||
func (rp *HTTPReverseProxy) CreateConnection(reqRouteInfo *RequestRouteInfo, byEndpoint bool) (net.Conn, error) {
|
||||
host, _ := httppkg.CanonicalHost(reqRouteInfo.Host)
|
||||
vr, ok := rp.getVhost(host, reqRouteInfo.URL, reqRouteInfo.HTTPUser)
|
||||
vr, ok := rp.vhostRouter.getByRoute(host, reqRouteInfo.URL, reqRouteInfo.HTTPUser)
|
||||
if ok {
|
||||
if byEndpoint {
|
||||
fn := vr.payload.(*RouteConfig).CreateConnByEndpointFn
|
||||
|
|
@ -208,50 +207,6 @@ func checkRouteAuthByRequest(req *http.Request, rc *RouteConfig) bool {
|
|||
return ok && user == rc.Username && passwd == rc.Password
|
||||
}
|
||||
|
||||
// getVhost tries to get vhost router by route policy.
|
||||
func (rp *HTTPReverseProxy) getVhost(domain, location, routeByHTTPUser string) (*Router, bool) {
|
||||
findRouter := func(inDomain, inLocation, inRouteByHTTPUser string) (*Router, bool) {
|
||||
vr, ok := rp.vhostRouter.Get(inDomain, inLocation, inRouteByHTTPUser)
|
||||
if ok {
|
||||
return vr, ok
|
||||
}
|
||||
// Try to check if there is one proxy that doesn't specify routerByHTTPUser, it means match all.
|
||||
vr, ok = rp.vhostRouter.Get(inDomain, inLocation, "")
|
||||
if ok {
|
||||
return vr, ok
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// First we check the full hostname
|
||||
// if not exist, then check the wildcard_domain such as *.example.com
|
||||
vr, ok := findRouter(domain, location, routeByHTTPUser)
|
||||
if ok {
|
||||
return vr, ok
|
||||
}
|
||||
|
||||
// e.g. domain = test.example.com, try to match wildcard domains.
|
||||
// *.example.com
|
||||
// *.com
|
||||
domainSplit := strings.Split(domain, ".")
|
||||
for len(domainSplit) >= 3 {
|
||||
domainSplit[0] = "*"
|
||||
domain = strings.Join(domainSplit, ".")
|
||||
vr, ok = findRouter(domain, location, routeByHTTPUser)
|
||||
if ok {
|
||||
return vr, true
|
||||
}
|
||||
domainSplit = domainSplit[1:]
|
||||
}
|
||||
|
||||
// Finally, try to check if there is one proxy that domain is "*" means match all domains.
|
||||
vr, ok = findRouter("*", location, routeByHTTPUser)
|
||||
if ok {
|
||||
return vr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Request) {
|
||||
hj, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
|
|
|
|||
|
|
@ -29,11 +29,11 @@ type HTTPSMuxer struct {
|
|||
|
||||
func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, error) {
|
||||
mux, err := NewMuxer(listener, GetHTTPSHostname, timeout)
|
||||
mux.SetFailHookFunc(vhostFailed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &HTTPSMuxer{mux}, err
|
||||
mux.SetFailHookFunc(vhostFailed)
|
||||
return &HTTPSMuxer{mux}, nil
|
||||
}
|
||||
|
||||
func GetHTTPSHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) {
|
||||
|
|
|
|||
|
|
@ -16,23 +16,53 @@ func TestGetHTTPSHostname(t *testing.T) {
|
|||
require.NoError(err)
|
||||
defer l.Close()
|
||||
|
||||
var conn net.Conn
|
||||
connCh := make(chan net.Conn, 1)
|
||||
acceptErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
conn, _ = l.Accept()
|
||||
require.NotNil(conn)
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
acceptErrCh <- err
|
||||
return
|
||||
}
|
||||
connCh <- conn
|
||||
}()
|
||||
|
||||
clientErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
tls.Dial("tcp", l.Addr().String(), &tls.Config{
|
||||
conn, err := tls.Dial("tcp", l.Addr().String(), &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "example.com",
|
||||
})
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
clientErrCh <- err
|
||||
}()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
_, infos, err := GetHTTPSHostname(conn)
|
||||
var conn net.Conn
|
||||
select {
|
||||
case conn = <-connCh:
|
||||
case err := <-acceptErrCh:
|
||||
require.NoError(err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for accepted connection")
|
||||
}
|
||||
require.NotNil(conn)
|
||||
|
||||
serverConn, infos, err := GetHTTPSHostname(conn)
|
||||
if serverConn != nil {
|
||||
_ = serverConn.Close()
|
||||
} else {
|
||||
_ = conn.Close()
|
||||
}
|
||||
require.NoError(err)
|
||||
require.Equal("example.com", infos["Host"])
|
||||
require.Equal("https", infos["Scheme"])
|
||||
|
||||
select {
|
||||
case <-clientErrCh:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for TLS client")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -93,12 +93,19 @@ func (r *Routers) Del(domain, location, httpUser string) {
|
|||
routersByHTTPUser[httpUser] = newVrs
|
||||
}
|
||||
|
||||
// Get returns the best location match for an exact host and exact HTTP user.
|
||||
// It does not apply all-users, wildcard-domain, or catch-all-domain fallback.
|
||||
func (r *Routers) Get(host, path, httpUser string) (vr *Router, exist bool) {
|
||||
host = strings.ToLower(host)
|
||||
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
return r.getLocked(host, path, httpUser)
|
||||
}
|
||||
|
||||
// getLocked performs an exact-host lookup; host must already be lower-cased.
|
||||
func (r *Routers) getLocked(host, path, httpUser string) (vr *Router, exist bool) {
|
||||
routersByHTTPUser, found := r.indexByDomain[host]
|
||||
if !found {
|
||||
return
|
||||
|
|
@ -117,6 +124,49 @@ func (r *Routers) Get(host, path, httpUser string) (vr *Router, exist bool) {
|
|||
return
|
||||
}
|
||||
|
||||
func (r *Routers) getByRoute(host, path, httpUser string) (*Router, bool) {
|
||||
host = strings.ToLower(host)
|
||||
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
// First we check the full hostname; if it doesn't exist, then check wildcard domains.
|
||||
// For example, test.example.com checks *.example.com before falling back to "*".
|
||||
vr, ok := r.getExactOrAllUsersLocked(host, path, httpUser)
|
||||
if ok {
|
||||
return vr, true
|
||||
}
|
||||
|
||||
hostSplit := strings.Split(host, ".")
|
||||
// Keep two-label hosts out of the wildcard walk, so example.com does not match *.com.
|
||||
for len(hostSplit) >= 3 {
|
||||
// Replace the leftmost remaining label with the wildcard marker.
|
||||
hostSplit[0] = "*"
|
||||
host = strings.Join(hostSplit, ".")
|
||||
vr, ok = r.getExactOrAllUsersLocked(host, path, httpUser)
|
||||
if ok {
|
||||
return vr, true
|
||||
}
|
||||
hostSplit = hostSplit[1:]
|
||||
}
|
||||
|
||||
// Finally, try to check if there is one proxy whose domain is "*", which means match all domains.
|
||||
return r.getExactOrAllUsersLocked("*", path, httpUser)
|
||||
}
|
||||
|
||||
func (r *Routers) getExactOrAllUsersLocked(host, path, httpUser string) (*Router, bool) {
|
||||
vr, ok := r.getLocked(host, path, httpUser)
|
||||
if ok {
|
||||
return vr, true
|
||||
}
|
||||
// Try to check if there is one proxy that doesn't specify routeByHTTPUser, it means match all.
|
||||
vr, ok = r.getLocked(host, path, "")
|
||||
if ok {
|
||||
return vr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (r *Routers) exist(host, path, httpUser string) (route *Router, exist bool) {
|
||||
routersByHTTPUser, found := r.indexByDomain[host]
|
||||
if !found {
|
||||
|
|
|
|||
257
pkg/util/vhost/router_test.go
Normal file
257
pkg/util/vhost/router_test.go
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package vhost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRoutersGet(t *testing.T) {
|
||||
routers := NewRouters()
|
||||
require.NoError(t, routers.Add("example.com", "/api", "alice", "exact-user"))
|
||||
require.NoError(t, routers.Add("example.com", "/public", "", "exact-all-users"))
|
||||
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users"))
|
||||
require.NoError(t, routers.Add("*", "/api", "", "all-domains"))
|
||||
|
||||
t.Run("exact host and user match with normalized domain", func(t *testing.T) {
|
||||
router, ok := routers.Get("EXAMPLE.COM", "/api/users", "alice")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "exact-user", router.payload)
|
||||
})
|
||||
|
||||
t.Run("exact host and empty user match", func(t *testing.T) {
|
||||
router, ok := routers.Get("EXAMPLE.COM", "/public/docs", "")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "exact-all-users", router.payload)
|
||||
})
|
||||
|
||||
t.Run("does not fall back from named user to empty user", func(t *testing.T) {
|
||||
// Get intentionally requires an exact HTTP user; route-level fallbacks live in getByRoute.
|
||||
_, ok := routers.Get("EXAMPLE.COM", "/public/docs", "alice")
|
||||
require.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("does not fall back to wildcard domains", func(t *testing.T) {
|
||||
_, ok := routers.Get("foo.example.com", "/api/users", "")
|
||||
require.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("does not fall back to catch-all domain", func(t *testing.T) {
|
||||
_, ok := routers.Get("missing.test", "/api/users", "")
|
||||
require.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoutersGetByRoute(t *testing.T) {
|
||||
routers := NewRouters()
|
||||
require.NoError(t, routers.Add("example.com", "/api", "alice", "exact-user"))
|
||||
require.NoError(t, routers.Add("example.com", "/api", "", "exact-all-users"))
|
||||
require.NoError(t, routers.Add("exact.example.com", "/api", "", "exact-subdomain"))
|
||||
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users"))
|
||||
require.NoError(t, routers.Add("*.foo.example.com", "/api", "", "specific-wildcard"))
|
||||
require.NoError(t, routers.Add("*.bar.com", "/api", "", "wildcard-parent-domain"))
|
||||
require.NoError(t, routers.Add("*", "/admin", "root", "all-domains-user"))
|
||||
require.NoError(t, routers.Add("*", "/", "", "all-domains"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
location string
|
||||
httpUser string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "exact domain and http user",
|
||||
domain: "example.com",
|
||||
location: "/api/users",
|
||||
httpUser: "alice",
|
||||
want: "exact-user",
|
||||
},
|
||||
{
|
||||
name: "exact domain falls back to all users",
|
||||
domain: "example.com",
|
||||
location: "/api/users",
|
||||
httpUser: "bob",
|
||||
want: "exact-all-users",
|
||||
},
|
||||
{
|
||||
name: "wildcard domain uses all users fallback",
|
||||
domain: "foo.example.com",
|
||||
location: "/api/users",
|
||||
httpUser: "bob",
|
||||
want: "wildcard-all-users",
|
||||
},
|
||||
{
|
||||
name: "mixed-case domain is normalized",
|
||||
domain: "Foo.Example.Com",
|
||||
location: "/api/users",
|
||||
httpUser: "bob",
|
||||
want: "wildcard-all-users",
|
||||
},
|
||||
{
|
||||
name: "exact domain wins over wildcard domain",
|
||||
domain: "exact.example.com",
|
||||
location: "/api/users",
|
||||
httpUser: "bob",
|
||||
want: "exact-subdomain",
|
||||
},
|
||||
{
|
||||
name: "more specific wildcard wins over broader wildcard",
|
||||
domain: "bar.foo.example.com",
|
||||
location: "/api/users",
|
||||
httpUser: "bob",
|
||||
want: "specific-wildcard",
|
||||
},
|
||||
{
|
||||
name: "wildcard walk checks parent domains",
|
||||
domain: "a.b.bar.com",
|
||||
location: "/api/users",
|
||||
httpUser: "bob",
|
||||
want: "wildcard-parent-domain",
|
||||
},
|
||||
{
|
||||
name: "catch-all domain fallback",
|
||||
domain: "foo.test.com",
|
||||
location: "/other",
|
||||
httpUser: "bob",
|
||||
want: "all-domains",
|
||||
},
|
||||
{
|
||||
name: "catch-all domain honors http user",
|
||||
domain: "foo.test.com",
|
||||
location: "/admin/panel",
|
||||
httpUser: "root",
|
||||
want: "all-domains-user",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router, ok := routers.getByRoute(tt.domain, tt.location, tt.httpUser)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tt.want, router.payload)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutersGetByRouteNoMatch(t *testing.T) {
|
||||
routers := NewRouters()
|
||||
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users"))
|
||||
require.NoError(t, routers.Add("*.com", "/api", "", "top-level-wildcard"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
location string
|
||||
}{
|
||||
{
|
||||
name: "two-label domain does not enter wildcard walk",
|
||||
domain: "example.com",
|
||||
location: "/api/users",
|
||||
},
|
||||
{
|
||||
name: "missing catch-all remains no match",
|
||||
domain: "foo.test.com",
|
||||
location: "/api/users",
|
||||
},
|
||||
{
|
||||
name: "wrong path remains no match",
|
||||
domain: "foo.example.com",
|
||||
location: "/other",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router, ok := routers.getByRoute(tt.domain, tt.location, "bob")
|
||||
require.False(t, ok)
|
||||
require.Nil(t, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutersConcurrentGetByRouteAndAdd(t *testing.T) {
|
||||
routers := NewRouters()
|
||||
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard"))
|
||||
|
||||
const readers = 8
|
||||
const iterations = 200
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
start := make(chan struct{})
|
||||
errCh := make(chan error, readers+1)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for id := range readers {
|
||||
wg.Go(func() {
|
||||
<-start
|
||||
for j := range iterations {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
router, ok := routers.getByRoute("foo.example.com", "/api/users", "")
|
||||
if !ok || router == nil || router.payload != "wildcard" {
|
||||
errCh <- fmt.Errorf("reader %d iteration %d got router=%v ok=%v", id, j, router, ok)
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
wg.Go(func() {
|
||||
<-start
|
||||
for i := range iterations {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
err := routers.Add(fmt.Sprintf("host-%d.example.com", i), "/api", "", i)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
close(start)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
cancel()
|
||||
t.Fatal("concurrent route lookup and add timed out")
|
||||
}
|
||||
|
||||
close(errCh)
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
|
@ -148,41 +148,9 @@ func (v *Muxer) Listen(ctx context.Context, cfg *RouteConfig) (l *Listener, err
|
|||
}
|
||||
|
||||
func (v *Muxer) getListener(name, path, httpUser string) (*Listener, bool) {
|
||||
findRouter := func(inName, inPath, inHTTPUser string) (*Listener, bool) {
|
||||
vr, ok := v.registryRouter.Get(inName, inPath, inHTTPUser)
|
||||
if ok {
|
||||
return vr.payload.(*Listener), true
|
||||
}
|
||||
// Try to check if there is one proxy that doesn't specify routerByHTTPUser, it means match all.
|
||||
vr, ok = v.registryRouter.Get(inName, inPath, "")
|
||||
if ok {
|
||||
return vr.payload.(*Listener), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// first we check the full hostname
|
||||
// if not exist, then check the wildcard_domain such as *.example.com
|
||||
l, ok := findRouter(name, path, httpUser)
|
||||
vr, ok := v.registryRouter.getByRoute(name, path, httpUser)
|
||||
if ok {
|
||||
return l, true
|
||||
}
|
||||
|
||||
domainSplit := strings.Split(name, ".")
|
||||
for len(domainSplit) >= 3 {
|
||||
domainSplit[0] = "*"
|
||||
name = strings.Join(domainSplit, ".")
|
||||
|
||||
l, ok = findRouter(name, path, httpUser)
|
||||
if ok {
|
||||
return l, true
|
||||
}
|
||||
domainSplit = domainSplit[1:]
|
||||
}
|
||||
// Finally, try to check if there is one proxy that domain is "*" means match all domains.
|
||||
l, ok = findRouter("*", path, httpUser)
|
||||
if ok {
|
||||
return l, true
|
||||
return vr.payload.(*Listener), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ import (
|
|||
"math/rand/v2"
|
||||
"time"
|
||||
|
||||
"k8s.io/utils/clock"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
)
|
||||
|
||||
|
|
@ -48,6 +50,7 @@ type FastBackoffOptions struct {
|
|||
|
||||
type fastBackoffImpl struct {
|
||||
options FastBackoffOptions
|
||||
clock clock.PassiveClock
|
||||
|
||||
lastCalledTime time.Time
|
||||
consecutiveErrCount int
|
||||
|
|
@ -57,18 +60,26 @@ type fastBackoffImpl struct {
|
|||
}
|
||||
|
||||
func NewFastBackoffManager(options FastBackoffOptions) BackoffManager {
|
||||
return newFastBackoffManagerWithClock(options, clock.RealClock{})
|
||||
}
|
||||
|
||||
func newFastBackoffManagerWithClock(options FastBackoffOptions, clk clock.PassiveClock) BackoffManager {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
return &fastBackoffImpl{
|
||||
options: options,
|
||||
clock: clk,
|
||||
countsInFastRetryWindow: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fastBackoffImpl) Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration {
|
||||
if f.lastCalledTime.IsZero() {
|
||||
f.lastCalledTime = time.Now()
|
||||
f.lastCalledTime = f.clock.Now()
|
||||
return f.options.Duration
|
||||
}
|
||||
now := time.Now()
|
||||
now := f.clock.Now()
|
||||
f.lastCalledTime = now
|
||||
|
||||
if previousConditionError {
|
||||
|
|
|
|||
27
pkg/util/wait/backoff_test.go
Normal file
27
pkg/util/wait/backoff_test.go
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
package wait
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
)
|
||||
|
||||
func TestFastBackoffManagerUsesClock(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
backoff := newFastBackoffManagerWithClock(FastBackoffOptions{
|
||||
Duration: time.Second,
|
||||
}, clk).(*fastBackoffImpl)
|
||||
|
||||
require.Equal(time.Second, backoff.Backoff(0, false))
|
||||
require.Equal(start, backoff.lastCalledTime)
|
||||
|
||||
next := start.Add(time.Minute)
|
||||
clk.SetTime(next)
|
||||
require.Equal(time.Second, backoff.Backoff(time.Second, false))
|
||||
require.Equal(next, backoff.lastCalledTime)
|
||||
}
|
||||
|
|
@ -19,7 +19,6 @@ import "strings"
|
|||
// LogWriter forwards writes to frp's logger at configurable level.
|
||||
// It is safe for concurrent use as long as the underlying Logger is thread-safe.
|
||||
type LogWriter struct {
|
||||
xl *Logger
|
||||
logFunc func(string)
|
||||
}
|
||||
|
||||
|
|
@ -31,35 +30,30 @@ func (w LogWriter) Write(p []byte) (n int, err error) {
|
|||
|
||||
func NewTraceWriter(xl *Logger) LogWriter {
|
||||
return LogWriter{
|
||||
xl: xl,
|
||||
logFunc: func(msg string) { xl.Tracef("%s", msg) },
|
||||
}
|
||||
}
|
||||
|
||||
func NewDebugWriter(xl *Logger) LogWriter {
|
||||
return LogWriter{
|
||||
xl: xl,
|
||||
logFunc: func(msg string) { xl.Debugf("%s", msg) },
|
||||
}
|
||||
}
|
||||
|
||||
func NewInfoWriter(xl *Logger) LogWriter {
|
||||
return LogWriter{
|
||||
xl: xl,
|
||||
logFunc: func(msg string) { xl.Infof("%s", msg) },
|
||||
}
|
||||
}
|
||||
|
||||
func NewWarnWriter(xl *Logger) LogWriter {
|
||||
return LogWriter{
|
||||
xl: xl,
|
||||
logFunc: func(msg string) { xl.Warnf("%s", msg) },
|
||||
}
|
||||
}
|
||||
|
||||
func NewErrorWriter(xl *Logger) LogWriter {
|
||||
return LogWriter{
|
||||
xl: xl,
|
||||
logFunc: func(msg string) { xl.Errorf("%s", msg) },
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -286,7 +286,6 @@ func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWri
|
|||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.routes[name] = &routeElement{
|
||||
name: name,
|
||||
routes: routes,
|
||||
conn: conn,
|
||||
}
|
||||
|
|
@ -383,7 +382,6 @@ func (r *serverRouter) cleanupConnIPs(conn io.Writer) {
|
|||
}
|
||||
|
||||
type routeElement struct {
|
||||
name string
|
||||
routes []net.IPNet
|
||||
conn io.ReadWriteCloser
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"k8s.io/utils/clock"
|
||||
|
||||
"github.com/fatedier/frp/pkg/config/types"
|
||||
)
|
||||
|
||||
|
|
@ -38,16 +40,25 @@ type Manager struct {
|
|||
|
||||
bindAddr string
|
||||
netType string
|
||||
clock clock.WithTicker
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewManager(netType string, bindAddr string, allowPorts []types.PortsRange) *Manager {
|
||||
return newManagerWithClock(netType, bindAddr, allowPorts, clock.RealClock{})
|
||||
}
|
||||
|
||||
func newManagerWithClock(netType string, bindAddr string, allowPorts []types.PortsRange, clk clock.WithTicker) *Manager {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
pm := &Manager{
|
||||
reservedPorts: make(map[string]*PortCtx),
|
||||
usedPorts: make(map[int]*PortCtx),
|
||||
freePorts: make(map[int]struct{}),
|
||||
bindAddr: bindAddr,
|
||||
netType: netType,
|
||||
clock: clk,
|
||||
}
|
||||
if len(allowPorts) > 0 {
|
||||
for _, pair := range allowPorts {
|
||||
|
|
@ -72,7 +83,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
|||
portCtx := &PortCtx{
|
||||
ProxyName: name,
|
||||
Closed: false,
|
||||
UpdateTime: time.Now(),
|
||||
UpdateTime: pm.clock.Now(),
|
||||
}
|
||||
|
||||
var ok bool
|
||||
|
|
@ -90,9 +101,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
|||
if ctx, ok := pm.reservedPorts[name]; ok {
|
||||
if pm.isPortAvailable(ctx.Port) {
|
||||
realPort = ctx.Port
|
||||
pm.usedPorts[realPort] = portCtx
|
||||
pm.reservedPorts[name] = portCtx
|
||||
delete(pm.freePorts, realPort)
|
||||
pm.markPortAcquiredLocked(name, realPort, portCtx)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
@ -109,9 +118,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
|||
}
|
||||
if pm.isPortAvailable(k) {
|
||||
realPort = k
|
||||
pm.usedPorts[realPort] = portCtx
|
||||
pm.reservedPorts[name] = portCtx
|
||||
delete(pm.freePorts, realPort)
|
||||
pm.markPortAcquiredLocked(name, realPort, portCtx)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -123,9 +130,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
|||
if _, ok = pm.freePorts[port]; ok {
|
||||
if pm.isPortAvailable(port) {
|
||||
realPort = port
|
||||
pm.usedPorts[realPort] = portCtx
|
||||
pm.reservedPorts[name] = portCtx
|
||||
delete(pm.freePorts, realPort)
|
||||
pm.markPortAcquiredLocked(name, realPort, portCtx)
|
||||
} else {
|
||||
err = ErrPortUnAvailable
|
||||
}
|
||||
|
|
@ -140,6 +145,13 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
// markPortAcquiredLocked records a successful acquisition. pm.mu must be held.
|
||||
func (pm *Manager) markPortAcquiredLocked(name string, port int, portCtx *PortCtx) {
|
||||
pm.usedPorts[port] = portCtx
|
||||
pm.reservedPorts[name] = portCtx
|
||||
delete(pm.freePorts, port)
|
||||
}
|
||||
|
||||
func (pm *Manager) isPortAvailable(port int) bool {
|
||||
if pm.netType == "udp" {
|
||||
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pm.bindAddr, strconv.Itoa(port)))
|
||||
|
|
@ -169,20 +181,36 @@ func (pm *Manager) Release(port int) {
|
|||
pm.freePorts[port] = struct{}{}
|
||||
delete(pm.usedPorts, port)
|
||||
ctx.Closed = true
|
||||
ctx.UpdateTime = time.Now()
|
||||
ctx.UpdateTime = pm.clock.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// Release reserved port if it isn't used in last 24 hours.
|
||||
func (pm *Manager) cleanReservedPortsWorker() {
|
||||
pm.cleanReservedPortsWorkerUntil(nil)
|
||||
}
|
||||
|
||||
func (pm *Manager) cleanReservedPortsWorkerUntil(stopCh <-chan struct{}) {
|
||||
ticker := pm.clock.NewTicker(CleanReservedPortsInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
time.Sleep(CleanReservedPortsInterval)
|
||||
pm.mu.Lock()
|
||||
for name, ctx := range pm.reservedPorts {
|
||||
if ctx.Closed && time.Since(ctx.UpdateTime) > MaxPortReservedDuration {
|
||||
delete(pm.reservedPorts, name)
|
||||
}
|
||||
select {
|
||||
case <-ticker.C():
|
||||
pm.cleanReservedPortsOnce()
|
||||
case <-stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *Manager) cleanReservedPortsOnce() {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
for name, ctx := range pm.reservedPorts {
|
||||
if ctx.Closed && pm.clock.Since(ctx.UpdateTime) > MaxPortReservedDuration {
|
||||
delete(pm.reservedPorts, name)
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
72
server/ports/ports_test.go
Normal file
72
server/ports/ports_test.go
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
package ports
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
|
||||
"github.com/fatedier/frp/pkg/config/types"
|
||||
)
|
||||
|
||||
func TestManagerUsesClockForPortTimestamps(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
port := freeTCPPort(t)
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
pm := newManagerWithClock("tcp", "127.0.0.1", []types.PortsRange{{Single: port}}, clk)
|
||||
|
||||
realPort, err := pm.Acquire("proxy", port)
|
||||
require.NoError(err)
|
||||
require.Equal(port, realPort)
|
||||
require.Equal(start, pm.usedPorts[port].UpdateTime)
|
||||
|
||||
releasedAt := start.Add(time.Minute)
|
||||
clk.SetTime(releasedAt)
|
||||
pm.Release(port)
|
||||
|
||||
require.Equal(releasedAt, pm.reservedPorts["proxy"].UpdateTime)
|
||||
}
|
||||
|
||||
func TestManagerCleanReservedPortsWorkerUsesClockTicker(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
port := freeTCPPort(t)
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
pm := newManagerWithClock("tcp", "127.0.0.1", []types.PortsRange{{Single: port}}, clk)
|
||||
|
||||
realPort, err := pm.Acquire("proxy", port)
|
||||
require.NoError(err)
|
||||
require.Equal(port, realPort)
|
||||
pm.Release(port)
|
||||
require.True(pm.hasReservedPort("proxy"))
|
||||
|
||||
require.Eventually(clk.HasWaiters, time.Second, time.Millisecond)
|
||||
clk.Step(MaxPortReservedDuration + CleanReservedPortsInterval + time.Minute)
|
||||
|
||||
require.Eventually(func() bool {
|
||||
return !pm.hasReservedPort("proxy")
|
||||
}, time.Second, time.Millisecond)
|
||||
}
|
||||
|
||||
func (pm *Manager) hasReservedPort(name string) bool {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
_, ok := pm.reservedPorts[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func freeTCPPort(t *testing.T) int {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
return listener.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
|
|
@ -18,6 +18,8 @@ import (
|
|||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"k8s.io/utils/clock"
|
||||
)
|
||||
|
||||
// ClientInfo captures metadata about a connected frpc instance.
|
||||
|
|
@ -42,12 +44,21 @@ type ClientRegistry struct {
|
|||
mu sync.RWMutex
|
||||
clients map[string]*ClientInfo
|
||||
runIndex map[string]string
|
||||
clock clock.PassiveClock
|
||||
}
|
||||
|
||||
func NewClientRegistry() *ClientRegistry {
|
||||
return newClientRegistryWithClock(clock.RealClock{})
|
||||
}
|
||||
|
||||
func newClientRegistryWithClock(clk clock.PassiveClock) *ClientRegistry {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
return &ClientRegistry{
|
||||
clients: make(map[string]*ClientInfo),
|
||||
runIndex: make(map[string]string),
|
||||
clock: clk,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -64,7 +75,7 @@ func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version,
|
|||
key = cr.composeClientKey(user, effectiveID)
|
||||
enforceUnique := rawClientID != ""
|
||||
|
||||
now := time.Now()
|
||||
now := cr.clock.Now()
|
||||
cr.mu.Lock()
|
||||
defer cr.mu.Unlock()
|
||||
|
||||
|
|
@ -116,7 +127,7 @@ func (cr *ClientRegistry) MarkOfflineByRunID(runID string) {
|
|||
} else {
|
||||
info.RunID = ""
|
||||
info.Online = false
|
||||
now := time.Now()
|
||||
now := cr.clock.Now()
|
||||
info.DisconnectedAt = now
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ package registry
|
|||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
|
@ -35,3 +38,37 @@ func TestClientRegistryRegisterStoresWireProtocol(t *testing.T) {
|
|||
t.Fatalf("wire protocol mismatch, want %q got %q", wire.ProtocolV2, info.WireProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientRegistryUsesClockForTimestamps(t *testing.T) {
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
registry := newClientRegistryWithClock(clk)
|
||||
|
||||
key, conflict := registry.Register("user", "client-id", "run-id", "host", "1.0.0", "127.0.0.1", wire.ProtocolV2)
|
||||
if conflict {
|
||||
t.Fatal("unexpected client conflict")
|
||||
}
|
||||
|
||||
info, ok := registry.GetByKey(key)
|
||||
if !ok {
|
||||
t.Fatalf("client %q not found", key)
|
||||
}
|
||||
if !info.FirstConnectedAt.Equal(start) {
|
||||
t.Fatalf("first connected time mismatch, want %s got %s", start, info.FirstConnectedAt)
|
||||
}
|
||||
if !info.LastConnectedAt.Equal(start) {
|
||||
t.Fatalf("last connected time mismatch, want %s got %s", start, info.LastConnectedAt)
|
||||
}
|
||||
|
||||
disconnectedAt := start.Add(time.Minute)
|
||||
clk.SetTime(disconnectedAt)
|
||||
registry.MarkOfflineByRunID("run-id")
|
||||
|
||||
info, ok = registry.GetByKey(key)
|
||||
if !ok {
|
||||
t.Fatalf("client %q not found after disconnect", key)
|
||||
}
|
||||
if !info.DisconnectedAt.Equal(disconnectedAt) {
|
||||
t.Fatalf("disconnected time mismatch, want %s got %s", disconnectedAt, info.DisconnectedAt)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,62 +0,0 @@
|
|||
package framework
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// CleanupActionHandle is an integer pointer type for handling cleanup action
|
||||
type CleanupActionHandle *int
|
||||
|
||||
type cleanupFuncHandle struct {
|
||||
actionHandle CleanupActionHandle
|
||||
actionHook func()
|
||||
}
|
||||
|
||||
var (
|
||||
cleanupActionsLock sync.Mutex
|
||||
cleanupHookList = []cleanupFuncHandle{}
|
||||
)
|
||||
|
||||
// AddCleanupAction installs a function that will be called in the event of the
|
||||
// whole test being terminated. This allows arbitrary pieces of the overall
|
||||
// test to hook into SynchronizedAfterSuite().
|
||||
// The hooks are called in last-in-first-out order.
|
||||
func AddCleanupAction(fn func()) CleanupActionHandle {
|
||||
p := CleanupActionHandle(new(int))
|
||||
cleanupActionsLock.Lock()
|
||||
defer cleanupActionsLock.Unlock()
|
||||
c := cleanupFuncHandle{actionHandle: p, actionHook: fn}
|
||||
cleanupHookList = append([]cleanupFuncHandle{c}, cleanupHookList...)
|
||||
return p
|
||||
}
|
||||
|
||||
// RemoveCleanupAction removes a function that was installed by
|
||||
// AddCleanupAction.
|
||||
func RemoveCleanupAction(p CleanupActionHandle) {
|
||||
cleanupActionsLock.Lock()
|
||||
defer cleanupActionsLock.Unlock()
|
||||
for i, item := range cleanupHookList {
|
||||
if item.actionHandle == p {
|
||||
cleanupHookList = append(cleanupHookList[:i], cleanupHookList[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RunCleanupActions runs all functions installed by AddCleanupAction. It does
|
||||
// not remove them (see RemoveCleanupAction) but it does run unlocked, so they
|
||||
// may remove themselves.
|
||||
func RunCleanupActions() {
|
||||
list := []func(){}
|
||||
func() {
|
||||
cleanupActionsLock.Lock()
|
||||
defer cleanupActionsLock.Unlock()
|
||||
for _, p := range cleanupHookList {
|
||||
list = append(list, p.actionHook)
|
||||
}
|
||||
}()
|
||||
// Run unlocked.
|
||||
for _, fn := range list {
|
||||
fn()
|
||||
}
|
||||
}
|
||||
|
|
@ -23,11 +23,6 @@ func ExpectNotEqual(actual any, extra any, explain ...any) {
|
|||
gomega.ExpectWithOffset(1, actual).NotTo(gomega.Equal(extra), explain...)
|
||||
}
|
||||
|
||||
// ExpectError expects an error happens, otherwise an exception raises
|
||||
func ExpectError(err error, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, err).To(gomega.HaveOccurred(), explain...)
|
||||
}
|
||||
|
||||
func ExpectErrorWithOffset(offset int, err error, explain ...any) {
|
||||
gomega.ExpectWithOffset(1+offset, err).To(gomega.HaveOccurred(), explain...)
|
||||
}
|
||||
|
|
@ -47,11 +42,6 @@ func ExpectContainSubstring(actual, substr string, explain ...any) {
|
|||
gomega.ExpectWithOffset(1, actual).To(gomega.ContainSubstring(substr), explain...)
|
||||
}
|
||||
|
||||
// ExpectConsistOf expects actual contains precisely the extra elements. The ordering of the elements does not matter.
|
||||
func ExpectConsistOf(actual any, extra any, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, actual).To(gomega.ConsistOf(extra), explain...)
|
||||
}
|
||||
|
||||
func ExpectContainElements(actual any, extra any, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, actual).To(gomega.ContainElements(extra), explain...)
|
||||
}
|
||||
|
|
@ -60,16 +50,6 @@ func ExpectNotContainElements(actual any, extra any, explain ...any) {
|
|||
gomega.ExpectWithOffset(1, actual).NotTo(gomega.ContainElements(extra), explain...)
|
||||
}
|
||||
|
||||
// ExpectHaveKey expects the actual map has the key in the keyset
|
||||
func ExpectHaveKey(actual any, key any, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, actual).To(gomega.HaveKey(key), explain...)
|
||||
}
|
||||
|
||||
// ExpectEmpty expects actual is empty
|
||||
func ExpectEmpty(actual any, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, actual).To(gomega.BeEmpty(), explain...)
|
||||
}
|
||||
|
||||
func ExpectTrue(actual any, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, actual).Should(gomega.BeTrue(), explain...)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,11 +38,6 @@ type Framework struct {
|
|||
// Multiple default mock servers used for e2e testing.
|
||||
mockServers *MockServers
|
||||
|
||||
// To make sure that this framework cleans up after itself, no matter what,
|
||||
// we install a Cleanup action before each test and clear it after. If we
|
||||
// should abort, the AfterSuite hook should run all Cleanup actions.
|
||||
cleanupHandle CleanupActionHandle
|
||||
|
||||
// beforeEachStarted indicates that BeforeEach has started
|
||||
beforeEachStarted bool
|
||||
|
||||
|
|
@ -87,8 +82,6 @@ func NewFramework(opt Options) *Framework {
|
|||
func (f *Framework) BeforeEach() {
|
||||
f.beforeEachStarted = true
|
||||
|
||||
f.cleanupHandle = AddCleanupAction(f.AfterEach)
|
||||
|
||||
dir, err := os.MkdirTemp(os.TempDir(), "frp-e2e-test-*")
|
||||
ExpectNoError(err)
|
||||
f.TempDirectory = dir
|
||||
|
|
@ -113,8 +106,6 @@ func (f *Framework) AfterEach() {
|
|||
return
|
||||
}
|
||||
|
||||
RemoveCleanupAction(f.cleanupHandle)
|
||||
|
||||
// stop processor
|
||||
for _, p := range f.serverProcesses {
|
||||
_ = p.Stop()
|
||||
|
|
@ -266,10 +257,6 @@ func (f *Framework) AllocPortExcludingRanges(ranges ...[2]int) int {
|
|||
return 0
|
||||
}
|
||||
|
||||
func (f *Framework) ReleasePort(port int) {
|
||||
f.portAllocator.Release(port)
|
||||
}
|
||||
|
||||
func (f *Framework) RunServer(portName string, s server.Server) {
|
||||
f.servers = append(f.servers, s)
|
||||
if s.BindPort() > 0 && portName != "" {
|
||||
|
|
|
|||
|
|
@ -75,11 +75,3 @@ func (m *MockServers) GetTemplateParams() map[string]any {
|
|||
ret[HTTPSimpleServerPort] = m.httpSimpleServer.BindPort()
|
||||
return ret
|
||||
}
|
||||
|
||||
func (m *MockServers) GetParam(key string) any {
|
||||
params := m.GetTemplateParams()
|
||||
if v, ok := params[key]; ok {
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -124,7 +124,3 @@ func (e *RequestExpect) Ensure(fns ...EnsureFunc) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *RequestExpect) Do() (*request.Response, error) {
|
||||
return e.req.Do()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,13 +31,6 @@ func New(options ...Option) *Server {
|
|||
return s
|
||||
}
|
||||
|
||||
func WithBindAddr(addr string) Option {
|
||||
return func(s *Server) *Server {
|
||||
s.bindAddr = addr
|
||||
return s
|
||||
}
|
||||
}
|
||||
|
||||
func WithBindPort(port int) Option {
|
||||
return func(s *Server) *Server {
|
||||
s.bindPort = port
|
||||
|
|
|
|||
|
|
@ -61,21 +61,6 @@ func WithBindPort(port int) Option {
|
|||
return func(s *Server) { s.bindPort = port }
|
||||
}
|
||||
|
||||
func WithClientCredentials(id, secret string) Option {
|
||||
return func(s *Server) {
|
||||
s.clientID = id
|
||||
s.clientSecret = secret
|
||||
}
|
||||
}
|
||||
|
||||
func WithAudience(aud string) Option {
|
||||
return func(s *Server) { s.audience = aud }
|
||||
}
|
||||
|
||||
func WithSubject(sub string) Option {
|
||||
return func(s *Server) { s.subject = sub }
|
||||
}
|
||||
|
||||
func WithExpiresIn(seconds int) Option {
|
||||
return func(s *Server) { s.expiresIn = seconds }
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,13 +40,8 @@ type Process struct {
|
|||
closeOne sync.Once
|
||||
waitErr error
|
||||
|
||||
started bool
|
||||
beforeStopHandler func()
|
||||
stopped bool
|
||||
}
|
||||
|
||||
func New(path string, params []string) *Process {
|
||||
return NewWithEnvs(path, params, nil)
|
||||
started bool
|
||||
stopped bool
|
||||
}
|
||||
|
||||
func NewWithEnvs(path string, params []string, envs []string) *Process {
|
||||
|
|
@ -100,9 +95,6 @@ func (p *Process) Stop() error {
|
|||
defer func() {
|
||||
p.stopped = true
|
||||
}()
|
||||
if p.beforeStopHandler != nil {
|
||||
p.beforeStopHandler()
|
||||
}
|
||||
p.cancel()
|
||||
<-p.done
|
||||
return p.waitErr
|
||||
|
|
@ -125,10 +117,6 @@ func (p *Process) CountOutput(pattern string) int {
|
|||
return strings.Count(p.Output(), pattern)
|
||||
}
|
||||
|
||||
func (p *Process) SetBeforeStopHandler(fn func()) {
|
||||
p.beforeStopHandler = fn
|
||||
}
|
||||
|
||||
// WaitForOutput polls the combined process output until the pattern is found
|
||||
// count time(s) or the timeout is reached. It also returns early if the process exits.
|
||||
func (p *Process) WaitForOutput(pattern string, count int, timeout time.Duration) error {
|
||||
|
|
|
|||
|
|
@ -22,11 +22,10 @@ type Request struct {
|
|||
protocol string
|
||||
|
||||
// for all protocol
|
||||
addr string
|
||||
port int
|
||||
body []byte
|
||||
timeout time.Duration
|
||||
resolver *net.Resolver
|
||||
addr string
|
||||
port int
|
||||
body []byte
|
||||
timeout time.Duration
|
||||
|
||||
// for http or https
|
||||
method string
|
||||
|
|
@ -134,11 +133,6 @@ func (r *Request) Body(content []byte) *Request {
|
|||
return r
|
||||
}
|
||||
|
||||
func (r *Request) Resolver(resolver *net.Resolver) *Request {
|
||||
r.resolver = resolver
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Request) Do() (*Response, error) {
|
||||
var (
|
||||
conn net.Conn
|
||||
|
|
@ -169,7 +163,7 @@ func (r *Request) Do() (*Response, error) {
|
|||
return nil, err
|
||||
}
|
||||
} else {
|
||||
dialer := &net.Dialer{Resolver: r.resolver}
|
||||
dialer := &net.Dialer{}
|
||||
switch r.protocol {
|
||||
case "tcp":
|
||||
conn, err = dialer.Dial("tcp", addr)
|
||||
|
|
@ -225,7 +219,6 @@ func (r *Request) sendHTTPRequest(method, urlstr string, host string, headers ma
|
|||
Timeout: time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
Resolver: r.resolver,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue