frp/pkg/plugin/server/manager_test.go
2026-05-12 11:09:19 +08:00

336 lines
8.7 KiB
Go

// Copyright 2019 fatedier, fatedier@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"bytes"
"context"
"errors"
"strings"
"sync"
"testing"
"time"
goliblog "github.com/fatedier/golib/log"
"github.com/fatedier/frp/pkg/msg"
frplog "github.com/fatedier/frp/pkg/util/log"
)
type testPlugin struct {
name string
ops map[string]bool
handler func(context.Context, string, any) (*Response, any, error)
}
// Log-capturing subtests serialize global logger swaps; do not use t.Parallel.
var logCaptureMu sync.Mutex
type logCapture struct {
bytes.Buffer
levels []goliblog.Level
}
func (p testPlugin) Name() string {
return p.name
}
func (p testPlugin) IsSupport(op string) bool {
return p.ops[op]
}
func (p testPlugin) Handle(ctx context.Context, op string, content any) (*Response, any, error) {
return p.handler(ctx, op, content)
}
func (w *logCapture) WriteLog(p []byte, level goliblog.Level, _ time.Time) (int, error) {
w.levels = append(w.levels, level)
return w.Write(p)
}
func captureLogOutput(t *testing.T) *logCapture {
t.Helper()
logCaptureMu.Lock()
logOutput := &logCapture{}
oldLogger := frplog.Logger
frplog.Logger = goliblog.New(
goliblog.WithOutput(logOutput),
goliblog.WithLevel(goliblog.TraceLevel),
goliblog.WithCaller(false),
)
t.Cleanup(func() {
frplog.Logger = oldLogger
logCaptureMu.Unlock()
})
return logOutput
}
var mutablePluginOps = []struct {
name string
op string
}{
{name: "login", op: OpLogin},
{name: "new proxy", op: OpNewProxy},
{name: "ping", op: OpPing},
{name: "new work conn", op: OpNewWorkConn},
{name: "new user conn", op: OpNewUserConn},
}
func callMutableWithUser(m *Manager, op string, user string) (string, error) {
switch op {
case OpLogin:
got, err := m.Login(&LoginContent{Login: msg.Login{User: user}})
if got == nil {
return "", err
}
return got.User, err
case OpNewProxy:
got, err := m.NewProxy(&NewProxyContent{User: UserInfo{User: user}})
if got == nil {
return "", err
}
return got.User.User, err
case OpPing:
got, err := m.Ping(&PingContent{User: UserInfo{User: user}})
if got == nil {
return "", err
}
return got.User.User, err
case OpNewWorkConn:
got, err := m.NewWorkConn(&NewWorkConnContent{User: UserInfo{User: user}})
if got == nil {
return "", err
}
return got.User.User, err
case OpNewUserConn:
got, err := m.NewUserConn(&NewUserConnContent{User: UserInfo{User: user}})
if got == nil {
return "", err
}
return got.User.User, err
default:
panic("unsupported mutable op: " + op)
}
}
func mutableUser(t *testing.T, op string, content any) string {
t.Helper()
switch op {
case OpLogin:
return content.(LoginContent).User
case OpNewProxy:
return content.(NewProxyContent).User.User
case OpPing:
return content.(PingContent).User.User
case OpNewWorkConn:
return content.(NewWorkConnContent).User.User
case OpNewUserConn:
return content.(NewUserConnContent).User.User
default:
t.Fatalf("unsupported mutable op: %s", op)
return ""
}
}
func mutateMutableContent(t *testing.T, op string, content any, user string) any {
t.Helper()
switch op {
case OpLogin:
got := content.(LoginContent)
got.User = user
return &got
case OpNewProxy:
got := content.(NewProxyContent)
got.User.User = user
return &got
case OpPing:
got := content.(PingContent)
got.User.User = user
return &got
case OpNewWorkConn:
got := content.(NewWorkConnContent)
got.User.User = user
return &got
case OpNewUserConn:
got := content.(NewUserConnContent)
got.User.User = user
return &got
default:
t.Fatalf("unsupported mutable op: %s", op)
return nil
}
}
func TestManagerMutableContentAcrossOps(t *testing.T) {
for _, tt := range mutablePluginOps {
t.Run(tt.name, func(t *testing.T) {
m := NewManager()
m.Register(testPlugin{
name: "mutate",
ops: map[string]bool{tt.op: true},
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
if op != tt.op {
t.Fatalf("unexpected op: %s", op)
}
if GetReqidFromContext(ctx) == "" {
t.Fatal("expected request id in context")
}
if got := mutableUser(t, tt.op, content); got != "initial" {
t.Fatalf("expected initial user, got %q", got)
}
return &Response{Unchange: false}, mutateMutableContent(t, tt.op, content, "mutated"), nil
},
})
m.Register(testPlugin{
name: "observe",
ops: map[string]bool{tt.op: true},
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
if op != tt.op {
t.Fatalf("unexpected op: %s", op)
}
if GetReqidFromContext(ctx) == "" {
t.Fatal("expected request id in context")
}
if got := mutableUser(t, tt.op, content); got != "mutated" {
t.Fatalf("expected mutated user, got %q", got)
}
return &Response{Unchange: true}, mutateMutableContent(t, tt.op, content, "ignored"), nil
},
})
got, err := callMutableWithUser(m, tt.op, "initial")
if err != nil {
t.Fatalf("mutable op failed: %v", err)
}
if got != "mutated" {
t.Fatalf("expected mutated user, got %q", got)
}
})
}
}
func TestManagerMutableContentRejectStopsChain(t *testing.T) {
m := NewManager()
var called bool
m.Register(testPlugin{
name: "reject",
ops: map[string]bool{OpPing: true},
handler: func(context.Context, string, any) (*Response, any, error) {
return &Response{Reject: true, RejectReason: "blocked"}, nil, nil
},
})
m.Register(testPlugin{
name: "unused",
ops: map[string]bool{OpPing: true},
handler: func(context.Context, string, any) (*Response, any, error) {
called = true
return &Response{Unchange: true}, nil, nil
},
})
got, err := m.Ping(&PingContent{})
if err == nil {
t.Fatal("expected reject error")
}
if got != nil {
t.Fatalf("expected no returned content, got %#v", got)
}
if err.Error() != "blocked" {
t.Fatalf("unexpected error: %v", err)
}
if called {
t.Fatal("expected plugin chain to stop after reject")
}
}
func TestManagerMutableContentPluginErrorLogLevel(t *testing.T) {
tests := []struct {
name string
op string
level goliblog.Level
}{
{name: "default warning", op: OpLogin, level: goliblog.WarnLevel},
{name: "new user conn info", op: OpNewUserConn, level: goliblog.InfoLevel},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logOutput := captureLogOutput(t)
m := NewManager()
m.Register(testPlugin{
name: "error",
ops: map[string]bool{tt.op: true},
handler: func(context.Context, string, any) (*Response, any, error) {
return nil, nil, errors.New("boom")
},
})
_, err := callMutableWithUser(m, tt.op, "initial")
if err == nil {
t.Fatal("expected plugin error")
}
if want := "send " + tt.op + " request to plugin error"; err.Error() != want {
t.Fatalf("unexpected error: %v", err)
}
if len(logOutput.levels) != 1 || logOutput.levels[0] != tt.level {
t.Fatalf("expected log level %v, got %v in %q", tt.level, logOutput.levels, logOutput.String())
}
})
}
}
func TestManagerCloseProxyAggregatesErrors(t *testing.T) {
logOutput := captureLogOutput(t)
m := NewManager()
for _, name := range []string{"first", "second"} {
m.Register(testPlugin{
name: name,
ops: map[string]bool{OpCloseProxy: true},
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
if GetReqidFromContext(ctx) == "" {
t.Fatal("expected request id in context")
}
if op != OpCloseProxy {
t.Fatalf("unexpected op: %s", op)
}
return nil, nil, errors.New(name + " error")
},
})
}
err := m.CloseProxy(&CloseProxyContent{})
if err == nil {
t.Fatal("expected close proxy error")
}
if !strings.HasPrefix(err.Error(), "send CloseProxy request to plugin errors: ") {
t.Fatalf("unexpected close proxy error prefix: %v", err)
}
if !strings.Contains(err.Error(), "[first]: first error") || !strings.Contains(err.Error(), "[second]: second error") {
t.Fatalf("missing aggregated errors: %v", err)
}
if len(logOutput.levels) != 2 {
t.Fatalf("expected two warning logs, got %v", logOutput.levels)
}
for _, level := range logOutput.levels {
if level != goliblog.WarnLevel {
t.Fatalf("expected warning log level, got %v", logOutput.levels)
}
}
}