mirror of
https://github.com/fatedier/frp.git
synced 2026-05-15 08:05:49 -06:00
336 lines
8.7 KiB
Go
336 lines
8.7 KiB
Go
// Copyright 2019 fatedier, fatedier@gmail.com
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
goliblog "github.com/fatedier/golib/log"
|
|
|
|
"github.com/fatedier/frp/pkg/msg"
|
|
frplog "github.com/fatedier/frp/pkg/util/log"
|
|
)
|
|
|
|
type testPlugin struct {
|
|
name string
|
|
ops map[string]bool
|
|
handler func(context.Context, string, any) (*Response, any, error)
|
|
}
|
|
|
|
// Log-capturing subtests serialize global logger swaps; do not use t.Parallel.
|
|
var logCaptureMu sync.Mutex
|
|
|
|
type logCapture struct {
|
|
bytes.Buffer
|
|
levels []goliblog.Level
|
|
}
|
|
|
|
func (p testPlugin) Name() string {
|
|
return p.name
|
|
}
|
|
|
|
func (p testPlugin) IsSupport(op string) bool {
|
|
return p.ops[op]
|
|
}
|
|
|
|
func (p testPlugin) Handle(ctx context.Context, op string, content any) (*Response, any, error) {
|
|
return p.handler(ctx, op, content)
|
|
}
|
|
|
|
func (w *logCapture) WriteLog(p []byte, level goliblog.Level, _ time.Time) (int, error) {
|
|
w.levels = append(w.levels, level)
|
|
return w.Write(p)
|
|
}
|
|
|
|
func captureLogOutput(t *testing.T) *logCapture {
|
|
t.Helper()
|
|
|
|
logCaptureMu.Lock()
|
|
logOutput := &logCapture{}
|
|
oldLogger := frplog.Logger
|
|
frplog.Logger = goliblog.New(
|
|
goliblog.WithOutput(logOutput),
|
|
goliblog.WithLevel(goliblog.TraceLevel),
|
|
goliblog.WithCaller(false),
|
|
)
|
|
t.Cleanup(func() {
|
|
frplog.Logger = oldLogger
|
|
logCaptureMu.Unlock()
|
|
})
|
|
return logOutput
|
|
}
|
|
|
|
var mutablePluginOps = []struct {
|
|
name string
|
|
op string
|
|
}{
|
|
{name: "login", op: OpLogin},
|
|
{name: "new proxy", op: OpNewProxy},
|
|
{name: "ping", op: OpPing},
|
|
{name: "new work conn", op: OpNewWorkConn},
|
|
{name: "new user conn", op: OpNewUserConn},
|
|
}
|
|
|
|
func callMutableWithUser(m *Manager, op string, user string) (string, error) {
|
|
switch op {
|
|
case OpLogin:
|
|
got, err := m.Login(&LoginContent{Login: msg.Login{User: user}})
|
|
if got == nil {
|
|
return "", err
|
|
}
|
|
return got.User, err
|
|
case OpNewProxy:
|
|
got, err := m.NewProxy(&NewProxyContent{User: UserInfo{User: user}})
|
|
if got == nil {
|
|
return "", err
|
|
}
|
|
return got.User.User, err
|
|
case OpPing:
|
|
got, err := m.Ping(&PingContent{User: UserInfo{User: user}})
|
|
if got == nil {
|
|
return "", err
|
|
}
|
|
return got.User.User, err
|
|
case OpNewWorkConn:
|
|
got, err := m.NewWorkConn(&NewWorkConnContent{User: UserInfo{User: user}})
|
|
if got == nil {
|
|
return "", err
|
|
}
|
|
return got.User.User, err
|
|
case OpNewUserConn:
|
|
got, err := m.NewUserConn(&NewUserConnContent{User: UserInfo{User: user}})
|
|
if got == nil {
|
|
return "", err
|
|
}
|
|
return got.User.User, err
|
|
default:
|
|
panic("unsupported mutable op: " + op)
|
|
}
|
|
}
|
|
|
|
func mutableUser(t *testing.T, op string, content any) string {
|
|
t.Helper()
|
|
|
|
switch op {
|
|
case OpLogin:
|
|
return content.(LoginContent).User
|
|
case OpNewProxy:
|
|
return content.(NewProxyContent).User.User
|
|
case OpPing:
|
|
return content.(PingContent).User.User
|
|
case OpNewWorkConn:
|
|
return content.(NewWorkConnContent).User.User
|
|
case OpNewUserConn:
|
|
return content.(NewUserConnContent).User.User
|
|
default:
|
|
t.Fatalf("unsupported mutable op: %s", op)
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func mutateMutableContent(t *testing.T, op string, content any, user string) any {
|
|
t.Helper()
|
|
|
|
switch op {
|
|
case OpLogin:
|
|
got := content.(LoginContent)
|
|
got.User = user
|
|
return &got
|
|
case OpNewProxy:
|
|
got := content.(NewProxyContent)
|
|
got.User.User = user
|
|
return &got
|
|
case OpPing:
|
|
got := content.(PingContent)
|
|
got.User.User = user
|
|
return &got
|
|
case OpNewWorkConn:
|
|
got := content.(NewWorkConnContent)
|
|
got.User.User = user
|
|
return &got
|
|
case OpNewUserConn:
|
|
got := content.(NewUserConnContent)
|
|
got.User.User = user
|
|
return &got
|
|
default:
|
|
t.Fatalf("unsupported mutable op: %s", op)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func TestManagerMutableContentAcrossOps(t *testing.T) {
|
|
for _, tt := range mutablePluginOps {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
m := NewManager()
|
|
m.Register(testPlugin{
|
|
name: "mutate",
|
|
ops: map[string]bool{tt.op: true},
|
|
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
|
|
if op != tt.op {
|
|
t.Fatalf("unexpected op: %s", op)
|
|
}
|
|
if GetReqidFromContext(ctx) == "" {
|
|
t.Fatal("expected request id in context")
|
|
}
|
|
if got := mutableUser(t, tt.op, content); got != "initial" {
|
|
t.Fatalf("expected initial user, got %q", got)
|
|
}
|
|
return &Response{Unchange: false}, mutateMutableContent(t, tt.op, content, "mutated"), nil
|
|
},
|
|
})
|
|
m.Register(testPlugin{
|
|
name: "observe",
|
|
ops: map[string]bool{tt.op: true},
|
|
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
|
|
if op != tt.op {
|
|
t.Fatalf("unexpected op: %s", op)
|
|
}
|
|
if GetReqidFromContext(ctx) == "" {
|
|
t.Fatal("expected request id in context")
|
|
}
|
|
if got := mutableUser(t, tt.op, content); got != "mutated" {
|
|
t.Fatalf("expected mutated user, got %q", got)
|
|
}
|
|
return &Response{Unchange: true}, mutateMutableContent(t, tt.op, content, "ignored"), nil
|
|
},
|
|
})
|
|
|
|
got, err := callMutableWithUser(m, tt.op, "initial")
|
|
if err != nil {
|
|
t.Fatalf("mutable op failed: %v", err)
|
|
}
|
|
if got != "mutated" {
|
|
t.Fatalf("expected mutated user, got %q", got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestManagerMutableContentRejectStopsChain(t *testing.T) {
|
|
m := NewManager()
|
|
|
|
var called bool
|
|
m.Register(testPlugin{
|
|
name: "reject",
|
|
ops: map[string]bool{OpPing: true},
|
|
handler: func(context.Context, string, any) (*Response, any, error) {
|
|
return &Response{Reject: true, RejectReason: "blocked"}, nil, nil
|
|
},
|
|
})
|
|
m.Register(testPlugin{
|
|
name: "unused",
|
|
ops: map[string]bool{OpPing: true},
|
|
handler: func(context.Context, string, any) (*Response, any, error) {
|
|
called = true
|
|
return &Response{Unchange: true}, nil, nil
|
|
},
|
|
})
|
|
|
|
got, err := m.Ping(&PingContent{})
|
|
if err == nil {
|
|
t.Fatal("expected reject error")
|
|
}
|
|
if got != nil {
|
|
t.Fatalf("expected no returned content, got %#v", got)
|
|
}
|
|
if err.Error() != "blocked" {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if called {
|
|
t.Fatal("expected plugin chain to stop after reject")
|
|
}
|
|
}
|
|
|
|
func TestManagerMutableContentPluginErrorLogLevel(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
op string
|
|
level goliblog.Level
|
|
}{
|
|
{name: "default warning", op: OpLogin, level: goliblog.WarnLevel},
|
|
{name: "new user conn info", op: OpNewUserConn, level: goliblog.InfoLevel},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
logOutput := captureLogOutput(t)
|
|
m := NewManager()
|
|
m.Register(testPlugin{
|
|
name: "error",
|
|
ops: map[string]bool{tt.op: true},
|
|
handler: func(context.Context, string, any) (*Response, any, error) {
|
|
return nil, nil, errors.New("boom")
|
|
},
|
|
})
|
|
|
|
_, err := callMutableWithUser(m, tt.op, "initial")
|
|
if err == nil {
|
|
t.Fatal("expected plugin error")
|
|
}
|
|
if want := "send " + tt.op + " request to plugin error"; err.Error() != want {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if len(logOutput.levels) != 1 || logOutput.levels[0] != tt.level {
|
|
t.Fatalf("expected log level %v, got %v in %q", tt.level, logOutput.levels, logOutput.String())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestManagerCloseProxyAggregatesErrors(t *testing.T) {
|
|
logOutput := captureLogOutput(t)
|
|
m := NewManager()
|
|
|
|
for _, name := range []string{"first", "second"} {
|
|
m.Register(testPlugin{
|
|
name: name,
|
|
ops: map[string]bool{OpCloseProxy: true},
|
|
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
|
|
if GetReqidFromContext(ctx) == "" {
|
|
t.Fatal("expected request id in context")
|
|
}
|
|
if op != OpCloseProxy {
|
|
t.Fatalf("unexpected op: %s", op)
|
|
}
|
|
return nil, nil, errors.New(name + " error")
|
|
},
|
|
})
|
|
}
|
|
|
|
err := m.CloseProxy(&CloseProxyContent{})
|
|
if err == nil {
|
|
t.Fatal("expected close proxy error")
|
|
}
|
|
if !strings.HasPrefix(err.Error(), "send CloseProxy request to plugin errors: ") {
|
|
t.Fatalf("unexpected close proxy error prefix: %v", err)
|
|
}
|
|
if !strings.Contains(err.Error(), "[first]: first error") || !strings.Contains(err.Error(), "[second]: second error") {
|
|
t.Fatalf("missing aggregated errors: %v", err)
|
|
}
|
|
if len(logOutput.levels) != 2 {
|
|
t.Fatalf("expected two warning logs, got %v", logOutput.levels)
|
|
}
|
|
for _, level := range logOutput.levels {
|
|
if level != goliblog.WarnLevel {
|
|
t.Fatalf("expected warning log level, got %v", logOutput.levels)
|
|
}
|
|
}
|
|
}
|