mirror of
https://github.com/fatedier/frp.git
synced 2026-05-15 08:05:49 -06:00
refactor: clean up code
This commit is contained in:
parent
8666e3643f
commit
3845587393
49 changed files with 2082 additions and 931 deletions
2
go.mod
2
go.mod
|
|
@ -36,6 +36,7 @@ require (
|
||||||
gopkg.in/ini.v1 v1.67.0
|
gopkg.in/ini.v1 v1.67.0
|
||||||
k8s.io/apimachinery v0.28.8
|
k8s.io/apimachinery v0.28.8
|
||||||
k8s.io/client-go v0.28.8
|
k8s.io/client-go v0.28.8
|
||||||
|
k8s.io/utils v0.0.0-20230406110748-d93618cff8a2
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|
@ -75,7 +76,6 @@ require (
|
||||||
google.golang.org/protobuf v1.36.5 // indirect
|
google.golang.org/protobuf v1.36.5 // indirect
|
||||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 // indirect
|
|
||||||
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect
|
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect
|
||||||
sigs.k8s.io/yaml v1.3.0 // indirect
|
sigs.k8s.io/yaml v1.3.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -14,11 +14,7 @@
|
||||||
|
|
||||||
package source
|
package source
|
||||||
|
|
||||||
import (
|
import v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ConfigSource implements Source for in-memory configuration.
|
// ConfigSource implements Source for in-memory configuration.
|
||||||
// All operations are thread-safe.
|
// All operations are thread-safe.
|
||||||
|
|
@ -39,23 +35,17 @@ func (s *ConfigSource) ReplaceAll(proxies []v1.ProxyConfigurer, visitors []v1.Vi
|
||||||
|
|
||||||
nextProxies := make(map[string]v1.ProxyConfigurer, len(proxies))
|
nextProxies := make(map[string]v1.ProxyConfigurer, len(proxies))
|
||||||
for _, p := range proxies {
|
for _, p := range proxies {
|
||||||
if p == nil {
|
name, err := validateProxyName(p)
|
||||||
return fmt.Errorf("proxy cannot be nil")
|
if err != nil {
|
||||||
}
|
return err
|
||||||
name := p.GetBaseConfig().Name
|
|
||||||
if name == "" {
|
|
||||||
return fmt.Errorf("proxy name cannot be empty")
|
|
||||||
}
|
}
|
||||||
nextProxies[name] = p
|
nextProxies[name] = p
|
||||||
}
|
}
|
||||||
nextVisitors := make(map[string]v1.VisitorConfigurer, len(visitors))
|
nextVisitors := make(map[string]v1.VisitorConfigurer, len(visitors))
|
||||||
for _, v := range visitors {
|
for _, v := range visitors {
|
||||||
if v == nil {
|
name, err := validateVisitorName(v)
|
||||||
return fmt.Errorf("visitor cannot be nil")
|
if err != nil {
|
||||||
}
|
return err
|
||||||
name := v.GetBaseConfig().Name
|
|
||||||
if name == "" {
|
|
||||||
return fmt.Errorf("visitor name cannot be empty")
|
|
||||||
}
|
}
|
||||||
nextVisitors[name] = v
|
nextVisitors[name] = v
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,11 @@ var (
|
||||||
ErrNotFound = errors.New("not found")
|
ErrNotFound = errors.New("not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
storeKindProxy = "proxy"
|
||||||
|
storeKindVisitor = "visitor"
|
||||||
|
)
|
||||||
|
|
||||||
func NewStoreSource(cfg StoreSourceConfig) (*StoreSource, error) {
|
func NewStoreSource(cfg StoreSourceConfig) (*StoreSource, error) {
|
||||||
if cfg.Path == "" {
|
if cfg.Path == "" {
|
||||||
return nil, fmt.Errorf("path is required")
|
return nil, fmt.Errorf("path is required")
|
||||||
|
|
@ -172,79 +177,111 @@ func (s *StoreSource) saveToFileUnlocked() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StoreSource) AddProxy(proxy v1.ProxyConfigurer) error {
|
func (s *StoreSource) persistOrRollbackUnlocked(rollback func()) error {
|
||||||
if proxy == nil {
|
if err := s.saveToFileUnlocked(); err != nil {
|
||||||
return fmt.Errorf("proxy cannot be nil")
|
rollback()
|
||||||
|
return fmt.Errorf("failed to persist: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store map selectors return the target map for generic helpers.
|
||||||
|
func proxyStoreEntries(s *StoreSource) map[string]v1.ProxyConfigurer {
|
||||||
|
return s.proxies
|
||||||
|
}
|
||||||
|
|
||||||
|
func visitorStoreEntries(s *StoreSource) map[string]v1.VisitorConfigurer {
|
||||||
|
return s.visitors
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store entry helpers share mutation, persistence, and rollback for proxy and visitor maps.
|
||||||
|
// T is intentionally limited by callers to v1.ProxyConfigurer or v1.VisitorConfigurer.
|
||||||
|
func addStoreEntry[T any](
|
||||||
|
s *StoreSource,
|
||||||
|
entriesFn func(*StoreSource) map[string]T,
|
||||||
|
kind string,
|
||||||
|
name string,
|
||||||
|
value T,
|
||||||
|
) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
entries := entriesFn(s)
|
||||||
|
if _, exists := entries[name]; exists {
|
||||||
|
return fmt.Errorf("%w: %s %q", ErrAlreadyExists, kind, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
name := proxy.GetBaseConfig().Name
|
entries[name] = value
|
||||||
|
return s.persistOrRollbackUnlocked(func() {
|
||||||
|
delete(entries, name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateStoreEntry[T any](
|
||||||
|
s *StoreSource,
|
||||||
|
entriesFn func(*StoreSource) map[string]T,
|
||||||
|
kind string,
|
||||||
|
name string,
|
||||||
|
value T,
|
||||||
|
) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
entries := entriesFn(s)
|
||||||
|
old, exists := entries[name]
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("%w: %s %q", ErrNotFound, kind, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries[name] = value
|
||||||
|
return s.persistOrRollbackUnlocked(func() {
|
||||||
|
entries[name] = old
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeStoreEntry[T any](
|
||||||
|
s *StoreSource,
|
||||||
|
entriesFn func(*StoreSource) map[string]T,
|
||||||
|
kind string,
|
||||||
|
name string,
|
||||||
|
) error {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return fmt.Errorf("proxy name cannot be empty")
|
return fmt.Errorf("%s name cannot be empty", kind)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
if _, exists := s.proxies[name]; exists {
|
entries := entriesFn(s)
|
||||||
return fmt.Errorf("%w: proxy %q", ErrAlreadyExists, name)
|
old, exists := entries[name]
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("%w: %s %q", ErrNotFound, kind, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.proxies[name] = proxy
|
delete(entries, name)
|
||||||
|
return s.persistOrRollbackUnlocked(func() {
|
||||||
|
entries[name] = old
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.saveToFileUnlocked(); err != nil {
|
func (s *StoreSource) AddProxy(proxy v1.ProxyConfigurer) error {
|
||||||
delete(s.proxies, name)
|
name, err := validateProxyName(proxy)
|
||||||
return fmt.Errorf("failed to persist: %w", err)
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return addStoreEntry(s, proxyStoreEntries, storeKindProxy, name, proxy)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StoreSource) UpdateProxy(proxy v1.ProxyConfigurer) error {
|
func (s *StoreSource) UpdateProxy(proxy v1.ProxyConfigurer) error {
|
||||||
if proxy == nil {
|
name, err := validateProxyName(proxy)
|
||||||
return fmt.Errorf("proxy cannot be nil")
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
return updateStoreEntry(s, proxyStoreEntries, storeKindProxy, name, proxy)
|
||||||
name := proxy.GetBaseConfig().Name
|
|
||||||
if name == "" {
|
|
||||||
return fmt.Errorf("proxy name cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
oldProxy, exists := s.proxies[name]
|
|
||||||
if !exists {
|
|
||||||
return fmt.Errorf("%w: proxy %q", ErrNotFound, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.proxies[name] = proxy
|
|
||||||
|
|
||||||
if err := s.saveToFileUnlocked(); err != nil {
|
|
||||||
s.proxies[name] = oldProxy
|
|
||||||
return fmt.Errorf("failed to persist: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StoreSource) RemoveProxy(name string) error {
|
func (s *StoreSource) RemoveProxy(name string) error {
|
||||||
if name == "" {
|
return removeStoreEntry(s, proxyStoreEntries, storeKindProxy, name)
|
||||||
return fmt.Errorf("proxy name cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
oldProxy, exists := s.proxies[name]
|
|
||||||
if !exists {
|
|
||||||
return fmt.Errorf("%w: proxy %q", ErrNotFound, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(s.proxies, name)
|
|
||||||
|
|
||||||
if err := s.saveToFileUnlocked(); err != nil {
|
|
||||||
s.proxies[name] = oldProxy
|
|
||||||
return fmt.Errorf("failed to persist: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StoreSource) GetProxy(name string) v1.ProxyConfigurer {
|
func (s *StoreSource) GetProxy(name string) v1.ProxyConfigurer {
|
||||||
|
|
@ -259,78 +296,23 @@ func (s *StoreSource) GetProxy(name string) v1.ProxyConfigurer {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StoreSource) AddVisitor(visitor v1.VisitorConfigurer) error {
|
func (s *StoreSource) AddVisitor(visitor v1.VisitorConfigurer) error {
|
||||||
if visitor == nil {
|
name, err := validateVisitorName(visitor)
|
||||||
return fmt.Errorf("visitor cannot be nil")
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
return addStoreEntry(s, visitorStoreEntries, storeKindVisitor, name, visitor)
|
||||||
name := visitor.GetBaseConfig().Name
|
|
||||||
if name == "" {
|
|
||||||
return fmt.Errorf("visitor name cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
if _, exists := s.visitors[name]; exists {
|
|
||||||
return fmt.Errorf("%w: visitor %q", ErrAlreadyExists, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.visitors[name] = visitor
|
|
||||||
|
|
||||||
if err := s.saveToFileUnlocked(); err != nil {
|
|
||||||
delete(s.visitors, name)
|
|
||||||
return fmt.Errorf("failed to persist: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StoreSource) UpdateVisitor(visitor v1.VisitorConfigurer) error {
|
func (s *StoreSource) UpdateVisitor(visitor v1.VisitorConfigurer) error {
|
||||||
if visitor == nil {
|
name, err := validateVisitorName(visitor)
|
||||||
return fmt.Errorf("visitor cannot be nil")
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
return updateStoreEntry(s, visitorStoreEntries, storeKindVisitor, name, visitor)
|
||||||
name := visitor.GetBaseConfig().Name
|
|
||||||
if name == "" {
|
|
||||||
return fmt.Errorf("visitor name cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
oldVisitor, exists := s.visitors[name]
|
|
||||||
if !exists {
|
|
||||||
return fmt.Errorf("%w: visitor %q", ErrNotFound, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.visitors[name] = visitor
|
|
||||||
|
|
||||||
if err := s.saveToFileUnlocked(); err != nil {
|
|
||||||
s.visitors[name] = oldVisitor
|
|
||||||
return fmt.Errorf("failed to persist: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StoreSource) RemoveVisitor(name string) error {
|
func (s *StoreSource) RemoveVisitor(name string) error {
|
||||||
if name == "" {
|
return removeStoreEntry(s, visitorStoreEntries, storeKindVisitor, name)
|
||||||
return fmt.Errorf("visitor name cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
oldVisitor, exists := s.visitors[name]
|
|
||||||
if !exists {
|
|
||||||
return fmt.Errorf("%w: visitor %q", ErrNotFound, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(s.visitors, name)
|
|
||||||
|
|
||||||
if err := s.saveToFileUnlocked(); err != nil {
|
|
||||||
s.visitors[name] = oldVisitor
|
|
||||||
return fmt.Errorf("failed to persist: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StoreSource) GetVisitor(name string) v1.VisitorConfigurer {
|
func (s *StoreSource) GetVisitor(name string) v1.VisitorConfigurer {
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ package source
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
@ -59,6 +60,101 @@ func TestStoreSource_AddProxyAndVisitor_DoesNotApplyRuntimeDefaults(t *testing.T
|
||||||
require.Empty(gotVisitor.(*v1.XTCPVisitorConfig).Protocol)
|
require.Empty(gotVisitor.(*v1.XTCPVisitorConfig).Protocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStoreSource_UpdateAndRemoveProxyAndVisitor(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
storeSource := newTestStoreSource(t)
|
||||||
|
|
||||||
|
proxyCfg := mockProxy("proxy1")
|
||||||
|
visitorCfg := mockVisitor("visitor1")
|
||||||
|
|
||||||
|
require.NoError(storeSource.AddProxy(proxyCfg))
|
||||||
|
require.NoError(storeSource.AddVisitor(visitorCfg))
|
||||||
|
require.ErrorIs(storeSource.AddProxy(proxyCfg), ErrAlreadyExists)
|
||||||
|
require.ErrorIs(storeSource.AddVisitor(visitorCfg), ErrAlreadyExists)
|
||||||
|
require.ErrorContains(storeSource.RemoveProxy(""), "proxy name cannot be empty")
|
||||||
|
require.ErrorContains(storeSource.RemoveVisitor(""), "visitor name cannot be empty")
|
||||||
|
|
||||||
|
updatedProxy := mockProxy("proxy1").(*v1.TCPProxyConfig)
|
||||||
|
updatedProxy.RemotePort = 19090
|
||||||
|
require.NoError(storeSource.UpdateProxy(updatedProxy))
|
||||||
|
require.Equal(19090, storeSource.GetProxy("proxy1").(*v1.TCPProxyConfig).RemotePort)
|
||||||
|
|
||||||
|
updatedVisitor := mockVisitor("visitor1").(*v1.STCPVisitorConfig)
|
||||||
|
updatedVisitor.ServerName = "updated-server"
|
||||||
|
require.NoError(storeSource.UpdateVisitor(updatedVisitor))
|
||||||
|
require.Equal("updated-server", storeSource.GetVisitor("visitor1").(*v1.STCPVisitorConfig).ServerName)
|
||||||
|
|
||||||
|
require.NoError(storeSource.RemoveProxy("proxy1"))
|
||||||
|
require.Nil(storeSource.GetProxy("proxy1"))
|
||||||
|
require.ErrorIs(storeSource.RemoveProxy("proxy1"), ErrNotFound)
|
||||||
|
|
||||||
|
require.NoError(storeSource.RemoveVisitor("visitor1"))
|
||||||
|
require.Nil(storeSource.GetVisitor("visitor1"))
|
||||||
|
require.ErrorIs(storeSource.RemoveVisitor("visitor1"), ErrNotFound)
|
||||||
|
|
||||||
|
require.ErrorIs(storeSource.UpdateProxy(updatedProxy), ErrNotFound)
|
||||||
|
require.ErrorIs(storeSource.UpdateVisitor(updatedVisitor), ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreSource_MutationRollsBackOnPersistFailure(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("chmod does not make directories unwritable on Windows")
|
||||||
|
}
|
||||||
|
if os.Getuid() == 0 {
|
||||||
|
t.Skip("chmod does not block writes for uid 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "store.json")
|
||||||
|
storeSource, err := NewStoreSource(StoreSourceConfig{Path: path})
|
||||||
|
require.NoError(err)
|
||||||
|
|
||||||
|
proxyCfg := mockProxy("proxy1")
|
||||||
|
visitorCfg := mockVisitor("visitor1")
|
||||||
|
originalRemotePort := proxyCfg.(*v1.TCPProxyConfig).RemotePort
|
||||||
|
originalServerName := visitorCfg.(*v1.STCPVisitorConfig).ServerName
|
||||||
|
require.NoError(storeSource.AddProxy(proxyCfg))
|
||||||
|
require.NoError(storeSource.AddVisitor(visitorCfg))
|
||||||
|
|
||||||
|
require.NoError(os.Chmod(dir, 0o500))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = os.Chmod(dir, 0o700)
|
||||||
|
})
|
||||||
|
|
||||||
|
requirePersistError := func(err error) {
|
||||||
|
t.Helper()
|
||||||
|
require.Error(err)
|
||||||
|
require.ErrorContains(err, "failed to persist")
|
||||||
|
require.NotErrorIs(err, ErrAlreadyExists)
|
||||||
|
require.NotErrorIs(err, ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
requirePersistError(storeSource.AddProxy(mockProxy("proxy2")))
|
||||||
|
require.Nil(storeSource.GetProxy("proxy2"))
|
||||||
|
|
||||||
|
updatedProxy := mockProxy("proxy1").(*v1.TCPProxyConfig)
|
||||||
|
updatedProxy.RemotePort = 19090
|
||||||
|
requirePersistError(storeSource.UpdateProxy(updatedProxy))
|
||||||
|
require.Equal(originalRemotePort, storeSource.GetProxy("proxy1").(*v1.TCPProxyConfig).RemotePort)
|
||||||
|
|
||||||
|
requirePersistError(storeSource.RemoveProxy("proxy1"))
|
||||||
|
require.NotNil(storeSource.GetProxy("proxy1"))
|
||||||
|
|
||||||
|
requirePersistError(storeSource.AddVisitor(mockVisitor("visitor2")))
|
||||||
|
require.Nil(storeSource.GetVisitor("visitor2"))
|
||||||
|
|
||||||
|
updatedVisitor := mockVisitor("visitor1").(*v1.STCPVisitorConfig)
|
||||||
|
updatedVisitor.ServerName = "updated-server"
|
||||||
|
requirePersistError(storeSource.UpdateVisitor(updatedVisitor))
|
||||||
|
require.Equal(originalServerName, storeSource.GetVisitor("visitor1").(*v1.STCPVisitorConfig).ServerName)
|
||||||
|
|
||||||
|
requirePersistError(storeSource.RemoveVisitor("visitor1"))
|
||||||
|
require.NotNil(storeSource.GetVisitor("visitor1"))
|
||||||
|
}
|
||||||
|
|
||||||
func TestStoreSource_LoadFromFile_DoesNotApplyRuntimeDefaults(t *testing.T) {
|
func TestStoreSource_LoadFromFile_DoesNotApplyRuntimeDefaults(t *testing.T) {
|
||||||
require := require.New(t)
|
require := require.New(t)
|
||||||
|
|
||||||
|
|
|
||||||
43
pkg/config/source/validation.go
Normal file
43
pkg/config/source/validation.go
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package source
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
func validateProxyName(proxy v1.ProxyConfigurer) (string, error) {
|
||||||
|
if proxy == nil {
|
||||||
|
return "", fmt.Errorf("proxy cannot be nil")
|
||||||
|
}
|
||||||
|
name := proxy.GetBaseConfig().Name
|
||||||
|
if name == "" {
|
||||||
|
return "", fmt.Errorf("proxy name cannot be empty")
|
||||||
|
}
|
||||||
|
return name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateVisitorName(visitor v1.VisitorConfigurer) (string, error) {
|
||||||
|
if visitor == nil {
|
||||||
|
return "", fmt.Errorf("visitor cannot be nil")
|
||||||
|
}
|
||||||
|
name := visitor.GetBaseConfig().Name
|
||||||
|
if name == "" {
|
||||||
|
return "", fmt.Errorf("visitor name cannot be empty")
|
||||||
|
}
|
||||||
|
return name, nil
|
||||||
|
}
|
||||||
43
pkg/config/v1/validation/auth.go
Normal file
43
pkg/config/v1/validation/auth.go
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package validation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
"github.com/fatedier/frp/pkg/policy/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (v *ConfigValidator) validateAuthTokenSource(token string, tokenSource *v1.ValueSource) error {
|
||||||
|
var errs error
|
||||||
|
// Preserve the previous client/server validation order for joined errors.
|
||||||
|
if token != "" && tokenSource != nil {
|
||||||
|
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
|
||||||
|
}
|
||||||
|
if tokenSource == nil {
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenSource.Type == "exec" {
|
||||||
|
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
||||||
|
errs = AppendError(errs, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := tokenSource.Validate(); err != nil {
|
||||||
|
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
|
||||||
|
}
|
||||||
|
return errs
|
||||||
|
}
|
||||||
228
pkg/config/v1/validation/auth_test.go
Normal file
228
pkg/config/v1/validation/auth_test.go
Normal file
|
|
@ -0,0 +1,228 @@
|
||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package validation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
"github.com/fatedier/frp/pkg/policy/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
tokenSourceConflictErr = "cannot specify both auth.token and auth.tokenSource"
|
||||||
|
tokenSourceExecErr = "unsafe feature \"TokenSourceExec\" is not enabled. To enable it, ensure it is allowed in the configuration or command line flags"
|
||||||
|
invalidFileSourceErr = "invalid auth.tokenSource: file configuration is required when type is 'file'"
|
||||||
|
unsupportedSourceErr = "invalid auth.tokenSource: unsupported value source type: env (only 'file' and 'exec' are supported)"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateAuthTokenSource(t *testing.T) {
|
||||||
|
for _, tc := range authTokenSourceTestCases() {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
validator := newAuthTokenSourceValidator(tc.unsafeAllowed)
|
||||||
|
err := validator.validateAuthTokenSource(tc.token, tc.tokenSource())
|
||||||
|
requireValidationErrors(t, err, tc.wantErrs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateClientAuthTokenSource(t *testing.T) {
|
||||||
|
for _, tc := range authTokenSourceTestCases() {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
auth := v1.AuthClientConfig{
|
||||||
|
Method: v1.AuthMethodToken,
|
||||||
|
Token: tc.token,
|
||||||
|
TokenSource: tc.tokenSource(),
|
||||||
|
}
|
||||||
|
validator := newAuthTokenSourceValidator(tc.unsafeAllowed)
|
||||||
|
_, err := validator.ValidateClientCommonConfig(validClientConfigWithAuth(auth))
|
||||||
|
requireValidationErrors(t, err, tc.wantErrs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateServerAuthTokenSource(t *testing.T) {
|
||||||
|
for _, tc := range authTokenSourceTestCases() {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
auth := v1.AuthServerConfig{
|
||||||
|
Method: v1.AuthMethodToken,
|
||||||
|
Token: tc.token,
|
||||||
|
TokenSource: tc.tokenSource(),
|
||||||
|
}
|
||||||
|
validator := newAuthTokenSourceValidator(tc.unsafeAllowed)
|
||||||
|
_, err := validator.ValidateServerConfig(validServerConfigWithAuth(auth))
|
||||||
|
requireValidationErrors(t, err, tc.wantErrs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type authTokenSourceTestCase struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
tokenSource func() *v1.ValueSource
|
||||||
|
unsafeAllowed bool
|
||||||
|
wantErrs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func authTokenSourceTestCases() []authTokenSourceTestCase {
|
||||||
|
return []authTokenSourceTestCase{
|
||||||
|
{
|
||||||
|
name: "empty token config",
|
||||||
|
tokenSource: nilTokenSource,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid file tokenSource",
|
||||||
|
tokenSource: validFileTokenSource,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "literal token without tokenSource",
|
||||||
|
token: "token",
|
||||||
|
tokenSource: nilTokenSource,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "literal token conflicts with file tokenSource",
|
||||||
|
token: "token",
|
||||||
|
tokenSource: validFileTokenSource,
|
||||||
|
wantErrs: []string{tokenSourceConflictErr},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exec tokenSource requires unsafe feature",
|
||||||
|
tokenSource: validExecTokenSource,
|
||||||
|
wantErrs: []string{tokenSourceExecErr},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exec tokenSource with unsafe feature allowed",
|
||||||
|
tokenSource: validExecTokenSource,
|
||||||
|
unsafeAllowed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "literal token conflicts with exec tokenSource and unsafe feature disabled",
|
||||||
|
token: "token",
|
||||||
|
tokenSource: validExecTokenSource,
|
||||||
|
wantErrs: []string{
|
||||||
|
tokenSourceConflictErr,
|
||||||
|
tokenSourceExecErr,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "literal token conflicts with exec tokenSource and unsafe feature allowed",
|
||||||
|
token: "token",
|
||||||
|
tokenSource: validExecTokenSource,
|
||||||
|
unsafeAllowed: true,
|
||||||
|
wantErrs: []string{tokenSourceConflictErr},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid file tokenSource is wrapped",
|
||||||
|
tokenSource: invalidFileTokenSource,
|
||||||
|
wantErrs: []string{invalidFileSourceErr},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported tokenSource type is wrapped",
|
||||||
|
tokenSource: unsupportedTokenSource,
|
||||||
|
wantErrs: []string{unsupportedSourceErr},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthTokenSourceValidator(unsafeAllowed bool) *ConfigValidator {
|
||||||
|
if !unsafeAllowed {
|
||||||
|
return NewConfigValidator(nil)
|
||||||
|
}
|
||||||
|
return NewConfigValidator(security.NewUnsafeFeatures([]string{security.TokenSourceExec}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func requireValidationErrors(t *testing.T, err error, wantErrs []string) {
|
||||||
|
t.Helper()
|
||||||
|
if len(wantErrs) == 0 {
|
||||||
|
require.NoError(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.Error(t, err)
|
||||||
|
// Client/server validators may wrap joined errors in another join layer; compare leaf errors.
|
||||||
|
gotErrs := unwrapValidationErrors(err)
|
||||||
|
require.Len(t, gotErrs, len(wantErrs))
|
||||||
|
for i, wantErr := range wantErrs {
|
||||||
|
require.EqualError(t, gotErrs[i], wantErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unwrapValidationErrors(err error) []error {
|
||||||
|
type joinedError interface {
|
||||||
|
Unwrap() []error
|
||||||
|
}
|
||||||
|
joined, ok := err.(joinedError)
|
||||||
|
if !ok {
|
||||||
|
return []error{err}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
for _, err := range joined.Unwrap() {
|
||||||
|
errs = append(errs, unwrapValidationErrors(err)...)
|
||||||
|
}
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
|
||||||
|
// nilTokenSource keeps the shared table shape uniform for cases without a tokenSource.
|
||||||
|
func nilTokenSource() *v1.ValueSource {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validFileTokenSource() *v1.ValueSource {
|
||||||
|
return &v1.ValueSource{
|
||||||
|
Type: "file",
|
||||||
|
File: &v1.FileSource{Path: "token.txt"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validExecTokenSource() *v1.ValueSource {
|
||||||
|
return &v1.ValueSource{
|
||||||
|
Type: "exec",
|
||||||
|
Exec: &v1.ExecSource{Command: "print-token"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func invalidFileTokenSource() *v1.ValueSource {
|
||||||
|
return &v1.ValueSource{
|
||||||
|
Type: "file",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unsupportedTokenSource() *v1.ValueSource {
|
||||||
|
return &v1.ValueSource{Type: "env"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validClientConfigWithAuth(auth v1.AuthClientConfig) *v1.ClientCommonConfig {
|
||||||
|
return &v1.ClientCommonConfig{
|
||||||
|
Auth: auth,
|
||||||
|
Log: v1.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
},
|
||||||
|
Transport: v1.ClientTransportConfig{
|
||||||
|
Protocol: "tcp",
|
||||||
|
WireProtocol: "v1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validServerConfigWithAuth(auth v1.AuthServerConfig) *v1.ServerConfig {
|
||||||
|
return &v1.ServerConfig{
|
||||||
|
Auth: auth,
|
||||||
|
Log: v1.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -68,22 +68,7 @@ func (v *ConfigValidator) validateAuthConfig(c *v1.AuthClientConfig) (Warning, e
|
||||||
errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes))
|
errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate token/tokenSource mutual exclusivity
|
errs = AppendError(errs, v.validateAuthTokenSource(c.Token, c.TokenSource))
|
||||||
if c.Token != "" && c.TokenSource != nil {
|
|
||||||
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate tokenSource if specified
|
|
||||||
if c.TokenSource != nil {
|
|
||||||
if c.TokenSource.Type == "exec" {
|
|
||||||
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
|
||||||
errs = AppendError(errs, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := c.TokenSource.Validate(); err != nil {
|
|
||||||
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := v.validateOIDCConfig(&c.OIDC); err != nil {
|
if err := v.validateOIDCConfig(&c.OIDC); err != nil {
|
||||||
errs = AppendError(errs, err)
|
errs = AppendError(errs, err)
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,6 @@ import (
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
"github.com/fatedier/frp/pkg/policy/security"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (v *ConfigValidator) ValidateServerConfig(c *v1.ServerConfig) (Warning, error) {
|
func (v *ConfigValidator) ValidateServerConfig(c *v1.ServerConfig) (Warning, error) {
|
||||||
|
|
@ -36,22 +35,7 @@ func (v *ConfigValidator) ValidateServerConfig(c *v1.ServerConfig) (Warning, err
|
||||||
errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes))
|
errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate token/tokenSource mutual exclusivity
|
errs = AppendError(errs, v.validateAuthTokenSource(c.Auth.Token, c.Auth.TokenSource))
|
||||||
if c.Auth.Token != "" && c.Auth.TokenSource != nil {
|
|
||||||
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate tokenSource if specified
|
|
||||||
if c.Auth.TokenSource != nil {
|
|
||||||
if c.Auth.TokenSource.Type == "exec" {
|
|
||||||
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
|
||||||
errs = AppendError(errs, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := c.Auth.TokenSource.Validate(); err != nil {
|
|
||||||
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validateLogConfig(&c.Log); err != nil {
|
if err := validateLogConfig(&c.Log); err != nil {
|
||||||
errs = AppendError(errs, err)
|
errs = AppendError(errs, err)
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"k8s.io/utils/clock"
|
||||||
|
|
||||||
"github.com/fatedier/frp/pkg/util/log"
|
"github.com/fatedier/frp/pkg/util/log"
|
||||||
"github.com/fatedier/frp/pkg/util/metric"
|
"github.com/fatedier/frp/pkg/util/metric"
|
||||||
server "github.com/fatedier/frp/server/metrics"
|
server "github.com/fatedier/frp/server/metrics"
|
||||||
|
|
@ -37,12 +39,21 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
type serverMetrics struct {
|
type serverMetrics struct {
|
||||||
info *ServerStatistics
|
info *ServerStatistics
|
||||||
mu sync.Mutex
|
clock clock.WithTicker
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServerMetrics() *serverMetrics {
|
func newServerMetrics() *serverMetrics {
|
||||||
|
return newServerMetricsWithClock(clock.RealClock{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServerMetricsWithClock(clk clock.WithTicker) *serverMetrics {
|
||||||
|
if clk == nil {
|
||||||
|
clk = clock.RealClock{}
|
||||||
|
}
|
||||||
return &serverMetrics{
|
return &serverMetrics{
|
||||||
|
clock: clk,
|
||||||
info: &ServerStatistics{
|
info: &ServerStatistics{
|
||||||
TotalTrafficIn: metric.NewDateCounter(ReserveDays),
|
TotalTrafficIn: metric.NewDateCounter(ReserveDays),
|
||||||
TotalTrafficOut: metric.NewDateCounter(ReserveDays),
|
TotalTrafficOut: metric.NewDateCounter(ReserveDays),
|
||||||
|
|
@ -57,14 +68,23 @@ func newServerMetrics() *serverMetrics {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *serverMetrics) run() {
|
func (m *serverMetrics) run() {
|
||||||
go func() {
|
go m.runUntil(nil)
|
||||||
for {
|
}
|
||||||
time.Sleep(12 * time.Hour)
|
|
||||||
start := time.Now()
|
func (m *serverMetrics) runUntil(stopCh <-chan struct{}) {
|
||||||
|
ticker := m.clock.NewTicker(12 * time.Hour)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C():
|
||||||
|
start := m.clock.Now()
|
||||||
count, total := m.clearUselessInfo(time.Duration(7*24) * time.Hour)
|
count, total := m.clearUselessInfo(time.Duration(7*24) * time.Hour)
|
||||||
log.Debugf("clear useless proxy statistics data count %d/%d, cost %v", count, total, time.Since(start))
|
log.Debugf("clear useless proxy statistics data count %d/%d, cost %v", count, total, m.clock.Since(start))
|
||||||
|
case <-stopCh:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *serverMetrics) clearUselessInfo(continuousOfflineDuration time.Duration) (int, int) {
|
func (m *serverMetrics) clearUselessInfo(continuousOfflineDuration time.Duration) (int, int) {
|
||||||
|
|
@ -77,7 +97,7 @@ func (m *serverMetrics) clearUselessInfo(continuousOfflineDuration time.Duration
|
||||||
for name, data := range m.info.ProxyStatistics {
|
for name, data := range m.info.ProxyStatistics {
|
||||||
if !data.LastCloseTime.IsZero() &&
|
if !data.LastCloseTime.IsZero() &&
|
||||||
data.LastStartTime.Before(data.LastCloseTime) &&
|
data.LastStartTime.Before(data.LastCloseTime) &&
|
||||||
time.Since(data.LastCloseTime) > continuousOfflineDuration {
|
m.clock.Since(data.LastCloseTime) > continuousOfflineDuration {
|
||||||
delete(m.info.ProxyStatistics, name)
|
delete(m.info.ProxyStatistics, name)
|
||||||
count++
|
count++
|
||||||
log.Tracef("clear proxy [%s]'s statistics data, lastCloseTime: [%s]", name, data.LastCloseTime.String())
|
log.Tracef("clear proxy [%s]'s statistics data, lastCloseTime: [%s]", name, data.LastCloseTime.String())
|
||||||
|
|
@ -121,7 +141,7 @@ func (m *serverMetrics) NewProxy(name string, proxyType string, user string, cli
|
||||||
}
|
}
|
||||||
proxyStats.User = user
|
proxyStats.User = user
|
||||||
proxyStats.ClientID = clientID
|
proxyStats.ClientID = clientID
|
||||||
proxyStats.LastStartTime = time.Now()
|
proxyStats.LastStartTime = m.clock.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *serverMetrics) CloseProxy(name string, proxyType string) {
|
func (m *serverMetrics) CloseProxy(name string, proxyType string) {
|
||||||
|
|
@ -131,7 +151,7 @@ func (m *serverMetrics) CloseProxy(name string, proxyType string) {
|
||||||
counter.Dec(1)
|
counter.Dec(1)
|
||||||
}
|
}
|
||||||
if proxyStats, ok := m.info.ProxyStatistics[name]; ok {
|
if proxyStats, ok := m.info.ProxyStatistics[name]; ok {
|
||||||
proxyStats.LastCloseTime = time.Now()
|
proxyStats.LastCloseTime = m.clock.Now()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
83
pkg/metrics/mem/server_test.go
Normal file
83
pkg/metrics/mem/server_test.go
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
package mem
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
clocktesting "k8s.io/utils/clock/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServerMetricsUsesClockForProxyTimestamps(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||||
|
clk := clocktesting.NewFakeClock(start)
|
||||||
|
metrics := newServerMetricsWithClock(clk)
|
||||||
|
|
||||||
|
metrics.NewProxy("proxy", "tcp", "user", "client-id")
|
||||||
|
require.Equal(start, metrics.info.ProxyStatistics["proxy"].LastStartTime)
|
||||||
|
|
||||||
|
closedAt := start.Add(time.Minute)
|
||||||
|
clk.SetTime(closedAt)
|
||||||
|
metrics.CloseProxy("proxy", "tcp")
|
||||||
|
require.Equal(closedAt, metrics.info.ProxyStatistics["proxy"].LastCloseTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerMetricsClearUselessInfoUsesClock(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||||
|
clk := clocktesting.NewFakeClock(start.Add(25 * time.Hour))
|
||||||
|
metrics := newServerMetricsWithClock(clk)
|
||||||
|
metrics.info.ProxyStatistics["proxy"] = &ProxyStatistics{
|
||||||
|
Name: "proxy",
|
||||||
|
LastStartTime: start.Add(-time.Hour),
|
||||||
|
LastCloseTime: start,
|
||||||
|
}
|
||||||
|
|
||||||
|
count, total := metrics.clearUselessInfo(24 * time.Hour)
|
||||||
|
|
||||||
|
require.Equal(1, count)
|
||||||
|
require.Equal(1, total)
|
||||||
|
require.Empty(metrics.info.ProxyStatistics)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerMetricsRunUsesClockTicker(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||||
|
clk := clocktesting.NewFakeClock(start)
|
||||||
|
metrics := newServerMetricsWithClock(clk)
|
||||||
|
metrics.info.ProxyStatistics["proxy"] = &ProxyStatistics{
|
||||||
|
Name: "proxy",
|
||||||
|
LastStartTime: start.Add(-time.Hour),
|
||||||
|
LastCloseTime: start,
|
||||||
|
}
|
||||||
|
|
||||||
|
stopCh := make(chan struct{})
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
metrics.runUntil(stopCh)
|
||||||
|
}()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(stopCh)
|
||||||
|
<-done
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Eventually(clk.HasWaiters, time.Second, time.Millisecond)
|
||||||
|
clk.Step(8 * 24 * time.Hour)
|
||||||
|
|
||||||
|
require.Eventually(func() bool {
|
||||||
|
return !metrics.hasProxyStatistics("proxy")
|
||||||
|
}, time.Second, time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *serverMetrics) hasProxyStatistics(name string) bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
_, ok := m.info.ProxyStatistics[name]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
@ -109,10 +109,9 @@ func AsyncHandler(f func(Message)) func(Message) {
|
||||||
type Dispatcher struct {
|
type Dispatcher struct {
|
||||||
rw ReadWriter
|
rw ReadWriter
|
||||||
|
|
||||||
sendCh chan Message
|
sendCh chan Message
|
||||||
doneCh chan struct{}
|
doneCh chan struct{}
|
||||||
msgHandlers map[reflect.Type]func(Message)
|
msgHandlers map[reflect.Type]func(Message)
|
||||||
defaultHandler func(Message)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDispatcher(rw ReadWriter) *Dispatcher {
|
func NewDispatcher(rw ReadWriter) *Dispatcher {
|
||||||
|
|
@ -151,8 +150,6 @@ func (d *Dispatcher) readLoop() {
|
||||||
|
|
||||||
if handler, ok := d.msgHandlers[reflect.TypeOf(m)]; ok {
|
if handler, ok := d.msgHandlers[reflect.TypeOf(m)]; ok {
|
||||||
handler(m)
|
handler(m)
|
||||||
} else if d.defaultHandler != nil {
|
|
||||||
d.defaultHandler(m)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -170,10 +167,6 @@ func (d *Dispatcher) RegisterHandler(msg Message, handler func(Message)) {
|
||||||
d.msgHandlers[reflect.TypeOf(msg)] = handler
|
d.msgHandlers[reflect.TypeOf(msg)] = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Dispatcher) RegisterDefaultHandler(handler func(Message)) {
|
|
||||||
d.defaultHandler = handler
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Dispatcher) Done() chan struct{} {
|
func (d *Dispatcher) Done() chan struct{} {
|
||||||
return d.doneCh
|
return d.doneCh
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
|
"k8s.io/utils/clock"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
@ -144,19 +145,19 @@ func getBehaviorByModeAndIndex(mode int, index int) (RecommandBehavior, Recomman
|
||||||
return behaviors[index].A, behaviors[index].B
|
return behaviors[index].A, behaviors[index].B
|
||||||
}
|
}
|
||||||
|
|
||||||
func getBehaviorScoresByMode(mode int, defaultScore int) []*BehaviorScore {
|
func getBehaviorScoresByMode(mode int, defaultScore int) []*behaviorScore {
|
||||||
return getBehaviorScoresByMode2(mode, defaultScore, defaultScore)
|
return getBehaviorScoresByMode2(mode, defaultScore, defaultScore)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore {
|
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*behaviorScore {
|
||||||
behaviors := getBehaviorByMode(mode)
|
behaviors := getBehaviorByMode(mode)
|
||||||
scores := make([]*BehaviorScore, 0, len(behaviors))
|
scores := make([]*behaviorScore, 0, len(behaviors))
|
||||||
for i := range behaviors {
|
for i := range behaviors {
|
||||||
score := receiverScore
|
score := receiverScore
|
||||||
if behaviors[i].A.Role == DetectRoleSender {
|
if behaviors[i].A.Role == DetectRoleSender {
|
||||||
score = senderScore
|
score = senderScore
|
||||||
}
|
}
|
||||||
scores = append(scores, &BehaviorScore{Mode: mode, Index: i, Score: score})
|
scores = append(scores, &behaviorScore{Mode: mode, Index: i, Score: score})
|
||||||
}
|
}
|
||||||
return scores
|
return scores
|
||||||
}
|
}
|
||||||
|
|
@ -170,14 +171,18 @@ type RecommandBehavior struct {
|
||||||
ListenRandomPorts int
|
ListenRandomPorts int
|
||||||
}
|
}
|
||||||
|
|
||||||
type MakeHoleRecords struct {
|
type makeHoleRecords struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
scores []*BehaviorScore
|
scores []*behaviorScore
|
||||||
LastUpdateTime time.Time
|
clock clock.PassiveClock
|
||||||
|
lastUpdateTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMakeHoleRecords(c, v *NatFeature) *MakeHoleRecords {
|
func newMakeHoleRecordsWithClock(c, v *NatFeature, clk clock.PassiveClock) *makeHoleRecords {
|
||||||
scores := []*BehaviorScore{}
|
if clk == nil {
|
||||||
|
clk = clock.RealClock{}
|
||||||
|
}
|
||||||
|
scores := []*behaviorScore{}
|
||||||
easyCount, hardCount, portsChangedRegularCount := ClassifyFeatureCount([]*NatFeature{c, v})
|
easyCount, hardCount, portsChangedRegularCount := ClassifyFeatureCount([]*NatFeature{c, v})
|
||||||
appendMode0 := func() {
|
appendMode0 := func() {
|
||||||
switch {
|
switch {
|
||||||
|
|
@ -212,13 +217,17 @@ func NewMakeHoleRecords(c, v *NatFeature) *MakeHoleRecords {
|
||||||
scores = append(scores, getBehaviorScoresByMode(DetectMode1, 1)...)
|
scores = append(scores, getBehaviorScoresByMode(DetectMode1, 1)...)
|
||||||
scores = append(scores, getBehaviorScoresByMode(DetectMode3, 1)...)
|
scores = append(scores, getBehaviorScoresByMode(DetectMode3, 1)...)
|
||||||
}
|
}
|
||||||
return &MakeHoleRecords{scores: scores, LastUpdateTime: time.Now()}
|
return &makeHoleRecords{
|
||||||
|
scores: scores,
|
||||||
|
clock: clk,
|
||||||
|
lastUpdateTime: clk.Now(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mhr *MakeHoleRecords) ReportSuccess(mode int, index int) {
|
func (mhr *makeHoleRecords) reportSuccess(mode int, index int) {
|
||||||
mhr.mu.Lock()
|
mhr.mu.Lock()
|
||||||
defer mhr.mu.Unlock()
|
defer mhr.mu.Unlock()
|
||||||
mhr.LastUpdateTime = time.Now()
|
mhr.lastUpdateTime = mhr.clock.Now()
|
||||||
for i := range mhr.scores {
|
for i := range mhr.scores {
|
||||||
score := mhr.scores[i]
|
score := mhr.scores[i]
|
||||||
if score.Mode != mode || score.Index != index {
|
if score.Mode != mode || score.Index != index {
|
||||||
|
|
@ -231,22 +240,22 @@ func (mhr *MakeHoleRecords) ReportSuccess(mode int, index int) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mhr *MakeHoleRecords) Recommand() (mode, index int) {
|
func (mhr *makeHoleRecords) recommand() (mode, index int) {
|
||||||
mhr.mu.Lock()
|
mhr.mu.Lock()
|
||||||
defer mhr.mu.Unlock()
|
defer mhr.mu.Unlock()
|
||||||
|
|
||||||
if len(mhr.scores) == 0 {
|
if len(mhr.scores) == 0 {
|
||||||
return 0, 0
|
return 0, 0
|
||||||
}
|
}
|
||||||
maxScore := slices.MaxFunc(mhr.scores, func(a, b *BehaviorScore) int {
|
maxScore := slices.MaxFunc(mhr.scores, func(a, b *behaviorScore) int {
|
||||||
return cmp.Compare(a.Score, b.Score)
|
return cmp.Compare(a.Score, b.Score)
|
||||||
})
|
})
|
||||||
maxScore.Score--
|
maxScore.Score--
|
||||||
mhr.LastUpdateTime = time.Now()
|
mhr.lastUpdateTime = mhr.clock.Now()
|
||||||
return maxScore.Mode, maxScore.Index
|
return maxScore.Mode, maxScore.Index
|
||||||
}
|
}
|
||||||
|
|
||||||
type BehaviorScore struct {
|
type behaviorScore struct {
|
||||||
Mode int
|
Mode int
|
||||||
Index int
|
Index int
|
||||||
// between -10 and 10
|
// between -10 and 10
|
||||||
|
|
@ -255,16 +264,25 @@ type BehaviorScore struct {
|
||||||
|
|
||||||
type Analyzer struct {
|
type Analyzer struct {
|
||||||
// key is client ip + visitor ip
|
// key is client ip + visitor ip
|
||||||
records map[string]*MakeHoleRecords
|
records map[string]*makeHoleRecords
|
||||||
dataReserveDuration time.Duration
|
dataReserveDuration time.Duration
|
||||||
|
clock clock.PassiveClock
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAnalyzer(dataReserveDuration time.Duration) *Analyzer {
|
func NewAnalyzer(dataReserveDuration time.Duration) *Analyzer {
|
||||||
|
return newAnalyzerWithClock(dataReserveDuration, clock.RealClock{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAnalyzerWithClock(dataReserveDuration time.Duration, clk clock.PassiveClock) *Analyzer {
|
||||||
|
if clk == nil {
|
||||||
|
clk = clock.RealClock{}
|
||||||
|
}
|
||||||
return &Analyzer{
|
return &Analyzer{
|
||||||
records: make(map[string]*MakeHoleRecords),
|
records: make(map[string]*makeHoleRecords),
|
||||||
dataReserveDuration: dataReserveDuration,
|
dataReserveDuration: dataReserveDuration,
|
||||||
|
clock: clk,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -272,12 +290,12 @@ func (a *Analyzer) GetRecommandBehaviors(key string, c, v *NatFeature) (mode, in
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
records, ok := a.records[key]
|
records, ok := a.records[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
records = NewMakeHoleRecords(c, v)
|
records = newMakeHoleRecordsWithClock(c, v, a.clock)
|
||||||
a.records[key] = records
|
a.records[key] = records
|
||||||
}
|
}
|
||||||
a.mu.Unlock()
|
a.mu.Unlock()
|
||||||
|
|
||||||
mode, index = records.Recommand()
|
mode, index = records.recommand()
|
||||||
cBehavior, vBehavior := getBehaviorByModeAndIndex(mode, index)
|
cBehavior, vBehavior := getBehaviorByModeAndIndex(mode, index)
|
||||||
|
|
||||||
switch mode {
|
switch mode {
|
||||||
|
|
@ -307,11 +325,11 @@ func (a *Analyzer) ReportSuccess(key string, mode, index int) {
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
records.ReportSuccess(mode, index)
|
records.reportSuccess(mode, index)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Analyzer) Clean() (int, int) {
|
func (a *Analyzer) Clean() (int, int) {
|
||||||
now := time.Now()
|
now := a.clock.Now()
|
||||||
total := 0
|
total := 0
|
||||||
count := 0
|
count := 0
|
||||||
|
|
||||||
|
|
@ -321,7 +339,7 @@ func (a *Analyzer) Clean() (int, int) {
|
||||||
total = len(a.records)
|
total = len(a.records)
|
||||||
// clean up records that have not been used for a period of time.
|
// clean up records that have not been used for a period of time.
|
||||||
for key, records := range a.records {
|
for key, records := range a.records {
|
||||||
if now.Sub(records.LastUpdateTime) > a.dataReserveDuration {
|
if now.Sub(records.lastUpdateTime) > a.dataReserveDuration {
|
||||||
delete(a.records, key)
|
delete(a.records, key)
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
|
|
|
||||||
33
pkg/nathole/analysis_test.go
Normal file
33
pkg/nathole/analysis_test.go
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
package nathole
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
clocktesting "k8s.io/utils/clock/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAnalyzerUsesClockForRecordTimestamps(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||||
|
clk := clocktesting.NewFakeClock(start)
|
||||||
|
analyzer := newAnalyzerWithClock(time.Hour, clk)
|
||||||
|
clientFeature := &NatFeature{NatType: EasyNAT, Behavior: BehaviorNoChange}
|
||||||
|
visitorFeature := &NatFeature{NatType: EasyNAT, Behavior: BehaviorNoChange}
|
||||||
|
|
||||||
|
mode, index, _, _ := analyzer.GetRecommandBehaviors("key", clientFeature, visitorFeature)
|
||||||
|
require.Equal(start, analyzer.records["key"].lastUpdateTime)
|
||||||
|
|
||||||
|
updatedAt := start.Add(time.Minute)
|
||||||
|
clk.SetTime(updatedAt)
|
||||||
|
analyzer.ReportSuccess("key", mode, index)
|
||||||
|
require.Equal(updatedAt, analyzer.records["key"].lastUpdateTime)
|
||||||
|
|
||||||
|
clk.SetTime(start.Add(2 * time.Hour))
|
||||||
|
count, total := analyzer.Clean()
|
||||||
|
require.Equal(1, count)
|
||||||
|
require.Equal(1, total)
|
||||||
|
require.Empty(analyzer.records)
|
||||||
|
}
|
||||||
|
|
@ -326,40 +326,16 @@ func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleR
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol := vm.Protocol
|
protocol := vm.Protocol
|
||||||
vResp := &msg.NatHoleResp{
|
vResp := newNatHoleResponse(
|
||||||
TransactionID: vm.TransactionID,
|
vm.TransactionID, session.sid, protocol, mode,
|
||||||
Sid: session.sid,
|
cm.MappedAddrs, cm.AssistedAddrs, vBehavior,
|
||||||
Protocol: protocol,
|
timeoutMs-vBehavior.SendDelayMs, cNatFeature.PortsDifference,
|
||||||
CandidateAddrs: slices.Compact(cm.MappedAddrs),
|
)
|
||||||
AssistedAddrs: slices.Compact(cm.AssistedAddrs),
|
cResp := newNatHoleResponse(
|
||||||
DetectBehavior: msg.NatHoleDetectBehavior{
|
cm.TransactionID, session.sid, protocol, mode,
|
||||||
Mode: mode,
|
vm.MappedAddrs, vm.AssistedAddrs, cBehavior,
|
||||||
Role: vBehavior.Role,
|
timeoutMs-cBehavior.SendDelayMs, vNatFeature.PortsDifference,
|
||||||
TTL: vBehavior.TTL,
|
)
|
||||||
SendDelayMs: vBehavior.SendDelayMs,
|
|
||||||
ReadTimeoutMs: timeoutMs - vBehavior.SendDelayMs,
|
|
||||||
SendRandomPorts: vBehavior.PortsRandomNumber,
|
|
||||||
ListenRandomPorts: vBehavior.ListenRandomPorts,
|
|
||||||
CandidatePorts: getRangePorts(cm.MappedAddrs, cNatFeature.PortsDifference, vBehavior.PortsRangeNumber),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cResp := &msg.NatHoleResp{
|
|
||||||
TransactionID: cm.TransactionID,
|
|
||||||
Sid: session.sid,
|
|
||||||
Protocol: protocol,
|
|
||||||
CandidateAddrs: slices.Compact(vm.MappedAddrs),
|
|
||||||
AssistedAddrs: slices.Compact(vm.AssistedAddrs),
|
|
||||||
DetectBehavior: msg.NatHoleDetectBehavior{
|
|
||||||
Mode: mode,
|
|
||||||
Role: cBehavior.Role,
|
|
||||||
TTL: cBehavior.TTL,
|
|
||||||
SendDelayMs: cBehavior.SendDelayMs,
|
|
||||||
ReadTimeoutMs: timeoutMs - cBehavior.SendDelayMs,
|
|
||||||
SendRandomPorts: cBehavior.PortsRandomNumber,
|
|
||||||
ListenRandomPorts: cBehavior.ListenRandomPorts,
|
|
||||||
CandidatePorts: getRangePorts(vm.MappedAddrs, vNatFeature.PortsDifference, cBehavior.PortsRangeNumber),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("sid [%s] visitor nat: %+v, candidateAddrs: %v; client nat: %+v, candidateAddrs: %v, protocol: %s",
|
log.Debugf("sid [%s] visitor nat: %+v, candidateAddrs: %v; client nat: %+v, candidateAddrs: %v, protocol: %s",
|
||||||
session.sid, *vNatFeature, vm.MappedAddrs, *cNatFeature, cm.MappedAddrs, protocol)
|
session.sid, *vNatFeature, vm.MappedAddrs, *cNatFeature, cm.MappedAddrs, protocol)
|
||||||
|
|
@ -368,6 +344,38 @@ func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleR
|
||||||
return vResp, cResp, nil
|
return vResp, cResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newNatHoleResponse(
|
||||||
|
transactionID string,
|
||||||
|
sid string,
|
||||||
|
protocol string,
|
||||||
|
mode int,
|
||||||
|
candidateAddrs []string,
|
||||||
|
assistedAddrs []string,
|
||||||
|
behavior RecommandBehavior,
|
||||||
|
readTimeoutMs int,
|
||||||
|
portsDifference int,
|
||||||
|
) *msg.NatHoleResp {
|
||||||
|
compactCandidateAddrs := slices.Compact(candidateAddrs)
|
||||||
|
compactAssistedAddrs := slices.Compact(assistedAddrs)
|
||||||
|
return &msg.NatHoleResp{
|
||||||
|
TransactionID: transactionID,
|
||||||
|
Sid: sid,
|
||||||
|
Protocol: protocol,
|
||||||
|
CandidateAddrs: compactCandidateAddrs,
|
||||||
|
AssistedAddrs: compactAssistedAddrs,
|
||||||
|
DetectBehavior: msg.NatHoleDetectBehavior{
|
||||||
|
Mode: mode,
|
||||||
|
Role: behavior.Role,
|
||||||
|
TTL: behavior.TTL,
|
||||||
|
SendDelayMs: behavior.SendDelayMs,
|
||||||
|
ReadTimeoutMs: readTimeoutMs,
|
||||||
|
SendRandomPorts: behavior.PortsRandomNumber,
|
||||||
|
ListenRandomPorts: behavior.ListenRandomPorts,
|
||||||
|
CandidatePorts: getRangePorts(candidateAddrs, portsDifference, behavior.PortsRangeNumber),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getRangePorts(addrs []string, difference, maxNumber int) []msg.PortsRange {
|
func getRangePorts(addrs []string, difference, maxNumber int) []msg.PortsRange {
|
||||||
if maxNumber <= 0 {
|
if maxNumber <= 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -17,17 +17,9 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
stdlog "log"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fatedier/golib/pool"
|
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
"github.com/fatedier/frp/pkg/util/log"
|
|
||||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
@ -37,57 +29,28 @@ func init() {
|
||||||
type HTTP2HTTPPlugin struct {
|
type HTTP2HTTPPlugin struct {
|
||||||
opts *v1.HTTP2HTTPPluginOptions
|
opts *v1.HTTP2HTTPPluginOptions
|
||||||
|
|
||||||
l *Listener
|
*httpBridgePlugin
|
||||||
s *http.Server
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTP2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
func NewHTTP2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
||||||
opts := options.(*v1.HTTP2HTTPPluginOptions)
|
opts := options.(*v1.HTTP2HTTPPluginOptions)
|
||||||
|
|
||||||
listener := NewProxyListener()
|
|
||||||
|
|
||||||
p := &HTTP2HTTPPlugin{
|
p := &HTTP2HTTPPlugin{
|
||||||
opts: opts,
|
opts: opts,
|
||||||
l: listener,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rp := &httputil.ReverseProxy{
|
rp := newHTTPBridgeReverseProxy(
|
||||||
Rewrite: func(r *httputil.ProxyRequest) {
|
func(r *httputil.ProxyRequest) {
|
||||||
req := r.Out
|
req := r.Out
|
||||||
req.URL.Scheme = "http"
|
rewriteHTTPPluginRequest(req, "http", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
|
||||||
req.URL.Host = p.opts.LocalAddr
|
|
||||||
if p.opts.HostHeaderRewrite != "" {
|
|
||||||
req.Host = p.opts.HostHeaderRewrite
|
|
||||||
}
|
|
||||||
for k, v := range p.opts.RequestHeaders.Set {
|
|
||||||
req.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
BufferPool: pool.NewBuffer(32 * 1024),
|
nil,
|
||||||
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
)
|
||||||
}
|
p.httpBridgePlugin = newHTTPBridgePluginServer(rp, false)
|
||||||
|
|
||||||
p.s = &http.Server{
|
|
||||||
Handler: rp,
|
|
||||||
ReadHeaderTimeout: 60 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_ = p.s.Serve(listener)
|
|
||||||
}()
|
|
||||||
|
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *HTTP2HTTPPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
|
||||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
|
||||||
_ = p.l.PutConn(wrapConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *HTTP2HTTPPlugin) Name() string {
|
func (p *HTTP2HTTPPlugin) Name() string {
|
||||||
return v1.PluginHTTP2HTTP
|
return v1.PluginHTTP2HTTP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *HTTP2HTTPPlugin) Close() error {
|
|
||||||
return p.s.Close()
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -17,18 +17,11 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
stdlog "log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fatedier/golib/pool"
|
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
"github.com/fatedier/frp/pkg/util/log"
|
|
||||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
@ -38,65 +31,35 @@ func init() {
|
||||||
type HTTP2HTTPSPlugin struct {
|
type HTTP2HTTPSPlugin struct {
|
||||||
opts *v1.HTTP2HTTPSPluginOptions
|
opts *v1.HTTP2HTTPSPluginOptions
|
||||||
|
|
||||||
l *Listener
|
*httpBridgePlugin
|
||||||
s *http.Server
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTP2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
func NewHTTP2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
||||||
opts := options.(*v1.HTTP2HTTPSPluginOptions)
|
opts := options.(*v1.HTTP2HTTPSPluginOptions)
|
||||||
|
|
||||||
listener := NewProxyListener()
|
|
||||||
|
|
||||||
p := &HTTP2HTTPSPlugin{
|
p := &HTTP2HTTPSPlugin{
|
||||||
opts: opts,
|
opts: opts,
|
||||||
l: listener,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tr := &http.Transport{
|
tr := &http.Transport{
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
}
|
}
|
||||||
|
|
||||||
rp := &httputil.ReverseProxy{
|
rp := newHTTPBridgeReverseProxy(
|
||||||
Rewrite: func(r *httputil.ProxyRequest) {
|
func(r *httputil.ProxyRequest) {
|
||||||
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
||||||
r.Out.Header["X-Forwarded-Host"] = r.In.Header["X-Forwarded-Host"]
|
r.Out.Header["X-Forwarded-Host"] = r.In.Header["X-Forwarded-Host"]
|
||||||
r.Out.Header["X-Forwarded-Proto"] = r.In.Header["X-Forwarded-Proto"]
|
r.Out.Header["X-Forwarded-Proto"] = r.In.Header["X-Forwarded-Proto"]
|
||||||
req := r.Out
|
req := r.Out
|
||||||
req.URL.Scheme = "https"
|
rewriteHTTPPluginRequest(req, "https", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
|
||||||
req.URL.Host = p.opts.LocalAddr
|
|
||||||
if p.opts.HostHeaderRewrite != "" {
|
|
||||||
req.Host = p.opts.HostHeaderRewrite
|
|
||||||
}
|
|
||||||
for k, v := range p.opts.RequestHeaders.Set {
|
|
||||||
req.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
Transport: tr,
|
tr,
|
||||||
BufferPool: pool.NewBuffer(32 * 1024),
|
)
|
||||||
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
p.httpBridgePlugin = newHTTPBridgePluginServer(rp, false)
|
||||||
}
|
|
||||||
|
|
||||||
p.s = &http.Server{
|
|
||||||
Handler: rp,
|
|
||||||
ReadHeaderTimeout: 60 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_ = p.s.Serve(listener)
|
|
||||||
}()
|
|
||||||
|
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *HTTP2HTTPSPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
|
||||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
|
||||||
_ = p.l.PutConn(wrapConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *HTTP2HTTPSPlugin) Name() string {
|
func (p *HTTP2HTTPSPlugin) Name() string {
|
||||||
return v1.PluginHTTP2HTTPS
|
return v1.PluginHTTP2HTTPS
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *HTTP2HTTPSPlugin) Close() error {
|
|
||||||
return p.s.Close()
|
|
||||||
}
|
|
||||||
|
|
|
||||||
126
pkg/plugin/client/http_common.go
Normal file
126
pkg/plugin/client/http_common.go
Normal file
|
|
@ -0,0 +1,126 @@
|
||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
//go:build !frps
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
stdlog "log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fatedier/golib/pool"
|
||||||
|
|
||||||
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
|
"github.com/fatedier/frp/pkg/plugin/client/internal/httpsserver"
|
||||||
|
"github.com/fatedier/frp/pkg/util/log"
|
||||||
|
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
const httpBridgeReadHeaderTimeout = 60 * time.Second
|
||||||
|
|
||||||
|
func rewriteHTTPPluginRequest(
|
||||||
|
req *http.Request,
|
||||||
|
scheme string,
|
||||||
|
localAddr string,
|
||||||
|
hostHeaderRewrite string,
|
||||||
|
requestHeaders v1.HeaderOperations,
|
||||||
|
) {
|
||||||
|
req.URL.Scheme = scheme
|
||||||
|
req.URL.Host = localAddr
|
||||||
|
if hostHeaderRewrite != "" {
|
||||||
|
req.Host = hostHeaderRewrite
|
||||||
|
}
|
||||||
|
for k, v := range requestHeaders.Set {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpBridgePlugin struct {
|
||||||
|
l *Listener
|
||||||
|
s *http.Server
|
||||||
|
|
||||||
|
useSourceRemoteAddr bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPBridgePluginServer(handler http.Handler, useSourceRemoteAddr bool) *httpBridgePlugin {
|
||||||
|
listener := NewProxyListener()
|
||||||
|
p := &httpBridgePlugin{
|
||||||
|
l: listener,
|
||||||
|
useSourceRemoteAddr: useSourceRemoteAddr,
|
||||||
|
}
|
||||||
|
p.s = &http.Server{
|
||||||
|
Handler: handler,
|
||||||
|
ReadHeaderTimeout: httpBridgeReadHeaderTimeout,
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
_ = p.s.Serve(listener)
|
||||||
|
}()
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPSBridgePluginServer(
|
||||||
|
handler http.Handler,
|
||||||
|
crtPath string,
|
||||||
|
keyPath string,
|
||||||
|
enableHTTP2 *bool,
|
||||||
|
useSourceRemoteAddr bool,
|
||||||
|
) (*httpBridgePlugin, error) {
|
||||||
|
listener := NewProxyListener()
|
||||||
|
server, err := httpsserver.New(handler, crtPath, keyPath, enableHTTP2)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
p := &httpBridgePlugin{
|
||||||
|
l: listener,
|
||||||
|
s: server,
|
||||||
|
useSourceRemoteAddr: useSourceRemoteAddr,
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
_ = p.s.ServeTLS(listener, "", "")
|
||||||
|
}()
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPBridgeReverseProxy(
|
||||||
|
rewrite func(*httputil.ProxyRequest),
|
||||||
|
transport http.RoundTripper,
|
||||||
|
) *httputil.ReverseProxy {
|
||||||
|
rp := &httputil.ReverseProxy{
|
||||||
|
Rewrite: rewrite,
|
||||||
|
BufferPool: pool.NewBuffer(32 * 1024),
|
||||||
|
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
||||||
|
}
|
||||||
|
if transport != nil {
|
||||||
|
rp.Transport = transport
|
||||||
|
}
|
||||||
|
return rp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *httpBridgePlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
||||||
|
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
||||||
|
if p.useSourceRemoteAddr && connInfo.SrcAddr != nil {
|
||||||
|
wrapConn.SetRemoteAddr(connInfo.SrcAddr)
|
||||||
|
}
|
||||||
|
_ = p.l.PutConn(wrapConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *httpBridgePlugin) Close() error {
|
||||||
|
err := p.s.Close()
|
||||||
|
_ = p.l.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
@ -17,22 +17,9 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
stdlog "log"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fatedier/golib/pool"
|
|
||||||
"github.com/samber/lo"
|
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
"github.com/fatedier/frp/pkg/transport"
|
|
||||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
|
||||||
"github.com/fatedier/frp/pkg/util/log"
|
|
||||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
@ -42,80 +29,35 @@ func init() {
|
||||||
type HTTPS2HTTPPlugin struct {
|
type HTTPS2HTTPPlugin struct {
|
||||||
opts *v1.HTTPS2HTTPPluginOptions
|
opts *v1.HTTPS2HTTPPluginOptions
|
||||||
|
|
||||||
l *Listener
|
*httpBridgePlugin
|
||||||
s *http.Server
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTPS2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
func NewHTTPS2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
||||||
opts := options.(*v1.HTTPS2HTTPPluginOptions)
|
opts := options.(*v1.HTTPS2HTTPPluginOptions)
|
||||||
listener := NewProxyListener()
|
|
||||||
|
|
||||||
p := &HTTPS2HTTPPlugin{
|
p := &HTTPS2HTTPPlugin{
|
||||||
opts: opts,
|
opts: opts,
|
||||||
l: listener,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rp := &httputil.ReverseProxy{
|
rp := newHTTPBridgeReverseProxy(
|
||||||
Rewrite: func(r *httputil.ProxyRequest) {
|
func(r *httputil.ProxyRequest) {
|
||||||
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
||||||
r.SetXForwarded()
|
r.SetXForwarded()
|
||||||
req := r.Out
|
req := r.Out
|
||||||
req.URL.Scheme = "http"
|
rewriteHTTPPluginRequest(req, "http", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
|
||||||
req.URL.Host = p.opts.LocalAddr
|
|
||||||
if p.opts.HostHeaderRewrite != "" {
|
|
||||||
req.Host = p.opts.HostHeaderRewrite
|
|
||||||
}
|
|
||||||
for k, v := range p.opts.RequestHeaders.Set {
|
|
||||||
req.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
BufferPool: pool.NewBuffer(32 * 1024),
|
nil,
|
||||||
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
)
|
||||||
}
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.TLS != nil {
|
|
||||||
tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName)
|
|
||||||
host, _ := httppkg.CanonicalHost(r.Host)
|
|
||||||
if tlsServerName != "" && tlsServerName != host {
|
|
||||||
w.WriteHeader(http.StatusMisdirectedRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rp.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
|
|
||||||
tlsConfig, err := transport.NewServerTLSConfig(p.opts.CrtPath, p.opts.KeyPath, "")
|
server, err := newHTTPSBridgePluginServer(rp, p.opts.CrtPath, p.opts.KeyPath, opts.EnableHTTP2, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("gen TLS config error: %v", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
p.httpBridgePlugin = server
|
||||||
|
|
||||||
p.s = &http.Server{
|
|
||||||
Handler: handler,
|
|
||||||
ReadHeaderTimeout: 60 * time.Second,
|
|
||||||
TLSConfig: tlsConfig,
|
|
||||||
}
|
|
||||||
if !lo.FromPtr(opts.EnableHTTP2) {
|
|
||||||
p.s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_ = p.s.ServeTLS(listener, "", "")
|
|
||||||
}()
|
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *HTTPS2HTTPPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
|
||||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
|
||||||
if connInfo.SrcAddr != nil {
|
|
||||||
wrapConn.SetRemoteAddr(connInfo.SrcAddr)
|
|
||||||
}
|
|
||||||
_ = p.l.PutConn(wrapConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *HTTPS2HTTPPlugin) Name() string {
|
func (p *HTTPS2HTTPPlugin) Name() string {
|
||||||
return v1.PluginHTTPS2HTTP
|
return v1.PluginHTTPS2HTTP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *HTTPS2HTTPPlugin) Close() error {
|
|
||||||
return p.s.Close()
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -17,22 +17,11 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
|
||||||
stdlog "log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fatedier/golib/pool"
|
|
||||||
"github.com/samber/lo"
|
|
||||||
|
|
||||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||||
"github.com/fatedier/frp/pkg/transport"
|
|
||||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
|
||||||
"github.com/fatedier/frp/pkg/util/log"
|
|
||||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
@ -42,86 +31,39 @@ func init() {
|
||||||
type HTTPS2HTTPSPlugin struct {
|
type HTTPS2HTTPSPlugin struct {
|
||||||
opts *v1.HTTPS2HTTPSPluginOptions
|
opts *v1.HTTPS2HTTPSPluginOptions
|
||||||
|
|
||||||
l *Listener
|
*httpBridgePlugin
|
||||||
s *http.Server
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTPS2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
func NewHTTPS2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
||||||
opts := options.(*v1.HTTPS2HTTPSPluginOptions)
|
opts := options.(*v1.HTTPS2HTTPSPluginOptions)
|
||||||
|
|
||||||
listener := NewProxyListener()
|
|
||||||
|
|
||||||
p := &HTTPS2HTTPSPlugin{
|
p := &HTTPS2HTTPSPlugin{
|
||||||
opts: opts,
|
opts: opts,
|
||||||
l: listener,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tr := &http.Transport{
|
tr := &http.Transport{
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
}
|
}
|
||||||
|
|
||||||
rp := &httputil.ReverseProxy{
|
rp := newHTTPBridgeReverseProxy(
|
||||||
Rewrite: func(r *httputil.ProxyRequest) {
|
func(r *httputil.ProxyRequest) {
|
||||||
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
||||||
r.SetXForwarded()
|
r.SetXForwarded()
|
||||||
req := r.Out
|
req := r.Out
|
||||||
req.URL.Scheme = "https"
|
rewriteHTTPPluginRequest(req, "https", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
|
||||||
req.URL.Host = p.opts.LocalAddr
|
|
||||||
if p.opts.HostHeaderRewrite != "" {
|
|
||||||
req.Host = p.opts.HostHeaderRewrite
|
|
||||||
}
|
|
||||||
for k, v := range p.opts.RequestHeaders.Set {
|
|
||||||
req.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
Transport: tr,
|
tr,
|
||||||
BufferPool: pool.NewBuffer(32 * 1024),
|
)
|
||||||
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
|
||||||
}
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.TLS != nil {
|
|
||||||
tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName)
|
|
||||||
host, _ := httppkg.CanonicalHost(r.Host)
|
|
||||||
if tlsServerName != "" && tlsServerName != host {
|
|
||||||
w.WriteHeader(http.StatusMisdirectedRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rp.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
|
|
||||||
tlsConfig, err := transport.NewServerTLSConfig(p.opts.CrtPath, p.opts.KeyPath, "")
|
server, err := newHTTPSBridgePluginServer(rp, p.opts.CrtPath, p.opts.KeyPath, opts.EnableHTTP2, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("gen TLS config error: %v", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
p.httpBridgePlugin = server
|
||||||
|
|
||||||
p.s = &http.Server{
|
|
||||||
Handler: handler,
|
|
||||||
ReadHeaderTimeout: 60 * time.Second,
|
|
||||||
TLSConfig: tlsConfig,
|
|
||||||
}
|
|
||||||
if !lo.FromPtr(opts.EnableHTTP2) {
|
|
||||||
p.s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_ = p.s.ServeTLS(listener, "", "")
|
|
||||||
}()
|
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *HTTPS2HTTPSPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
|
||||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
|
||||||
if connInfo.SrcAddr != nil {
|
|
||||||
wrapConn.SetRemoteAddr(connInfo.SrcAddr)
|
|
||||||
}
|
|
||||||
_ = p.l.PutConn(wrapConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *HTTPS2HTTPSPlugin) Name() string {
|
func (p *HTTPS2HTTPSPlugin) Name() string {
|
||||||
return v1.PluginHTTPS2HTTPS
|
return v1.PluginHTTPS2HTTPS
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *HTTPS2HTTPSPlugin) Close() error {
|
|
||||||
return p.s.Close()
|
|
||||||
}
|
|
||||||
|
|
|
||||||
60
pkg/plugin/client/internal/httpsserver/server.go
Normal file
60
pkg/plugin/client/internal/httpsserver/server.go
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
// Copyright 2026 The frp Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
//go:build !frps
|
||||||
|
|
||||||
|
package httpsserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/samber/lo"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/transport"
|
||||||
|
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func New(handler http.Handler, crtPath, keyPath string, enableHTTP2 *bool) (*http.Server, error) {
|
||||||
|
tlsConfig, err := transport.NewServerTLSConfig(crtPath, keyPath, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gen TLS config error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := &http.Server{
|
||||||
|
Handler: withMisdirectedRequestCheck(handler),
|
||||||
|
ReadHeaderTimeout: 60 * time.Second,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
}
|
||||||
|
if !lo.FromPtr(enableHTTP2) {
|
||||||
|
server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
|
||||||
|
}
|
||||||
|
return server, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func withMisdirectedRequestCheck(handler http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.TLS != nil {
|
||||||
|
tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName)
|
||||||
|
host, _ := httppkg.CanonicalHost(r.Host)
|
||||||
|
if tlsServerName != "" && tlsServerName != host {
|
||||||
|
w.WriteHeader(http.StatusMisdirectedRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -44,6 +44,67 @@ func NewManager() *Manager {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newPluginRequestContext() (context.Context, *xlog.Logger) {
|
||||||
|
reqid, _ := util.RandID()
|
||||||
|
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||||
|
ctx := xlog.NewContext(context.Background(), xl)
|
||||||
|
return NewReqidContext(ctx, reqid), xl
|
||||||
|
}
|
||||||
|
|
||||||
|
type pluginErrorLogMode bool
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Warn is the zero value because it is the default for mutable plugin operations.
|
||||||
|
pluginErrorLogWarn pluginErrorLogMode = false
|
||||||
|
pluginErrorLogInfo pluginErrorLogMode = true
|
||||||
|
)
|
||||||
|
|
||||||
|
func logPluginError(xl *xlog.Logger, p Plugin, op string, err error, mode pluginErrorLogMode) {
|
||||||
|
if mode == pluginErrorLogInfo {
|
||||||
|
xl.Infof("send %s request to plugin [%s] error: %v", op, p.Name(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
xl.Warnf("send %s request to plugin [%s] error: %v", op, p.Name(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleMutableContent[T any](
|
||||||
|
plugins []Plugin,
|
||||||
|
op string,
|
||||||
|
content *T,
|
||||||
|
logMode pluginErrorLogMode,
|
||||||
|
) (*T, error) {
|
||||||
|
if len(plugins) == 0 {
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
res = &Response{
|
||||||
|
Reject: false,
|
||||||
|
Unchange: true,
|
||||||
|
}
|
||||||
|
retContent any
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
ctx, xl := newPluginRequestContext()
|
||||||
|
|
||||||
|
for _, p := range plugins {
|
||||||
|
res, retContent, err = p.Handle(ctx, op, *content)
|
||||||
|
if err != nil {
|
||||||
|
logPluginError(xl, p, op, err, logMode)
|
||||||
|
return nil, errors.New("send " + op + " request to plugin error")
|
||||||
|
}
|
||||||
|
if res.Reject {
|
||||||
|
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||||
|
}
|
||||||
|
if !res.Unchange {
|
||||||
|
// Preserve the existing Plugin contract: changed content must be *T.
|
||||||
|
// Buggy Plugin implementations still panic here, by design.
|
||||||
|
content = retContent.(*T)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) Register(p Plugin) {
|
func (m *Manager) Register(p Plugin) {
|
||||||
if p.IsSupport(OpLogin) {
|
if p.IsSupport(OpLogin) {
|
||||||
m.loginPlugins = append(m.loginPlugins, p)
|
m.loginPlugins = append(m.loginPlugins, p)
|
||||||
|
|
@ -66,71 +127,11 @@ func (m *Manager) Register(p Plugin) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) Login(content *LoginContent) (*LoginContent, error) {
|
func (m *Manager) Login(content *LoginContent) (*LoginContent, error) {
|
||||||
if len(m.loginPlugins) == 0 {
|
return handleMutableContent(m.loginPlugins, OpLogin, content, pluginErrorLogWarn)
|
||||||
return content, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
res = &Response{
|
|
||||||
Reject: false,
|
|
||||||
Unchange: true,
|
|
||||||
}
|
|
||||||
retContent any
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
reqid, _ := util.RandID()
|
|
||||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
|
||||||
ctx := xlog.NewContext(context.Background(), xl)
|
|
||||||
ctx = NewReqidContext(ctx, reqid)
|
|
||||||
|
|
||||||
for _, p := range m.loginPlugins {
|
|
||||||
res, retContent, err = p.Handle(ctx, OpLogin, *content)
|
|
||||||
if err != nil {
|
|
||||||
xl.Warnf("send Login request to plugin [%s] error: %v", p.Name(), err)
|
|
||||||
return nil, errors.New("send Login request to plugin error")
|
|
||||||
}
|
|
||||||
if res.Reject {
|
|
||||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
|
||||||
}
|
|
||||||
if !res.Unchange {
|
|
||||||
content = retContent.(*LoginContent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return content, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) NewProxy(content *NewProxyContent) (*NewProxyContent, error) {
|
func (m *Manager) NewProxy(content *NewProxyContent) (*NewProxyContent, error) {
|
||||||
if len(m.newProxyPlugins) == 0 {
|
return handleMutableContent(m.newProxyPlugins, OpNewProxy, content, pluginErrorLogWarn)
|
||||||
return content, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
res = &Response{
|
|
||||||
Reject: false,
|
|
||||||
Unchange: true,
|
|
||||||
}
|
|
||||||
retContent any
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
reqid, _ := util.RandID()
|
|
||||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
|
||||||
ctx := xlog.NewContext(context.Background(), xl)
|
|
||||||
ctx = NewReqidContext(ctx, reqid)
|
|
||||||
|
|
||||||
for _, p := range m.newProxyPlugins {
|
|
||||||
res, retContent, err = p.Handle(ctx, OpNewProxy, *content)
|
|
||||||
if err != nil {
|
|
||||||
xl.Warnf("send NewProxy request to plugin [%s] error: %v", p.Name(), err)
|
|
||||||
return nil, errors.New("send NewProxy request to plugin error")
|
|
||||||
}
|
|
||||||
if res.Reject {
|
|
||||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
|
||||||
}
|
|
||||||
if !res.Unchange {
|
|
||||||
content = retContent.(*NewProxyContent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return content, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) CloseProxy(content *CloseProxyContent) error {
|
func (m *Manager) CloseProxy(content *CloseProxyContent) error {
|
||||||
|
|
@ -139,10 +140,7 @@ func (m *Manager) CloseProxy(content *CloseProxyContent) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
errs := make([]string, 0)
|
errs := make([]string, 0)
|
||||||
reqid, _ := util.RandID()
|
ctx, xl := newPluginRequestContext()
|
||||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
|
||||||
ctx := xlog.NewContext(context.Background(), xl)
|
|
||||||
ctx = NewReqidContext(ctx, reqid)
|
|
||||||
|
|
||||||
for _, p := range m.closeProxyPlugins {
|
for _, p := range m.closeProxyPlugins {
|
||||||
_, _, err := p.Handle(ctx, OpCloseProxy, *content)
|
_, _, err := p.Handle(ctx, OpCloseProxy, *content)
|
||||||
|
|
@ -159,103 +157,14 @@ func (m *Manager) CloseProxy(content *CloseProxyContent) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) Ping(content *PingContent) (*PingContent, error) {
|
func (m *Manager) Ping(content *PingContent) (*PingContent, error) {
|
||||||
if len(m.pingPlugins) == 0 {
|
return handleMutableContent(m.pingPlugins, OpPing, content, pluginErrorLogWarn)
|
||||||
return content, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
res = &Response{
|
|
||||||
Reject: false,
|
|
||||||
Unchange: true,
|
|
||||||
}
|
|
||||||
retContent any
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
reqid, _ := util.RandID()
|
|
||||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
|
||||||
ctx := xlog.NewContext(context.Background(), xl)
|
|
||||||
ctx = NewReqidContext(ctx, reqid)
|
|
||||||
|
|
||||||
for _, p := range m.pingPlugins {
|
|
||||||
res, retContent, err = p.Handle(ctx, OpPing, *content)
|
|
||||||
if err != nil {
|
|
||||||
xl.Warnf("send Ping request to plugin [%s] error: %v", p.Name(), err)
|
|
||||||
return nil, errors.New("send Ping request to plugin error")
|
|
||||||
}
|
|
||||||
if res.Reject {
|
|
||||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
|
||||||
}
|
|
||||||
if !res.Unchange {
|
|
||||||
content = retContent.(*PingContent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return content, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) NewWorkConn(content *NewWorkConnContent) (*NewWorkConnContent, error) {
|
func (m *Manager) NewWorkConn(content *NewWorkConnContent) (*NewWorkConnContent, error) {
|
||||||
if len(m.newWorkConnPlugins) == 0 {
|
return handleMutableContent(m.newWorkConnPlugins, OpNewWorkConn, content, pluginErrorLogWarn)
|
||||||
return content, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
res = &Response{
|
|
||||||
Reject: false,
|
|
||||||
Unchange: true,
|
|
||||||
}
|
|
||||||
retContent any
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
reqid, _ := util.RandID()
|
|
||||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
|
||||||
ctx := xlog.NewContext(context.Background(), xl)
|
|
||||||
ctx = NewReqidContext(ctx, reqid)
|
|
||||||
|
|
||||||
for _, p := range m.newWorkConnPlugins {
|
|
||||||
res, retContent, err = p.Handle(ctx, OpNewWorkConn, *content)
|
|
||||||
if err != nil {
|
|
||||||
xl.Warnf("send NewWorkConn request to plugin [%s] error: %v", p.Name(), err)
|
|
||||||
return nil, errors.New("send NewWorkConn request to plugin error")
|
|
||||||
}
|
|
||||||
if res.Reject {
|
|
||||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
|
||||||
}
|
|
||||||
if !res.Unchange {
|
|
||||||
content = retContent.(*NewWorkConnContent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return content, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) NewUserConn(content *NewUserConnContent) (*NewUserConnContent, error) {
|
func (m *Manager) NewUserConn(content *NewUserConnContent) (*NewUserConnContent, error) {
|
||||||
if len(m.newUserConnPlugins) == 0 {
|
// Preserve the pre-refactor log level for NewUserConn plugin errors.
|
||||||
return content, nil
|
return handleMutableContent(m.newUserConnPlugins, OpNewUserConn, content, pluginErrorLogInfo)
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
res = &Response{
|
|
||||||
Reject: false,
|
|
||||||
Unchange: true,
|
|
||||||
}
|
|
||||||
retContent any
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
reqid, _ := util.RandID()
|
|
||||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
|
||||||
ctx := xlog.NewContext(context.Background(), xl)
|
|
||||||
ctx = NewReqidContext(ctx, reqid)
|
|
||||||
|
|
||||||
for _, p := range m.newUserConnPlugins {
|
|
||||||
res, retContent, err = p.Handle(ctx, OpNewUserConn, *content)
|
|
||||||
if err != nil {
|
|
||||||
xl.Infof("send NewUserConn request to plugin [%s] error: %v", p.Name(), err)
|
|
||||||
return nil, errors.New("send NewUserConn request to plugin error")
|
|
||||||
}
|
|
||||||
if res.Reject {
|
|
||||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
|
||||||
}
|
|
||||||
if !res.Unchange {
|
|
||||||
content = retContent.(*NewUserConnContent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return content, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
336
pkg/plugin/server/manager_test.go
Normal file
336
pkg/plugin/server/manager_test.go
Normal file
|
|
@ -0,0 +1,336 @@
|
||||||
|
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
goliblog "github.com/fatedier/golib/log"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/msg"
|
||||||
|
frplog "github.com/fatedier/frp/pkg/util/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testPlugin struct {
|
||||||
|
name string
|
||||||
|
ops map[string]bool
|
||||||
|
handler func(context.Context, string, any) (*Response, any, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log-capturing subtests serialize global logger swaps; do not use t.Parallel.
|
||||||
|
var logCaptureMu sync.Mutex
|
||||||
|
|
||||||
|
type logCapture struct {
|
||||||
|
bytes.Buffer
|
||||||
|
levels []goliblog.Level
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p testPlugin) Name() string {
|
||||||
|
return p.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p testPlugin) IsSupport(op string) bool {
|
||||||
|
return p.ops[op]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p testPlugin) Handle(ctx context.Context, op string, content any) (*Response, any, error) {
|
||||||
|
return p.handler(ctx, op, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *logCapture) WriteLog(p []byte, level goliblog.Level, _ time.Time) (int, error) {
|
||||||
|
w.levels = append(w.levels, level)
|
||||||
|
return w.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func captureLogOutput(t *testing.T) *logCapture {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
logCaptureMu.Lock()
|
||||||
|
logOutput := &logCapture{}
|
||||||
|
oldLogger := frplog.Logger
|
||||||
|
frplog.Logger = goliblog.New(
|
||||||
|
goliblog.WithOutput(logOutput),
|
||||||
|
goliblog.WithLevel(goliblog.TraceLevel),
|
||||||
|
goliblog.WithCaller(false),
|
||||||
|
)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
frplog.Logger = oldLogger
|
||||||
|
logCaptureMu.Unlock()
|
||||||
|
})
|
||||||
|
return logOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
var mutablePluginOps = []struct {
|
||||||
|
name string
|
||||||
|
op string
|
||||||
|
}{
|
||||||
|
{name: "login", op: OpLogin},
|
||||||
|
{name: "new proxy", op: OpNewProxy},
|
||||||
|
{name: "ping", op: OpPing},
|
||||||
|
{name: "new work conn", op: OpNewWorkConn},
|
||||||
|
{name: "new user conn", op: OpNewUserConn},
|
||||||
|
}
|
||||||
|
|
||||||
|
func callMutableWithUser(m *Manager, op string, user string) (string, error) {
|
||||||
|
switch op {
|
||||||
|
case OpLogin:
|
||||||
|
got, err := m.Login(&LoginContent{Login: msg.Login{User: user}})
|
||||||
|
if got == nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return got.User, err
|
||||||
|
case OpNewProxy:
|
||||||
|
got, err := m.NewProxy(&NewProxyContent{User: UserInfo{User: user}})
|
||||||
|
if got == nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return got.User.User, err
|
||||||
|
case OpPing:
|
||||||
|
got, err := m.Ping(&PingContent{User: UserInfo{User: user}})
|
||||||
|
if got == nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return got.User.User, err
|
||||||
|
case OpNewWorkConn:
|
||||||
|
got, err := m.NewWorkConn(&NewWorkConnContent{User: UserInfo{User: user}})
|
||||||
|
if got == nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return got.User.User, err
|
||||||
|
case OpNewUserConn:
|
||||||
|
got, err := m.NewUserConn(&NewUserConnContent{User: UserInfo{User: user}})
|
||||||
|
if got == nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return got.User.User, err
|
||||||
|
default:
|
||||||
|
panic("unsupported mutable op: " + op)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutableUser(t *testing.T, op string, content any) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
switch op {
|
||||||
|
case OpLogin:
|
||||||
|
return content.(LoginContent).User
|
||||||
|
case OpNewProxy:
|
||||||
|
return content.(NewProxyContent).User.User
|
||||||
|
case OpPing:
|
||||||
|
return content.(PingContent).User.User
|
||||||
|
case OpNewWorkConn:
|
||||||
|
return content.(NewWorkConnContent).User.User
|
||||||
|
case OpNewUserConn:
|
||||||
|
return content.(NewUserConnContent).User.User
|
||||||
|
default:
|
||||||
|
t.Fatalf("unsupported mutable op: %s", op)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateMutableContent(t *testing.T, op string, content any, user string) any {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
switch op {
|
||||||
|
case OpLogin:
|
||||||
|
got := content.(LoginContent)
|
||||||
|
got.User = user
|
||||||
|
return &got
|
||||||
|
case OpNewProxy:
|
||||||
|
got := content.(NewProxyContent)
|
||||||
|
got.User.User = user
|
||||||
|
return &got
|
||||||
|
case OpPing:
|
||||||
|
got := content.(PingContent)
|
||||||
|
got.User.User = user
|
||||||
|
return &got
|
||||||
|
case OpNewWorkConn:
|
||||||
|
got := content.(NewWorkConnContent)
|
||||||
|
got.User.User = user
|
||||||
|
return &got
|
||||||
|
case OpNewUserConn:
|
||||||
|
got := content.(NewUserConnContent)
|
||||||
|
got.User.User = user
|
||||||
|
return &got
|
||||||
|
default:
|
||||||
|
t.Fatalf("unsupported mutable op: %s", op)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerMutableContentAcrossOps(t *testing.T) {
|
||||||
|
for _, tt := range mutablePluginOps {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
m.Register(testPlugin{
|
||||||
|
name: "mutate",
|
||||||
|
ops: map[string]bool{tt.op: true},
|
||||||
|
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
|
||||||
|
if op != tt.op {
|
||||||
|
t.Fatalf("unexpected op: %s", op)
|
||||||
|
}
|
||||||
|
if GetReqidFromContext(ctx) == "" {
|
||||||
|
t.Fatal("expected request id in context")
|
||||||
|
}
|
||||||
|
if got := mutableUser(t, tt.op, content); got != "initial" {
|
||||||
|
t.Fatalf("expected initial user, got %q", got)
|
||||||
|
}
|
||||||
|
return &Response{Unchange: false}, mutateMutableContent(t, tt.op, content, "mutated"), nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
m.Register(testPlugin{
|
||||||
|
name: "observe",
|
||||||
|
ops: map[string]bool{tt.op: true},
|
||||||
|
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
|
||||||
|
if op != tt.op {
|
||||||
|
t.Fatalf("unexpected op: %s", op)
|
||||||
|
}
|
||||||
|
if GetReqidFromContext(ctx) == "" {
|
||||||
|
t.Fatal("expected request id in context")
|
||||||
|
}
|
||||||
|
if got := mutableUser(t, tt.op, content); got != "mutated" {
|
||||||
|
t.Fatalf("expected mutated user, got %q", got)
|
||||||
|
}
|
||||||
|
return &Response{Unchange: true}, mutateMutableContent(t, tt.op, content, "ignored"), nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := callMutableWithUser(m, tt.op, "initial")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("mutable op failed: %v", err)
|
||||||
|
}
|
||||||
|
if got != "mutated" {
|
||||||
|
t.Fatalf("expected mutated user, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerMutableContentRejectStopsChain(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
|
||||||
|
var called bool
|
||||||
|
m.Register(testPlugin{
|
||||||
|
name: "reject",
|
||||||
|
ops: map[string]bool{OpPing: true},
|
||||||
|
handler: func(context.Context, string, any) (*Response, any, error) {
|
||||||
|
return &Response{Reject: true, RejectReason: "blocked"}, nil, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
m.Register(testPlugin{
|
||||||
|
name: "unused",
|
||||||
|
ops: map[string]bool{OpPing: true},
|
||||||
|
handler: func(context.Context, string, any) (*Response, any, error) {
|
||||||
|
called = true
|
||||||
|
return &Response{Unchange: true}, nil, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := m.Ping(&PingContent{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected reject error")
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("expected no returned content, got %#v", got)
|
||||||
|
}
|
||||||
|
if err.Error() != "blocked" {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if called {
|
||||||
|
t.Fatal("expected plugin chain to stop after reject")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerMutableContentPluginErrorLogLevel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
op string
|
||||||
|
level goliblog.Level
|
||||||
|
}{
|
||||||
|
{name: "default warning", op: OpLogin, level: goliblog.WarnLevel},
|
||||||
|
{name: "new user conn info", op: OpNewUserConn, level: goliblog.InfoLevel},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
logOutput := captureLogOutput(t)
|
||||||
|
m := NewManager()
|
||||||
|
m.Register(testPlugin{
|
||||||
|
name: "error",
|
||||||
|
ops: map[string]bool{tt.op: true},
|
||||||
|
handler: func(context.Context, string, any) (*Response, any, error) {
|
||||||
|
return nil, nil, errors.New("boom")
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := callMutableWithUser(m, tt.op, "initial")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected plugin error")
|
||||||
|
}
|
||||||
|
if want := "send " + tt.op + " request to plugin error"; err.Error() != want {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(logOutput.levels) != 1 || logOutput.levels[0] != tt.level {
|
||||||
|
t.Fatalf("expected log level %v, got %v in %q", tt.level, logOutput.levels, logOutput.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerCloseProxyAggregatesErrors(t *testing.T) {
|
||||||
|
logOutput := captureLogOutput(t)
|
||||||
|
m := NewManager()
|
||||||
|
|
||||||
|
for _, name := range []string{"first", "second"} {
|
||||||
|
m.Register(testPlugin{
|
||||||
|
name: name,
|
||||||
|
ops: map[string]bool{OpCloseProxy: true},
|
||||||
|
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
|
||||||
|
if GetReqidFromContext(ctx) == "" {
|
||||||
|
t.Fatal("expected request id in context")
|
||||||
|
}
|
||||||
|
if op != OpCloseProxy {
|
||||||
|
t.Fatalf("unexpected op: %s", op)
|
||||||
|
}
|
||||||
|
return nil, nil, errors.New(name + " error")
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.CloseProxy(&CloseProxyContent{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected close proxy error")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(err.Error(), "send CloseProxy request to plugin errors: ") {
|
||||||
|
t.Fatalf("unexpected close proxy error prefix: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "[first]: first error") || !strings.Contains(err.Error(), "[second]: second error") {
|
||||||
|
t.Fatalf("missing aggregated errors: %v", err)
|
||||||
|
}
|
||||||
|
if len(logOutput.levels) != 2 {
|
||||||
|
t.Fatalf("expected two warning logs, got %v", logOutput.levels)
|
||||||
|
}
|
||||||
|
for _, level := range logOutput.levels {
|
||||||
|
if level != goliblog.WarnLevel {
|
||||||
|
t.Fatalf("expected warning log level, got %v", logOutput.levels)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -17,6 +17,8 @@ package metric
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"k8s.io/utils/clock"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DateCounter interface {
|
type DateCounter interface {
|
||||||
|
|
@ -38,27 +40,33 @@ func NewDateCounter(reserveDays int64) DateCounter {
|
||||||
type StandardDateCounter struct {
|
type StandardDateCounter struct {
|
||||||
reserveDays int64
|
reserveDays int64
|
||||||
counts []int64
|
counts []int64
|
||||||
|
clock clock.PassiveClock
|
||||||
|
|
||||||
lastUpdateDate time.Time
|
lastUpdateDate time.Time
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newStandardDateCounter(reserveDays int64) *StandardDateCounter {
|
func newStandardDateCounter(reserveDays int64) *StandardDateCounter {
|
||||||
now := time.Now()
|
return newStandardDateCounterWithClock(reserveDays, clock.RealClock{})
|
||||||
now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
}
|
||||||
s := &StandardDateCounter{
|
|
||||||
|
func newStandardDateCounterWithClock(reserveDays int64, clk clock.PassiveClock) *StandardDateCounter {
|
||||||
|
if clk == nil {
|
||||||
|
clk = clock.RealClock{}
|
||||||
|
}
|
||||||
|
return &StandardDateCounter{
|
||||||
reserveDays: reserveDays,
|
reserveDays: reserveDays,
|
||||||
counts: make([]int64, reserveDays),
|
counts: make([]int64, reserveDays),
|
||||||
lastUpdateDate: now,
|
clock: clk,
|
||||||
|
lastUpdateDate: startOfDay(clk.Now()),
|
||||||
}
|
}
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *StandardDateCounter) TodayCount() int64 {
|
func (c *StandardDateCounter) TodayCount() int64 {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
c.rotate(time.Now())
|
c.rotate(c.clock.Now())
|
||||||
return c.counts[0]
|
return c.counts[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -70,65 +78,61 @@ func (c *StandardDateCounter) GetLastDaysCount(lastdays int64) []int64 {
|
||||||
|
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
c.rotate(time.Now())
|
c.rotate(c.clock.Now())
|
||||||
for i := 0; i < int(lastdays); i++ {
|
copy(counts, c.counts)
|
||||||
counts[i] = c.counts[i]
|
|
||||||
}
|
|
||||||
return counts
|
return counts
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *StandardDateCounter) Inc(count int64) {
|
func (c *StandardDateCounter) Inc(count int64) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
c.rotate(time.Now())
|
c.rotate(c.clock.Now())
|
||||||
c.counts[0] += count
|
c.counts[0] += count
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *StandardDateCounter) Dec(count int64) {
|
func (c *StandardDateCounter) Dec(count int64) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
c.rotate(time.Now())
|
c.rotate(c.clock.Now())
|
||||||
c.counts[0] -= count
|
c.counts[0] -= count
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *StandardDateCounter) Snapshot() DateCounter {
|
func (c *StandardDateCounter) Snapshot() DateCounter {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
tmp := newStandardDateCounter(c.reserveDays)
|
tmp := newStandardDateCounterWithClock(c.reserveDays, c.clock)
|
||||||
for i := 0; i < int(c.reserveDays); i++ {
|
copy(tmp.counts, c.counts)
|
||||||
tmp.counts[i] = c.counts[i]
|
|
||||||
}
|
|
||||||
return tmp
|
return tmp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *StandardDateCounter) Clear() {
|
func (c *StandardDateCounter) Clear() {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
for i := 0; i < int(c.reserveDays); i++ {
|
clear(c.counts)
|
||||||
c.counts[i] = 0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// rotate
|
// rotate
|
||||||
// Must hold the lock before calling this function.
|
// Must hold the lock before calling this function.
|
||||||
func (c *StandardDateCounter) rotate(now time.Time) {
|
func (c *StandardDateCounter) rotate(now time.Time) {
|
||||||
now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
now = startOfDay(now)
|
||||||
days := int(now.Sub(c.lastUpdateDate).Hours() / 24)
|
days := int(now.Sub(c.lastUpdateDate).Hours() / 24)
|
||||||
|
reserveDays := int(c.reserveDays)
|
||||||
defer func() {
|
|
||||||
c.lastUpdateDate = now
|
|
||||||
}()
|
|
||||||
|
|
||||||
if days <= 0 {
|
if days <= 0 {
|
||||||
return
|
return
|
||||||
} else if days >= int(c.reserveDays) {
|
} else if days >= reserveDays {
|
||||||
c.counts = make([]int64, c.reserveDays)
|
c.counts = make([]int64, c.reserveDays)
|
||||||
|
c.lastUpdateDate = now
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
newCounts := make([]int64, c.reserveDays)
|
newCounts := make([]int64, c.reserveDays)
|
||||||
|
|
||||||
for i := days; i < int(c.reserveDays); i++ {
|
copy(newCounts[days:], c.counts[:reserveDays-days])
|
||||||
newCounts[i] = c.counts[i-days]
|
|
||||||
}
|
|
||||||
c.counts = newCounts
|
c.counts = newCounts
|
||||||
|
c.lastUpdateDate = now
|
||||||
|
}
|
||||||
|
|
||||||
|
// startOfDay returns midnight in t's location.
|
||||||
|
func startOfDay(t time.Time) time.Time {
|
||||||
|
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
package metric
|
package metric
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
clocktesting "k8s.io/utils/clock/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDateCounter(t *testing.T) {
|
func TestDateCounter(t *testing.T) {
|
||||||
|
|
@ -25,3 +28,107 @@ func TestDateCounter(t *testing.T) {
|
||||||
dcTmp := dc.Snapshot()
|
dcTmp := dc.Snapshot()
|
||||||
require.EqualValues(5, dcTmp.TodayCount())
|
require.EqualValues(5, dcTmp.TodayCount())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDateCounterRotate(t *testing.T) {
|
||||||
|
loc := time.FixedZone("test", 8*60*60)
|
||||||
|
lastUpdateDate := time.Date(2026, time.May, 8, 0, 0, 0, 0, loc)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
now time.Time
|
||||||
|
want []int64
|
||||||
|
wantLastUpdateDate time.Time
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "same day",
|
||||||
|
now: time.Date(2026, time.May, 8, 12, 30, 0, 0, loc),
|
||||||
|
want: []int64{10, 7, 3},
|
||||||
|
wantLastUpdateDate: lastUpdateDate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "clock skew",
|
||||||
|
now: time.Date(2026, time.May, 7, 12, 30, 0, 0, loc),
|
||||||
|
want: []int64{10, 7, 3},
|
||||||
|
wantLastUpdateDate: lastUpdateDate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one day",
|
||||||
|
now: time.Date(2026, time.May, 9, 12, 30, 0, 0, loc),
|
||||||
|
want: []int64{0, 10, 7},
|
||||||
|
wantLastUpdateDate: time.Date(2026, time.May, 9, 0, 0, 0, 0, loc),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two days",
|
||||||
|
now: time.Date(2026, time.May, 10, 12, 30, 0, 0, loc),
|
||||||
|
want: []int64{0, 0, 10},
|
||||||
|
wantLastUpdateDate: time.Date(2026, time.May, 10, 0, 0, 0, 0, loc),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all reserved days elapsed",
|
||||||
|
now: time.Date(2026, time.May, 11, 12, 30, 0, 0, loc),
|
||||||
|
want: []int64{0, 0, 0},
|
||||||
|
wantLastUpdateDate: time.Date(2026, time.May, 11, 0, 0, 0, 0, loc),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
dc := newStandardDateCounter(3)
|
||||||
|
dc.counts = []int64{10, 7, 3}
|
||||||
|
dc.lastUpdateDate = lastUpdateDate
|
||||||
|
|
||||||
|
dc.mu.Lock()
|
||||||
|
dc.rotate(tt.now)
|
||||||
|
dc.mu.Unlock()
|
||||||
|
|
||||||
|
require.Equal(tt.want, dc.counts)
|
||||||
|
require.Equal(tt.wantLastUpdateDate, dc.lastUpdateDate)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDateCounterGetLastDaysCountReturnsCopy(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
clk := clocktesting.NewFakeClock(time.Date(2026, time.May, 8, 12, 30, 0, 0, time.Local))
|
||||||
|
dc := newStandardDateCounterWithClock(3, clk)
|
||||||
|
dc.counts = []int64{10, 7, 3}
|
||||||
|
|
||||||
|
counts := dc.GetLastDaysCount(2)
|
||||||
|
require.Equal([]int64{10, 7}, counts)
|
||||||
|
|
||||||
|
counts[0] = 100
|
||||||
|
require.Equal([]int64{10, 7}, dc.GetLastDaysCount(2))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDateCounterClear(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
dc := newStandardDateCounter(3)
|
||||||
|
dc.counts = []int64{10, 7, 3}
|
||||||
|
|
||||||
|
dc.Clear()
|
||||||
|
|
||||||
|
require.Equal([]int64{0, 0, 0}, dc.counts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDateCounterConcurrentAccess(t *testing.T) {
|
||||||
|
clk := clocktesting.NewFakeClock(time.Date(2026, time.May, 8, 12, 30, 0, 0, time.Local))
|
||||||
|
dc := newStandardDateCounterWithClock(3, clk)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for range 8 {
|
||||||
|
wg.Go(func() {
|
||||||
|
for range 100 {
|
||||||
|
dc.Inc(1)
|
||||||
|
dc.Dec(1)
|
||||||
|
_ = dc.TodayCount()
|
||||||
|
_ = dc.GetLastDaysCount(3)
|
||||||
|
_ = dc.Snapshot()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -39,11 +39,14 @@ type HTTPConnectTCPMuxer struct {
|
||||||
func NewHTTPConnectTCPMuxer(listener net.Listener, passthrough bool, timeout time.Duration) (*HTTPConnectTCPMuxer, error) {
|
func NewHTTPConnectTCPMuxer(listener net.Listener, passthrough bool, timeout time.Duration) (*HTTPConnectTCPMuxer, error) {
|
||||||
ret := &HTTPConnectTCPMuxer{passthrough: passthrough}
|
ret := &HTTPConnectTCPMuxer{passthrough: passthrough}
|
||||||
mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, timeout)
|
mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
mux.SetCheckAuthFunc(ret.auth).
|
mux.SetCheckAuthFunc(ret.auth).
|
||||||
SetSuccessHookFunc(ret.sendConnectResponse).
|
SetSuccessHookFunc(ret.sendConnectResponse).
|
||||||
SetFailHookFunc(vhostFailed)
|
SetFailHookFunc(vhostFailed)
|
||||||
ret.Muxer = mux
|
ret.Muxer = mux
|
||||||
return ret, err
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host, httpUser, httpPwd string, err error) {
|
func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host, httpUser, httpPwd string, err error) {
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
libio "github.com/fatedier/golib/io"
|
libio "github.com/fatedier/golib/io"
|
||||||
|
|
@ -160,7 +159,7 @@ func (rp *HTTPReverseProxy) UnRegister(routeCfg RouteConfig) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser string) *RouteConfig {
|
func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser string) *RouteConfig {
|
||||||
vr, ok := rp.getVhost(domain, location, routeByHTTPUser)
|
vr, ok := rp.vhostRouter.getByRoute(domain, location, routeByHTTPUser)
|
||||||
if ok {
|
if ok {
|
||||||
log.Debugf("get new http request host [%s] path [%s] httpuser [%s]", domain, location, routeByHTTPUser)
|
log.Debugf("get new http request host [%s] path [%s] httpuser [%s]", domain, location, routeByHTTPUser)
|
||||||
return vr.payload.(*RouteConfig)
|
return vr.payload.(*RouteConfig)
|
||||||
|
|
@ -171,7 +170,7 @@ func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser str
|
||||||
// CreateConnection create a new connection by route config
|
// CreateConnection create a new connection by route config
|
||||||
func (rp *HTTPReverseProxy) CreateConnection(reqRouteInfo *RequestRouteInfo, byEndpoint bool) (net.Conn, error) {
|
func (rp *HTTPReverseProxy) CreateConnection(reqRouteInfo *RequestRouteInfo, byEndpoint bool) (net.Conn, error) {
|
||||||
host, _ := httppkg.CanonicalHost(reqRouteInfo.Host)
|
host, _ := httppkg.CanonicalHost(reqRouteInfo.Host)
|
||||||
vr, ok := rp.getVhost(host, reqRouteInfo.URL, reqRouteInfo.HTTPUser)
|
vr, ok := rp.vhostRouter.getByRoute(host, reqRouteInfo.URL, reqRouteInfo.HTTPUser)
|
||||||
if ok {
|
if ok {
|
||||||
if byEndpoint {
|
if byEndpoint {
|
||||||
fn := vr.payload.(*RouteConfig).CreateConnByEndpointFn
|
fn := vr.payload.(*RouteConfig).CreateConnByEndpointFn
|
||||||
|
|
@ -208,50 +207,6 @@ func checkRouteAuthByRequest(req *http.Request, rc *RouteConfig) bool {
|
||||||
return ok && user == rc.Username && passwd == rc.Password
|
return ok && user == rc.Username && passwd == rc.Password
|
||||||
}
|
}
|
||||||
|
|
||||||
// getVhost tries to get vhost router by route policy.
|
|
||||||
func (rp *HTTPReverseProxy) getVhost(domain, location, routeByHTTPUser string) (*Router, bool) {
|
|
||||||
findRouter := func(inDomain, inLocation, inRouteByHTTPUser string) (*Router, bool) {
|
|
||||||
vr, ok := rp.vhostRouter.Get(inDomain, inLocation, inRouteByHTTPUser)
|
|
||||||
if ok {
|
|
||||||
return vr, ok
|
|
||||||
}
|
|
||||||
// Try to check if there is one proxy that doesn't specify routerByHTTPUser, it means match all.
|
|
||||||
vr, ok = rp.vhostRouter.Get(inDomain, inLocation, "")
|
|
||||||
if ok {
|
|
||||||
return vr, ok
|
|
||||||
}
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// First we check the full hostname
|
|
||||||
// if not exist, then check the wildcard_domain such as *.example.com
|
|
||||||
vr, ok := findRouter(domain, location, routeByHTTPUser)
|
|
||||||
if ok {
|
|
||||||
return vr, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// e.g. domain = test.example.com, try to match wildcard domains.
|
|
||||||
// *.example.com
|
|
||||||
// *.com
|
|
||||||
domainSplit := strings.Split(domain, ".")
|
|
||||||
for len(domainSplit) >= 3 {
|
|
||||||
domainSplit[0] = "*"
|
|
||||||
domain = strings.Join(domainSplit, ".")
|
|
||||||
vr, ok = findRouter(domain, location, routeByHTTPUser)
|
|
||||||
if ok {
|
|
||||||
return vr, true
|
|
||||||
}
|
|
||||||
domainSplit = domainSplit[1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, try to check if there is one proxy that domain is "*" means match all domains.
|
|
||||||
vr, ok = findRouter("*", location, routeByHTTPUser)
|
|
||||||
if ok {
|
|
||||||
return vr, true
|
|
||||||
}
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Request) {
|
func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Request) {
|
||||||
hj, ok := rw.(http.Hijacker)
|
hj, ok := rw.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|
|
||||||
|
|
@ -29,11 +29,11 @@ type HTTPSMuxer struct {
|
||||||
|
|
||||||
func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, error) {
|
func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, error) {
|
||||||
mux, err := NewMuxer(listener, GetHTTPSHostname, timeout)
|
mux, err := NewMuxer(listener, GetHTTPSHostname, timeout)
|
||||||
mux.SetFailHookFunc(vhostFailed)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &HTTPSMuxer{mux}, err
|
mux.SetFailHookFunc(vhostFailed)
|
||||||
|
return &HTTPSMuxer{mux}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetHTTPSHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) {
|
func GetHTTPSHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) {
|
||||||
|
|
|
||||||
|
|
@ -16,23 +16,53 @@ func TestGetHTTPSHostname(t *testing.T) {
|
||||||
require.NoError(err)
|
require.NoError(err)
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
|
|
||||||
var conn net.Conn
|
connCh := make(chan net.Conn, 1)
|
||||||
|
acceptErrCh := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
conn, _ = l.Accept()
|
conn, err := l.Accept()
|
||||||
require.NotNil(conn)
|
if err != nil {
|
||||||
|
acceptErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
connCh <- conn
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
clientErrCh := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
tls.Dial("tcp", l.Addr().String(), &tls.Config{
|
conn, err := tls.Dial("tcp", l.Addr().String(), &tls.Config{
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
ServerName: "example.com",
|
ServerName: "example.com",
|
||||||
})
|
})
|
||||||
|
if conn != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
clientErrCh <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
var conn net.Conn
|
||||||
_, infos, err := GetHTTPSHostname(conn)
|
select {
|
||||||
|
case conn = <-connCh:
|
||||||
|
case err := <-acceptErrCh:
|
||||||
|
require.NoError(err)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timed out waiting for accepted connection")
|
||||||
|
}
|
||||||
|
require.NotNil(conn)
|
||||||
|
|
||||||
|
serverConn, infos, err := GetHTTPSHostname(conn)
|
||||||
|
if serverConn != nil {
|
||||||
|
_ = serverConn.Close()
|
||||||
|
} else {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
require.NoError(err)
|
require.NoError(err)
|
||||||
require.Equal("example.com", infos["Host"])
|
require.Equal("example.com", infos["Host"])
|
||||||
require.Equal("https", infos["Scheme"])
|
require.Equal("https", infos["Scheme"])
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-clientErrCh:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timed out waiting for TLS client")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -93,12 +93,19 @@ func (r *Routers) Del(domain, location, httpUser string) {
|
||||||
routersByHTTPUser[httpUser] = newVrs
|
routersByHTTPUser[httpUser] = newVrs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get returns the best location match for an exact host and exact HTTP user.
|
||||||
|
// It does not apply all-users, wildcard-domain, or catch-all-domain fallback.
|
||||||
func (r *Routers) Get(host, path, httpUser string) (vr *Router, exist bool) {
|
func (r *Routers) Get(host, path, httpUser string) (vr *Router, exist bool) {
|
||||||
host = strings.ToLower(host)
|
host = strings.ToLower(host)
|
||||||
|
|
||||||
r.mutex.RLock()
|
r.mutex.RLock()
|
||||||
defer r.mutex.RUnlock()
|
defer r.mutex.RUnlock()
|
||||||
|
|
||||||
|
return r.getLocked(host, path, httpUser)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLocked performs an exact-host lookup; host must already be lower-cased.
|
||||||
|
func (r *Routers) getLocked(host, path, httpUser string) (vr *Router, exist bool) {
|
||||||
routersByHTTPUser, found := r.indexByDomain[host]
|
routersByHTTPUser, found := r.indexByDomain[host]
|
||||||
if !found {
|
if !found {
|
||||||
return
|
return
|
||||||
|
|
@ -117,6 +124,49 @@ func (r *Routers) Get(host, path, httpUser string) (vr *Router, exist bool) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Routers) getByRoute(host, path, httpUser string) (*Router, bool) {
|
||||||
|
host = strings.ToLower(host)
|
||||||
|
|
||||||
|
r.mutex.RLock()
|
||||||
|
defer r.mutex.RUnlock()
|
||||||
|
|
||||||
|
// First we check the full hostname; if it doesn't exist, then check wildcard domains.
|
||||||
|
// For example, test.example.com checks *.example.com before falling back to "*".
|
||||||
|
vr, ok := r.getExactOrAllUsersLocked(host, path, httpUser)
|
||||||
|
if ok {
|
||||||
|
return vr, true
|
||||||
|
}
|
||||||
|
|
||||||
|
hostSplit := strings.Split(host, ".")
|
||||||
|
// Keep two-label hosts out of the wildcard walk, so example.com does not match *.com.
|
||||||
|
for len(hostSplit) >= 3 {
|
||||||
|
// Replace the leftmost remaining label with the wildcard marker.
|
||||||
|
hostSplit[0] = "*"
|
||||||
|
host = strings.Join(hostSplit, ".")
|
||||||
|
vr, ok = r.getExactOrAllUsersLocked(host, path, httpUser)
|
||||||
|
if ok {
|
||||||
|
return vr, true
|
||||||
|
}
|
||||||
|
hostSplit = hostSplit[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, try to check if there is one proxy whose domain is "*", which means match all domains.
|
||||||
|
return r.getExactOrAllUsersLocked("*", path, httpUser)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Routers) getExactOrAllUsersLocked(host, path, httpUser string) (*Router, bool) {
|
||||||
|
vr, ok := r.getLocked(host, path, httpUser)
|
||||||
|
if ok {
|
||||||
|
return vr, true
|
||||||
|
}
|
||||||
|
// Try to check if there is one proxy that doesn't specify routeByHTTPUser, it means match all.
|
||||||
|
vr, ok = r.getLocked(host, path, "")
|
||||||
|
if ok {
|
||||||
|
return vr, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Routers) exist(host, path, httpUser string) (route *Router, exist bool) {
|
func (r *Routers) exist(host, path, httpUser string) (route *Router, exist bool) {
|
||||||
routersByHTTPUser, found := r.indexByDomain[host]
|
routersByHTTPUser, found := r.indexByDomain[host]
|
||||||
if !found {
|
if !found {
|
||||||
|
|
|
||||||
257
pkg/util/vhost/router_test.go
Normal file
257
pkg/util/vhost/router_test.go
Normal file
|
|
@ -0,0 +1,257 @@
|
||||||
|
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package vhost
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRoutersGet(t *testing.T) {
|
||||||
|
routers := NewRouters()
|
||||||
|
require.NoError(t, routers.Add("example.com", "/api", "alice", "exact-user"))
|
||||||
|
require.NoError(t, routers.Add("example.com", "/public", "", "exact-all-users"))
|
||||||
|
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users"))
|
||||||
|
require.NoError(t, routers.Add("*", "/api", "", "all-domains"))
|
||||||
|
|
||||||
|
t.Run("exact host and user match with normalized domain", func(t *testing.T) {
|
||||||
|
router, ok := routers.Get("EXAMPLE.COM", "/api/users", "alice")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "exact-user", router.payload)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exact host and empty user match", func(t *testing.T) {
|
||||||
|
router, ok := routers.Get("EXAMPLE.COM", "/public/docs", "")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "exact-all-users", router.payload)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not fall back from named user to empty user", func(t *testing.T) {
|
||||||
|
// Get intentionally requires an exact HTTP user; route-level fallbacks live in getByRoute.
|
||||||
|
_, ok := routers.Get("EXAMPLE.COM", "/public/docs", "alice")
|
||||||
|
require.False(t, ok)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not fall back to wildcard domains", func(t *testing.T) {
|
||||||
|
_, ok := routers.Get("foo.example.com", "/api/users", "")
|
||||||
|
require.False(t, ok)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not fall back to catch-all domain", func(t *testing.T) {
|
||||||
|
_, ok := routers.Get("missing.test", "/api/users", "")
|
||||||
|
require.False(t, ok)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoutersGetByRoute(t *testing.T) {
|
||||||
|
routers := NewRouters()
|
||||||
|
require.NoError(t, routers.Add("example.com", "/api", "alice", "exact-user"))
|
||||||
|
require.NoError(t, routers.Add("example.com", "/api", "", "exact-all-users"))
|
||||||
|
require.NoError(t, routers.Add("exact.example.com", "/api", "", "exact-subdomain"))
|
||||||
|
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users"))
|
||||||
|
require.NoError(t, routers.Add("*.foo.example.com", "/api", "", "specific-wildcard"))
|
||||||
|
require.NoError(t, routers.Add("*.bar.com", "/api", "", "wildcard-parent-domain"))
|
||||||
|
require.NoError(t, routers.Add("*", "/admin", "root", "all-domains-user"))
|
||||||
|
require.NoError(t, routers.Add("*", "/", "", "all-domains"))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
domain string
|
||||||
|
location string
|
||||||
|
httpUser string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact domain and http user",
|
||||||
|
domain: "example.com",
|
||||||
|
location: "/api/users",
|
||||||
|
httpUser: "alice",
|
||||||
|
want: "exact-user",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exact domain falls back to all users",
|
||||||
|
domain: "example.com",
|
||||||
|
location: "/api/users",
|
||||||
|
httpUser: "bob",
|
||||||
|
want: "exact-all-users",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard domain uses all users fallback",
|
||||||
|
domain: "foo.example.com",
|
||||||
|
location: "/api/users",
|
||||||
|
httpUser: "bob",
|
||||||
|
want: "wildcard-all-users",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed-case domain is normalized",
|
||||||
|
domain: "Foo.Example.Com",
|
||||||
|
location: "/api/users",
|
||||||
|
httpUser: "bob",
|
||||||
|
want: "wildcard-all-users",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exact domain wins over wildcard domain",
|
||||||
|
domain: "exact.example.com",
|
||||||
|
location: "/api/users",
|
||||||
|
httpUser: "bob",
|
||||||
|
want: "exact-subdomain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "more specific wildcard wins over broader wildcard",
|
||||||
|
domain: "bar.foo.example.com",
|
||||||
|
location: "/api/users",
|
||||||
|
httpUser: "bob",
|
||||||
|
want: "specific-wildcard",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard walk checks parent domains",
|
||||||
|
domain: "a.b.bar.com",
|
||||||
|
location: "/api/users",
|
||||||
|
httpUser: "bob",
|
||||||
|
want: "wildcard-parent-domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "catch-all domain fallback",
|
||||||
|
domain: "foo.test.com",
|
||||||
|
location: "/other",
|
||||||
|
httpUser: "bob",
|
||||||
|
want: "all-domains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "catch-all domain honors http user",
|
||||||
|
domain: "foo.test.com",
|
||||||
|
location: "/admin/panel",
|
||||||
|
httpUser: "root",
|
||||||
|
want: "all-domains-user",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
router, ok := routers.getByRoute(tt.domain, tt.location, tt.httpUser)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, tt.want, router.payload)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoutersGetByRouteNoMatch(t *testing.T) {
|
||||||
|
routers := NewRouters()
|
||||||
|
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users"))
|
||||||
|
require.NoError(t, routers.Add("*.com", "/api", "", "top-level-wildcard"))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
domain string
|
||||||
|
location string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "two-label domain does not enter wildcard walk",
|
||||||
|
domain: "example.com",
|
||||||
|
location: "/api/users",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing catch-all remains no match",
|
||||||
|
domain: "foo.test.com",
|
||||||
|
location: "/api/users",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong path remains no match",
|
||||||
|
domain: "foo.example.com",
|
||||||
|
location: "/other",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
router, ok := routers.getByRoute(tt.domain, tt.location, "bob")
|
||||||
|
require.False(t, ok)
|
||||||
|
require.Nil(t, router)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoutersConcurrentGetByRouteAndAdd(t *testing.T) {
|
||||||
|
routers := NewRouters()
|
||||||
|
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard"))
|
||||||
|
|
||||||
|
const readers = 8
|
||||||
|
const iterations = 200
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
start := make(chan struct{})
|
||||||
|
errCh := make(chan error, readers+1)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for id := range readers {
|
||||||
|
wg.Go(func() {
|
||||||
|
<-start
|
||||||
|
for j := range iterations {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
router, ok := routers.getByRoute("foo.example.com", "/api/users", "")
|
||||||
|
if !ok || router == nil || router.payload != "wildcard" {
|
||||||
|
errCh <- fmt.Errorf("reader %d iteration %d got router=%v ok=%v", id, j, router, ok)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Go(func() {
|
||||||
|
<-start
|
||||||
|
for i := range iterations {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
err := routers.Add(fmt.Sprintf("host-%d.example.com", i), "/api", "", i)
|
||||||
|
if err != nil {
|
||||||
|
errCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
close(start)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
cancel()
|
||||||
|
t.Fatal("concurrent route lookup and add timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
close(errCh)
|
||||||
|
for err := range errCh {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -148,41 +148,9 @@ func (v *Muxer) Listen(ctx context.Context, cfg *RouteConfig) (l *Listener, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *Muxer) getListener(name, path, httpUser string) (*Listener, bool) {
|
func (v *Muxer) getListener(name, path, httpUser string) (*Listener, bool) {
|
||||||
findRouter := func(inName, inPath, inHTTPUser string) (*Listener, bool) {
|
vr, ok := v.registryRouter.getByRoute(name, path, httpUser)
|
||||||
vr, ok := v.registryRouter.Get(inName, inPath, inHTTPUser)
|
|
||||||
if ok {
|
|
||||||
return vr.payload.(*Listener), true
|
|
||||||
}
|
|
||||||
// Try to check if there is one proxy that doesn't specify routerByHTTPUser, it means match all.
|
|
||||||
vr, ok = v.registryRouter.Get(inName, inPath, "")
|
|
||||||
if ok {
|
|
||||||
return vr.payload.(*Listener), true
|
|
||||||
}
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// first we check the full hostname
|
|
||||||
// if not exist, then check the wildcard_domain such as *.example.com
|
|
||||||
l, ok := findRouter(name, path, httpUser)
|
|
||||||
if ok {
|
if ok {
|
||||||
return l, true
|
return vr.payload.(*Listener), true
|
||||||
}
|
|
||||||
|
|
||||||
domainSplit := strings.Split(name, ".")
|
|
||||||
for len(domainSplit) >= 3 {
|
|
||||||
domainSplit[0] = "*"
|
|
||||||
name = strings.Join(domainSplit, ".")
|
|
||||||
|
|
||||||
l, ok = findRouter(name, path, httpUser)
|
|
||||||
if ok {
|
|
||||||
return l, true
|
|
||||||
}
|
|
||||||
domainSplit = domainSplit[1:]
|
|
||||||
}
|
|
||||||
// Finally, try to check if there is one proxy that domain is "*" means match all domains.
|
|
||||||
l, ok = findRouter("*", path, httpUser)
|
|
||||||
if ok {
|
|
||||||
return l, true
|
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ import (
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"k8s.io/utils/clock"
|
||||||
|
|
||||||
"github.com/fatedier/frp/pkg/util/util"
|
"github.com/fatedier/frp/pkg/util/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -48,6 +50,7 @@ type FastBackoffOptions struct {
|
||||||
|
|
||||||
type fastBackoffImpl struct {
|
type fastBackoffImpl struct {
|
||||||
options FastBackoffOptions
|
options FastBackoffOptions
|
||||||
|
clock clock.PassiveClock
|
||||||
|
|
||||||
lastCalledTime time.Time
|
lastCalledTime time.Time
|
||||||
consecutiveErrCount int
|
consecutiveErrCount int
|
||||||
|
|
@ -57,18 +60,26 @@ type fastBackoffImpl struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFastBackoffManager(options FastBackoffOptions) BackoffManager {
|
func NewFastBackoffManager(options FastBackoffOptions) BackoffManager {
|
||||||
|
return newFastBackoffManagerWithClock(options, clock.RealClock{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFastBackoffManagerWithClock(options FastBackoffOptions, clk clock.PassiveClock) BackoffManager {
|
||||||
|
if clk == nil {
|
||||||
|
clk = clock.RealClock{}
|
||||||
|
}
|
||||||
return &fastBackoffImpl{
|
return &fastBackoffImpl{
|
||||||
options: options,
|
options: options,
|
||||||
|
clock: clk,
|
||||||
countsInFastRetryWindow: 1,
|
countsInFastRetryWindow: 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fastBackoffImpl) Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration {
|
func (f *fastBackoffImpl) Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration {
|
||||||
if f.lastCalledTime.IsZero() {
|
if f.lastCalledTime.IsZero() {
|
||||||
f.lastCalledTime = time.Now()
|
f.lastCalledTime = f.clock.Now()
|
||||||
return f.options.Duration
|
return f.options.Duration
|
||||||
}
|
}
|
||||||
now := time.Now()
|
now := f.clock.Now()
|
||||||
f.lastCalledTime = now
|
f.lastCalledTime = now
|
||||||
|
|
||||||
if previousConditionError {
|
if previousConditionError {
|
||||||
|
|
|
||||||
27
pkg/util/wait/backoff_test.go
Normal file
27
pkg/util/wait/backoff_test.go
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
package wait
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
clocktesting "k8s.io/utils/clock/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFastBackoffManagerUsesClock(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||||
|
clk := clocktesting.NewFakeClock(start)
|
||||||
|
backoff := newFastBackoffManagerWithClock(FastBackoffOptions{
|
||||||
|
Duration: time.Second,
|
||||||
|
}, clk).(*fastBackoffImpl)
|
||||||
|
|
||||||
|
require.Equal(time.Second, backoff.Backoff(0, false))
|
||||||
|
require.Equal(start, backoff.lastCalledTime)
|
||||||
|
|
||||||
|
next := start.Add(time.Minute)
|
||||||
|
clk.SetTime(next)
|
||||||
|
require.Equal(time.Second, backoff.Backoff(time.Second, false))
|
||||||
|
require.Equal(next, backoff.lastCalledTime)
|
||||||
|
}
|
||||||
|
|
@ -19,7 +19,6 @@ import "strings"
|
||||||
// LogWriter forwards writes to frp's logger at configurable level.
|
// LogWriter forwards writes to frp's logger at configurable level.
|
||||||
// It is safe for concurrent use as long as the underlying Logger is thread-safe.
|
// It is safe for concurrent use as long as the underlying Logger is thread-safe.
|
||||||
type LogWriter struct {
|
type LogWriter struct {
|
||||||
xl *Logger
|
|
||||||
logFunc func(string)
|
logFunc func(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -31,35 +30,30 @@ func (w LogWriter) Write(p []byte) (n int, err error) {
|
||||||
|
|
||||||
func NewTraceWriter(xl *Logger) LogWriter {
|
func NewTraceWriter(xl *Logger) LogWriter {
|
||||||
return LogWriter{
|
return LogWriter{
|
||||||
xl: xl,
|
|
||||||
logFunc: func(msg string) { xl.Tracef("%s", msg) },
|
logFunc: func(msg string) { xl.Tracef("%s", msg) },
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDebugWriter(xl *Logger) LogWriter {
|
func NewDebugWriter(xl *Logger) LogWriter {
|
||||||
return LogWriter{
|
return LogWriter{
|
||||||
xl: xl,
|
|
||||||
logFunc: func(msg string) { xl.Debugf("%s", msg) },
|
logFunc: func(msg string) { xl.Debugf("%s", msg) },
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInfoWriter(xl *Logger) LogWriter {
|
func NewInfoWriter(xl *Logger) LogWriter {
|
||||||
return LogWriter{
|
return LogWriter{
|
||||||
xl: xl,
|
|
||||||
logFunc: func(msg string) { xl.Infof("%s", msg) },
|
logFunc: func(msg string) { xl.Infof("%s", msg) },
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWarnWriter(xl *Logger) LogWriter {
|
func NewWarnWriter(xl *Logger) LogWriter {
|
||||||
return LogWriter{
|
return LogWriter{
|
||||||
xl: xl,
|
|
||||||
logFunc: func(msg string) { xl.Warnf("%s", msg) },
|
logFunc: func(msg string) { xl.Warnf("%s", msg) },
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewErrorWriter(xl *Logger) LogWriter {
|
func NewErrorWriter(xl *Logger) LogWriter {
|
||||||
return LogWriter{
|
return LogWriter{
|
||||||
xl: xl,
|
|
||||||
logFunc: func(msg string) { xl.Errorf("%s", msg) },
|
logFunc: func(msg string) { xl.Errorf("%s", msg) },
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -286,7 +286,6 @@ func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWri
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
r.routes[name] = &routeElement{
|
r.routes[name] = &routeElement{
|
||||||
name: name,
|
|
||||||
routes: routes,
|
routes: routes,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
}
|
}
|
||||||
|
|
@ -383,7 +382,6 @@ func (r *serverRouter) cleanupConnIPs(conn io.Writer) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type routeElement struct {
|
type routeElement struct {
|
||||||
name string
|
|
||||||
routes []net.IPNet
|
routes []net.IPNet
|
||||||
conn io.ReadWriteCloser
|
conn io.ReadWriteCloser
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"k8s.io/utils/clock"
|
||||||
|
|
||||||
"github.com/fatedier/frp/pkg/config/types"
|
"github.com/fatedier/frp/pkg/config/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -38,16 +40,25 @@ type Manager struct {
|
||||||
|
|
||||||
bindAddr string
|
bindAddr string
|
||||||
netType string
|
netType string
|
||||||
|
clock clock.WithTicker
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(netType string, bindAddr string, allowPorts []types.PortsRange) *Manager {
|
func NewManager(netType string, bindAddr string, allowPorts []types.PortsRange) *Manager {
|
||||||
|
return newManagerWithClock(netType, bindAddr, allowPorts, clock.RealClock{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newManagerWithClock(netType string, bindAddr string, allowPorts []types.PortsRange, clk clock.WithTicker) *Manager {
|
||||||
|
if clk == nil {
|
||||||
|
clk = clock.RealClock{}
|
||||||
|
}
|
||||||
pm := &Manager{
|
pm := &Manager{
|
||||||
reservedPorts: make(map[string]*PortCtx),
|
reservedPorts: make(map[string]*PortCtx),
|
||||||
usedPorts: make(map[int]*PortCtx),
|
usedPorts: make(map[int]*PortCtx),
|
||||||
freePorts: make(map[int]struct{}),
|
freePorts: make(map[int]struct{}),
|
||||||
bindAddr: bindAddr,
|
bindAddr: bindAddr,
|
||||||
netType: netType,
|
netType: netType,
|
||||||
|
clock: clk,
|
||||||
}
|
}
|
||||||
if len(allowPorts) > 0 {
|
if len(allowPorts) > 0 {
|
||||||
for _, pair := range allowPorts {
|
for _, pair := range allowPorts {
|
||||||
|
|
@ -72,7 +83,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||||
portCtx := &PortCtx{
|
portCtx := &PortCtx{
|
||||||
ProxyName: name,
|
ProxyName: name,
|
||||||
Closed: false,
|
Closed: false,
|
||||||
UpdateTime: time.Now(),
|
UpdateTime: pm.clock.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
var ok bool
|
var ok bool
|
||||||
|
|
@ -90,9 +101,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||||
if ctx, ok := pm.reservedPorts[name]; ok {
|
if ctx, ok := pm.reservedPorts[name]; ok {
|
||||||
if pm.isPortAvailable(ctx.Port) {
|
if pm.isPortAvailable(ctx.Port) {
|
||||||
realPort = ctx.Port
|
realPort = ctx.Port
|
||||||
pm.usedPorts[realPort] = portCtx
|
pm.markPortAcquiredLocked(name, realPort, portCtx)
|
||||||
pm.reservedPorts[name] = portCtx
|
|
||||||
delete(pm.freePorts, realPort)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -109,9 +118,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||||
}
|
}
|
||||||
if pm.isPortAvailable(k) {
|
if pm.isPortAvailable(k) {
|
||||||
realPort = k
|
realPort = k
|
||||||
pm.usedPorts[realPort] = portCtx
|
pm.markPortAcquiredLocked(name, realPort, portCtx)
|
||||||
pm.reservedPorts[name] = portCtx
|
|
||||||
delete(pm.freePorts, realPort)
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -123,9 +130,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||||
if _, ok = pm.freePorts[port]; ok {
|
if _, ok = pm.freePorts[port]; ok {
|
||||||
if pm.isPortAvailable(port) {
|
if pm.isPortAvailable(port) {
|
||||||
realPort = port
|
realPort = port
|
||||||
pm.usedPorts[realPort] = portCtx
|
pm.markPortAcquiredLocked(name, realPort, portCtx)
|
||||||
pm.reservedPorts[name] = portCtx
|
|
||||||
delete(pm.freePorts, realPort)
|
|
||||||
} else {
|
} else {
|
||||||
err = ErrPortUnAvailable
|
err = ErrPortUnAvailable
|
||||||
}
|
}
|
||||||
|
|
@ -140,6 +145,13 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// markPortAcquiredLocked records a successful acquisition. pm.mu must be held.
|
||||||
|
func (pm *Manager) markPortAcquiredLocked(name string, port int, portCtx *PortCtx) {
|
||||||
|
pm.usedPorts[port] = portCtx
|
||||||
|
pm.reservedPorts[name] = portCtx
|
||||||
|
delete(pm.freePorts, port)
|
||||||
|
}
|
||||||
|
|
||||||
func (pm *Manager) isPortAvailable(port int) bool {
|
func (pm *Manager) isPortAvailable(port int) bool {
|
||||||
if pm.netType == "udp" {
|
if pm.netType == "udp" {
|
||||||
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pm.bindAddr, strconv.Itoa(port)))
|
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pm.bindAddr, strconv.Itoa(port)))
|
||||||
|
|
@ -169,20 +181,36 @@ func (pm *Manager) Release(port int) {
|
||||||
pm.freePorts[port] = struct{}{}
|
pm.freePorts[port] = struct{}{}
|
||||||
delete(pm.usedPorts, port)
|
delete(pm.usedPorts, port)
|
||||||
ctx.Closed = true
|
ctx.Closed = true
|
||||||
ctx.UpdateTime = time.Now()
|
ctx.UpdateTime = pm.clock.Now()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release reserved port if it isn't used in last 24 hours.
|
// Release reserved port if it isn't used in last 24 hours.
|
||||||
func (pm *Manager) cleanReservedPortsWorker() {
|
func (pm *Manager) cleanReservedPortsWorker() {
|
||||||
|
pm.cleanReservedPortsWorkerUntil(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *Manager) cleanReservedPortsWorkerUntil(stopCh <-chan struct{}) {
|
||||||
|
ticker := pm.clock.NewTicker(CleanReservedPortsInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
time.Sleep(CleanReservedPortsInterval)
|
select {
|
||||||
pm.mu.Lock()
|
case <-ticker.C():
|
||||||
for name, ctx := range pm.reservedPorts {
|
pm.cleanReservedPortsOnce()
|
||||||
if ctx.Closed && time.Since(ctx.UpdateTime) > MaxPortReservedDuration {
|
case <-stopCh:
|
||||||
delete(pm.reservedPorts, name)
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *Manager) cleanReservedPortsOnce() {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
for name, ctx := range pm.reservedPorts {
|
||||||
|
if ctx.Closed && pm.clock.Since(ctx.UpdateTime) > MaxPortReservedDuration {
|
||||||
|
delete(pm.reservedPorts, name)
|
||||||
}
|
}
|
||||||
pm.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
72
server/ports/ports_test.go
Normal file
72
server/ports/ports_test.go
Normal file
|
|
@ -0,0 +1,72 @@
|
||||||
|
package ports
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
clocktesting "k8s.io/utils/clock/testing"
|
||||||
|
|
||||||
|
"github.com/fatedier/frp/pkg/config/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestManagerUsesClockForPortTimestamps(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
port := freeTCPPort(t)
|
||||||
|
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||||
|
clk := clocktesting.NewFakeClock(start)
|
||||||
|
pm := newManagerWithClock("tcp", "127.0.0.1", []types.PortsRange{{Single: port}}, clk)
|
||||||
|
|
||||||
|
realPort, err := pm.Acquire("proxy", port)
|
||||||
|
require.NoError(err)
|
||||||
|
require.Equal(port, realPort)
|
||||||
|
require.Equal(start, pm.usedPorts[port].UpdateTime)
|
||||||
|
|
||||||
|
releasedAt := start.Add(time.Minute)
|
||||||
|
clk.SetTime(releasedAt)
|
||||||
|
pm.Release(port)
|
||||||
|
|
||||||
|
require.Equal(releasedAt, pm.reservedPorts["proxy"].UpdateTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerCleanReservedPortsWorkerUsesClockTicker(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
port := freeTCPPort(t)
|
||||||
|
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||||
|
clk := clocktesting.NewFakeClock(start)
|
||||||
|
pm := newManagerWithClock("tcp", "127.0.0.1", []types.PortsRange{{Single: port}}, clk)
|
||||||
|
|
||||||
|
realPort, err := pm.Acquire("proxy", port)
|
||||||
|
require.NoError(err)
|
||||||
|
require.Equal(port, realPort)
|
||||||
|
pm.Release(port)
|
||||||
|
require.True(pm.hasReservedPort("proxy"))
|
||||||
|
|
||||||
|
require.Eventually(clk.HasWaiters, time.Second, time.Millisecond)
|
||||||
|
clk.Step(MaxPortReservedDuration + CleanReservedPortsInterval + time.Minute)
|
||||||
|
|
||||||
|
require.Eventually(func() bool {
|
||||||
|
return !pm.hasReservedPort("proxy")
|
||||||
|
}, time.Second, time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *Manager) hasReservedPort(name string) bool {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
_, ok := pm.reservedPorts[name]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func freeTCPPort(t *testing.T) int {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
return listener.Addr().(*net.TCPAddr).Port
|
||||||
|
}
|
||||||
|
|
@ -18,6 +18,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"k8s.io/utils/clock"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClientInfo captures metadata about a connected frpc instance.
|
// ClientInfo captures metadata about a connected frpc instance.
|
||||||
|
|
@ -42,12 +44,21 @@ type ClientRegistry struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
clients map[string]*ClientInfo
|
clients map[string]*ClientInfo
|
||||||
runIndex map[string]string
|
runIndex map[string]string
|
||||||
|
clock clock.PassiveClock
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClientRegistry() *ClientRegistry {
|
func NewClientRegistry() *ClientRegistry {
|
||||||
|
return newClientRegistryWithClock(clock.RealClock{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClientRegistryWithClock(clk clock.PassiveClock) *ClientRegistry {
|
||||||
|
if clk == nil {
|
||||||
|
clk = clock.RealClock{}
|
||||||
|
}
|
||||||
return &ClientRegistry{
|
return &ClientRegistry{
|
||||||
clients: make(map[string]*ClientInfo),
|
clients: make(map[string]*ClientInfo),
|
||||||
runIndex: make(map[string]string),
|
runIndex: make(map[string]string),
|
||||||
|
clock: clk,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -64,7 +75,7 @@ func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version,
|
||||||
key = cr.composeClientKey(user, effectiveID)
|
key = cr.composeClientKey(user, effectiveID)
|
||||||
enforceUnique := rawClientID != ""
|
enforceUnique := rawClientID != ""
|
||||||
|
|
||||||
now := time.Now()
|
now := cr.clock.Now()
|
||||||
cr.mu.Lock()
|
cr.mu.Lock()
|
||||||
defer cr.mu.Unlock()
|
defer cr.mu.Unlock()
|
||||||
|
|
||||||
|
|
@ -116,7 +127,7 @@ func (cr *ClientRegistry) MarkOfflineByRunID(runID string) {
|
||||||
} else {
|
} else {
|
||||||
info.RunID = ""
|
info.RunID = ""
|
||||||
info.Online = false
|
info.Online = false
|
||||||
now := time.Now()
|
now := cr.clock.Now()
|
||||||
info.DisconnectedAt = now
|
info.DisconnectedAt = now
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,9 @@ package registry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
clocktesting "k8s.io/utils/clock/testing"
|
||||||
|
|
||||||
"github.com/fatedier/frp/pkg/proto/wire"
|
"github.com/fatedier/frp/pkg/proto/wire"
|
||||||
)
|
)
|
||||||
|
|
@ -35,3 +38,37 @@ func TestClientRegistryRegisterStoresWireProtocol(t *testing.T) {
|
||||||
t.Fatalf("wire protocol mismatch, want %q got %q", wire.ProtocolV2, info.WireProtocol)
|
t.Fatalf("wire protocol mismatch, want %q got %q", wire.ProtocolV2, info.WireProtocol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClientRegistryUsesClockForTimestamps(t *testing.T) {
|
||||||
|
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||||
|
clk := clocktesting.NewFakeClock(start)
|
||||||
|
registry := newClientRegistryWithClock(clk)
|
||||||
|
|
||||||
|
key, conflict := registry.Register("user", "client-id", "run-id", "host", "1.0.0", "127.0.0.1", wire.ProtocolV2)
|
||||||
|
if conflict {
|
||||||
|
t.Fatal("unexpected client conflict")
|
||||||
|
}
|
||||||
|
|
||||||
|
info, ok := registry.GetByKey(key)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("client %q not found", key)
|
||||||
|
}
|
||||||
|
if !info.FirstConnectedAt.Equal(start) {
|
||||||
|
t.Fatalf("first connected time mismatch, want %s got %s", start, info.FirstConnectedAt)
|
||||||
|
}
|
||||||
|
if !info.LastConnectedAt.Equal(start) {
|
||||||
|
t.Fatalf("last connected time mismatch, want %s got %s", start, info.LastConnectedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
disconnectedAt := start.Add(time.Minute)
|
||||||
|
clk.SetTime(disconnectedAt)
|
||||||
|
registry.MarkOfflineByRunID("run-id")
|
||||||
|
|
||||||
|
info, ok = registry.GetByKey(key)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("client %q not found after disconnect", key)
|
||||||
|
}
|
||||||
|
if !info.DisconnectedAt.Equal(disconnectedAt) {
|
||||||
|
t.Fatalf("disconnected time mismatch, want %s got %s", disconnectedAt, info.DisconnectedAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,62 +0,0 @@
|
||||||
package framework
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CleanupActionHandle is an integer pointer type for handling cleanup action
|
|
||||||
type CleanupActionHandle *int
|
|
||||||
|
|
||||||
type cleanupFuncHandle struct {
|
|
||||||
actionHandle CleanupActionHandle
|
|
||||||
actionHook func()
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
cleanupActionsLock sync.Mutex
|
|
||||||
cleanupHookList = []cleanupFuncHandle{}
|
|
||||||
)
|
|
||||||
|
|
||||||
// AddCleanupAction installs a function that will be called in the event of the
|
|
||||||
// whole test being terminated. This allows arbitrary pieces of the overall
|
|
||||||
// test to hook into SynchronizedAfterSuite().
|
|
||||||
// The hooks are called in last-in-first-out order.
|
|
||||||
func AddCleanupAction(fn func()) CleanupActionHandle {
|
|
||||||
p := CleanupActionHandle(new(int))
|
|
||||||
cleanupActionsLock.Lock()
|
|
||||||
defer cleanupActionsLock.Unlock()
|
|
||||||
c := cleanupFuncHandle{actionHandle: p, actionHook: fn}
|
|
||||||
cleanupHookList = append([]cleanupFuncHandle{c}, cleanupHookList...)
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveCleanupAction removes a function that was installed by
|
|
||||||
// AddCleanupAction.
|
|
||||||
func RemoveCleanupAction(p CleanupActionHandle) {
|
|
||||||
cleanupActionsLock.Lock()
|
|
||||||
defer cleanupActionsLock.Unlock()
|
|
||||||
for i, item := range cleanupHookList {
|
|
||||||
if item.actionHandle == p {
|
|
||||||
cleanupHookList = append(cleanupHookList[:i], cleanupHookList[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunCleanupActions runs all functions installed by AddCleanupAction. It does
|
|
||||||
// not remove them (see RemoveCleanupAction) but it does run unlocked, so they
|
|
||||||
// may remove themselves.
|
|
||||||
func RunCleanupActions() {
|
|
||||||
list := []func(){}
|
|
||||||
func() {
|
|
||||||
cleanupActionsLock.Lock()
|
|
||||||
defer cleanupActionsLock.Unlock()
|
|
||||||
for _, p := range cleanupHookList {
|
|
||||||
list = append(list, p.actionHook)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
// Run unlocked.
|
|
||||||
for _, fn := range list {
|
|
||||||
fn()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -23,11 +23,6 @@ func ExpectNotEqual(actual any, extra any, explain ...any) {
|
||||||
gomega.ExpectWithOffset(1, actual).NotTo(gomega.Equal(extra), explain...)
|
gomega.ExpectWithOffset(1, actual).NotTo(gomega.Equal(extra), explain...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpectError expects an error happens, otherwise an exception raises
|
|
||||||
func ExpectError(err error, explain ...any) {
|
|
||||||
gomega.ExpectWithOffset(1, err).To(gomega.HaveOccurred(), explain...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExpectErrorWithOffset(offset int, err error, explain ...any) {
|
func ExpectErrorWithOffset(offset int, err error, explain ...any) {
|
||||||
gomega.ExpectWithOffset(1+offset, err).To(gomega.HaveOccurred(), explain...)
|
gomega.ExpectWithOffset(1+offset, err).To(gomega.HaveOccurred(), explain...)
|
||||||
}
|
}
|
||||||
|
|
@ -47,11 +42,6 @@ func ExpectContainSubstring(actual, substr string, explain ...any) {
|
||||||
gomega.ExpectWithOffset(1, actual).To(gomega.ContainSubstring(substr), explain...)
|
gomega.ExpectWithOffset(1, actual).To(gomega.ContainSubstring(substr), explain...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpectConsistOf expects actual contains precisely the extra elements. The ordering of the elements does not matter.
|
|
||||||
func ExpectConsistOf(actual any, extra any, explain ...any) {
|
|
||||||
gomega.ExpectWithOffset(1, actual).To(gomega.ConsistOf(extra), explain...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExpectContainElements(actual any, extra any, explain ...any) {
|
func ExpectContainElements(actual any, extra any, explain ...any) {
|
||||||
gomega.ExpectWithOffset(1, actual).To(gomega.ContainElements(extra), explain...)
|
gomega.ExpectWithOffset(1, actual).To(gomega.ContainElements(extra), explain...)
|
||||||
}
|
}
|
||||||
|
|
@ -60,16 +50,6 @@ func ExpectNotContainElements(actual any, extra any, explain ...any) {
|
||||||
gomega.ExpectWithOffset(1, actual).NotTo(gomega.ContainElements(extra), explain...)
|
gomega.ExpectWithOffset(1, actual).NotTo(gomega.ContainElements(extra), explain...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpectHaveKey expects the actual map has the key in the keyset
|
|
||||||
func ExpectHaveKey(actual any, key any, explain ...any) {
|
|
||||||
gomega.ExpectWithOffset(1, actual).To(gomega.HaveKey(key), explain...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExpectEmpty expects actual is empty
|
|
||||||
func ExpectEmpty(actual any, explain ...any) {
|
|
||||||
gomega.ExpectWithOffset(1, actual).To(gomega.BeEmpty(), explain...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExpectTrue(actual any, explain ...any) {
|
func ExpectTrue(actual any, explain ...any) {
|
||||||
gomega.ExpectWithOffset(1, actual).Should(gomega.BeTrue(), explain...)
|
gomega.ExpectWithOffset(1, actual).Should(gomega.BeTrue(), explain...)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -38,11 +38,6 @@ type Framework struct {
|
||||||
// Multiple default mock servers used for e2e testing.
|
// Multiple default mock servers used for e2e testing.
|
||||||
mockServers *MockServers
|
mockServers *MockServers
|
||||||
|
|
||||||
// To make sure that this framework cleans up after itself, no matter what,
|
|
||||||
// we install a Cleanup action before each test and clear it after. If we
|
|
||||||
// should abort, the AfterSuite hook should run all Cleanup actions.
|
|
||||||
cleanupHandle CleanupActionHandle
|
|
||||||
|
|
||||||
// beforeEachStarted indicates that BeforeEach has started
|
// beforeEachStarted indicates that BeforeEach has started
|
||||||
beforeEachStarted bool
|
beforeEachStarted bool
|
||||||
|
|
||||||
|
|
@ -87,8 +82,6 @@ func NewFramework(opt Options) *Framework {
|
||||||
func (f *Framework) BeforeEach() {
|
func (f *Framework) BeforeEach() {
|
||||||
f.beforeEachStarted = true
|
f.beforeEachStarted = true
|
||||||
|
|
||||||
f.cleanupHandle = AddCleanupAction(f.AfterEach)
|
|
||||||
|
|
||||||
dir, err := os.MkdirTemp(os.TempDir(), "frp-e2e-test-*")
|
dir, err := os.MkdirTemp(os.TempDir(), "frp-e2e-test-*")
|
||||||
ExpectNoError(err)
|
ExpectNoError(err)
|
||||||
f.TempDirectory = dir
|
f.TempDirectory = dir
|
||||||
|
|
@ -113,8 +106,6 @@ func (f *Framework) AfterEach() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
RemoveCleanupAction(f.cleanupHandle)
|
|
||||||
|
|
||||||
// stop processor
|
// stop processor
|
||||||
for _, p := range f.serverProcesses {
|
for _, p := range f.serverProcesses {
|
||||||
_ = p.Stop()
|
_ = p.Stop()
|
||||||
|
|
@ -266,10 +257,6 @@ func (f *Framework) AllocPortExcludingRanges(ranges ...[2]int) int {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Framework) ReleasePort(port int) {
|
|
||||||
f.portAllocator.Release(port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Framework) RunServer(portName string, s server.Server) {
|
func (f *Framework) RunServer(portName string, s server.Server) {
|
||||||
f.servers = append(f.servers, s)
|
f.servers = append(f.servers, s)
|
||||||
if s.BindPort() > 0 && portName != "" {
|
if s.BindPort() > 0 && portName != "" {
|
||||||
|
|
|
||||||
|
|
@ -75,11 +75,3 @@ func (m *MockServers) GetTemplateParams() map[string]any {
|
||||||
ret[HTTPSimpleServerPort] = m.httpSimpleServer.BindPort()
|
ret[HTTPSimpleServerPort] = m.httpSimpleServer.BindPort()
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServers) GetParam(key string) any {
|
|
||||||
params := m.GetTemplateParams()
|
|
||||||
if v, ok := params[key]; ok {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,3 @@ func (e *RequestExpect) Ensure(fns ...EnsureFunc) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *RequestExpect) Do() (*request.Response, error) {
|
|
||||||
return e.req.Do()
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -31,13 +31,6 @@ func New(options ...Option) *Server {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithBindAddr(addr string) Option {
|
|
||||||
return func(s *Server) *Server {
|
|
||||||
s.bindAddr = addr
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithBindPort(port int) Option {
|
func WithBindPort(port int) Option {
|
||||||
return func(s *Server) *Server {
|
return func(s *Server) *Server {
|
||||||
s.bindPort = port
|
s.bindPort = port
|
||||||
|
|
|
||||||
|
|
@ -61,21 +61,6 @@ func WithBindPort(port int) Option {
|
||||||
return func(s *Server) { s.bindPort = port }
|
return func(s *Server) { s.bindPort = port }
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithClientCredentials(id, secret string) Option {
|
|
||||||
return func(s *Server) {
|
|
||||||
s.clientID = id
|
|
||||||
s.clientSecret = secret
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithAudience(aud string) Option {
|
|
||||||
return func(s *Server) { s.audience = aud }
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithSubject(sub string) Option {
|
|
||||||
return func(s *Server) { s.subject = sub }
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithExpiresIn(seconds int) Option {
|
func WithExpiresIn(seconds int) Option {
|
||||||
return func(s *Server) { s.expiresIn = seconds }
|
return func(s *Server) { s.expiresIn = seconds }
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -40,13 +40,8 @@ type Process struct {
|
||||||
closeOne sync.Once
|
closeOne sync.Once
|
||||||
waitErr error
|
waitErr error
|
||||||
|
|
||||||
started bool
|
started bool
|
||||||
beforeStopHandler func()
|
stopped bool
|
||||||
stopped bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(path string, params []string) *Process {
|
|
||||||
return NewWithEnvs(path, params, nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWithEnvs(path string, params []string, envs []string) *Process {
|
func NewWithEnvs(path string, params []string, envs []string) *Process {
|
||||||
|
|
@ -100,9 +95,6 @@ func (p *Process) Stop() error {
|
||||||
defer func() {
|
defer func() {
|
||||||
p.stopped = true
|
p.stopped = true
|
||||||
}()
|
}()
|
||||||
if p.beforeStopHandler != nil {
|
|
||||||
p.beforeStopHandler()
|
|
||||||
}
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
<-p.done
|
<-p.done
|
||||||
return p.waitErr
|
return p.waitErr
|
||||||
|
|
@ -125,10 +117,6 @@ func (p *Process) CountOutput(pattern string) int {
|
||||||
return strings.Count(p.Output(), pattern)
|
return strings.Count(p.Output(), pattern)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) SetBeforeStopHandler(fn func()) {
|
|
||||||
p.beforeStopHandler = fn
|
|
||||||
}
|
|
||||||
|
|
||||||
// WaitForOutput polls the combined process output until the pattern is found
|
// WaitForOutput polls the combined process output until the pattern is found
|
||||||
// count time(s) or the timeout is reached. It also returns early if the process exits.
|
// count time(s) or the timeout is reached. It also returns early if the process exits.
|
||||||
func (p *Process) WaitForOutput(pattern string, count int, timeout time.Duration) error {
|
func (p *Process) WaitForOutput(pattern string, count int, timeout time.Duration) error {
|
||||||
|
|
|
||||||
|
|
@ -22,11 +22,10 @@ type Request struct {
|
||||||
protocol string
|
protocol string
|
||||||
|
|
||||||
// for all protocol
|
// for all protocol
|
||||||
addr string
|
addr string
|
||||||
port int
|
port int
|
||||||
body []byte
|
body []byte
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
resolver *net.Resolver
|
|
||||||
|
|
||||||
// for http or https
|
// for http or https
|
||||||
method string
|
method string
|
||||||
|
|
@ -134,11 +133,6 @@ func (r *Request) Body(content []byte) *Request {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Request) Resolver(resolver *net.Resolver) *Request {
|
|
||||||
r.resolver = resolver
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Request) Do() (*Response, error) {
|
func (r *Request) Do() (*Response, error) {
|
||||||
var (
|
var (
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
|
|
@ -169,7 +163,7 @@ func (r *Request) Do() (*Response, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
dialer := &net.Dialer{Resolver: r.resolver}
|
dialer := &net.Dialer{}
|
||||||
switch r.protocol {
|
switch r.protocol {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
conn, err = dialer.Dial("tcp", addr)
|
conn, err = dialer.Dial("tcp", addr)
|
||||||
|
|
@ -225,7 +219,6 @@ func (r *Request) sendHTTPRequest(method, urlstr string, host string, headers ma
|
||||||
Timeout: time.Second,
|
Timeout: time.Second,
|
||||||
KeepAlive: 30 * time.Second,
|
KeepAlive: 30 * time.Second,
|
||||||
DualStack: true,
|
DualStack: true,
|
||||||
Resolver: r.resolver,
|
|
||||||
}).DialContext,
|
}).DialContext,
|
||||||
MaxIdleConns: 100,
|
MaxIdleConns: 100,
|
||||||
IdleConnTimeout: 90 * time.Second,
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue