@@ -34,24 +42,6 @@ If you're looking for a meeting recording API, consider checking out [Recall.ai]
an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more.
-
-
-
-
-
- Requestly - Free & Open-Source alternative to Postman
-
- All-in-one platform to Test, Mock and Intercept APIs.
-
-
-
-
-
-
-
- The complete IDE crafted for professional Go developers
-
-
## 为什么使用 frp ?
diff --git a/Release.md b/Release.md
index 42db658a..e724fd34 100644
--- a/Release.md
+++ b/Release.md
@@ -10,9 +10,12 @@ This release introduces wire protocol v2 as a transition path for future frpc/fr
**The default value of `transport.wireProtocol` remains `v1` in this release.** Users can keep the default for now. To test v2 early, upgrade both frpc and frps to versions that support it, then set `transport.wireProtocol = "v2"` in frpc. A v2-enabled frpc cannot connect to an older frps.
+When `transport.wireProtocol = "v2"` is enabled, the control channel uses negotiated AEAD encryption after the login handshake. Both frpc and frps must be upgraded to this release to use v2.
+
v1 will be deprecated when v2 becomes the default in a future release. It will continue to be supported until v0.78.0 is released, and may be removed in v0.78.0 or later.
## Features
* Added `transport.wireProtocol` for frpc to select the internal message protocol used between frpc and frps. Supported values are `v1` and `v2`.
* Added client protocol visibility in the frps dashboard and `/api/clients` API. Online clients now report their negotiated protocol as `v1` or `v2`.
+* Wire protocol v2 now negotiates AEAD control-channel encryption. Supported algorithms are `xchacha20-poly1305` and `aes-256-gcm`; frpc advertises its preferred order based on local AES-GCM hardware support, and frps selects the first supported algorithm from that list.
diff --git a/client/control_session.go b/client/control_session.go
index 4438acfe..d533ba2d 100644
--- a/client/control_session.go
+++ b/client/control_session.go
@@ -74,17 +74,18 @@ func (d *controlSessionDialer) Dial(previousRunID string) (*SessionContext, erro
return nil, err
}
- loginRespMsg, err := d.exchangeLogin(conn, loginMsg)
+ loginResult, err := d.exchangeLogin(conn, loginMsg)
if err != nil {
return nil, err
}
+ loginRespMsg := loginResult.resp
if loginRespMsg.Error != "" {
return nil, errors.New(loginRespMsg.Error)
}
var controlRW io.ReadWriter = conn
if d.clientSpec == nil || d.clientSpec.Type != "ssh-tunnel" {
- controlRW, err = netpkg.NewCryptoReadWriter(conn, d.auth.EncryptionKey())
+ controlRW, err = d.newControlReadWriter(conn, loginResult.crypto)
if err != nil {
return nil, fmt.Errorf("create control crypto read writer: %w", err)
}
@@ -125,9 +126,16 @@ func (d *controlSessionDialer) buildLoginMsg(previousRunID string) (*msg.Login,
return loginMsg, nil
}
-func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login) (*msg.LoginResp, error) {
+type loginExchangeResult struct {
+ resp *msg.LoginResp
+ crypto *wire.CryptoContext
+}
+
+func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login) (*loginExchangeResult, error) {
rw := msg.NewV1ReadWriter(conn)
var wireConn *wire.Conn
+ var clientHello wire.ClientHello
+ var clientHelloPayload []byte
if d.common.Transport.WireProtocol == wire.ProtocolV2 {
if err := wire.WriteMagic(conn); err != nil {
@@ -136,14 +144,23 @@ func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login)
wireConn = wire.NewConn(conn)
rw = msg.NewV2ReadWriterWithConn(wireConn)
- hello := wire.DefaultClientHello(wire.BootstrapInfo{
+ var err error
+ clientHello, err = wire.NewClientHello(wire.BootstrapInfo{
Transport: d.common.Transport.Protocol,
TLS: lo.FromPtr(d.common.Transport.TLS.Enable) || d.common.Transport.Protocol == "wss" || d.common.Transport.Protocol == "quic",
TCPMux: lo.FromPtr(d.common.Transport.TCPMux),
})
- if err := wireConn.WriteJSONFrame(wire.FrameTypeClientHello, hello); err != nil {
+ if err != nil {
return nil, err
}
+ clientHelloFrame, err := wire.NewJSONFrame(wire.FrameTypeClientHello, clientHello)
+ if err != nil {
+ return nil, err
+ }
+ if err := wireConn.WriteFrame(clientHelloFrame); err != nil {
+ return nil, err
+ }
+ clientHelloPayload = clientHelloFrame.Payload
}
if err := rw.WriteMsg(loginMsg); err != nil {
return nil, err
@@ -154,19 +171,50 @@ func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login)
_ = conn.SetReadDeadline(time.Time{})
}()
+ var cryptoContext *wire.CryptoContext
if wireConn != nil {
+ serverHelloFrame, err := wireConn.ReadFrame()
+ if err != nil {
+ return nil, err
+ }
+ if serverHelloFrame.Type != wire.FrameTypeServerHello {
+ return nil, fmt.Errorf("unexpected frame type %d, want %d", serverHelloFrame.Type, wire.FrameTypeServerHello)
+ }
var serverHello wire.ServerHello
- if err := wireConn.ReadJSONFrame(wire.FrameTypeServerHello, &serverHello); err != nil {
+ if err := wireConn.UnmarshalFrame(serverHelloFrame, &serverHello); err != nil {
return nil, err
}
if serverHello.Error != "" {
return nil, errors.New(serverHello.Error)
}
+ cryptoContext, err = wire.NewClientCryptoContext(clientHelloPayload, serverHelloFrame.Payload)
+ if err != nil {
+ return nil, err
+ }
}
var loginRespMsg msg.LoginResp
if err := rw.ReadMsgInto(&loginRespMsg); err != nil {
return nil, err
}
- return &loginRespMsg, nil
+ return &loginExchangeResult{
+ resp: &loginRespMsg,
+ crypto: cryptoContext,
+ }, nil
+}
+
+func (d *controlSessionDialer) newControlReadWriter(conn net.Conn, cryptoContext *wire.CryptoContext) (io.ReadWriter, error) {
+ if d.common.Transport.WireProtocol == wire.ProtocolV2 {
+ if cryptoContext == nil {
+ return nil, errors.New("missing v2 crypto negotiation")
+ }
+ return netpkg.NewAEADCryptoReadWriter(
+ conn,
+ d.auth.EncryptionKey(),
+ netpkg.AEADCryptoRoleClient,
+ cryptoContext.Algorithm,
+ cryptoContext.TranscriptHash,
+ )
+ }
+ return netpkg.NewCryptoReadWriter(conn, d.auth.EncryptionKey())
}
diff --git a/client/control_session_test.go b/client/control_session_test.go
index 9e59c9cd..a0778fba 100644
--- a/client/control_session_test.go
+++ b/client/control_session_test.go
@@ -29,6 +29,7 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/proto/wire"
+ netpkg "github.com/fatedier/frp/pkg/util/net"
)
type testConnector struct {
@@ -140,8 +141,17 @@ func TestControlSessionDialerDialV2(t *testing.T) {
}
wireConn := wire.NewConn(serverRaw)
+ clientHelloFrame, err := wireConn.ReadFrame()
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ if clientHelloFrame.Type != wire.FrameTypeClientHello {
+ serverErrCh <- fmt.Errorf("unexpected frame type %d, want %d", clientHelloFrame.Type, wire.FrameTypeClientHello)
+ return
+ }
var hello wire.ClientHello
- if err := wireConn.ReadJSONFrame(wire.FrameTypeClientHello, &hello); err != nil {
+ if err := wireConn.UnmarshalFrame(clientHelloFrame, &hello); err != nil {
serverErrCh <- err
return
}
@@ -160,11 +170,52 @@ func TestControlSessionDialerDialV2(t *testing.T) {
serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
return
}
- if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, wire.DefaultServerHello()); err != nil {
+ serverHello, err := wire.NewServerHello(hello)
+ if err != nil {
serverErrCh <- err
return
}
- serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"})
+ serverHelloFrame, err := wire.NewJSONFrame(wire.FrameTypeServerHello, serverHello)
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ cryptoContext := wire.NewCryptoContext(
+ serverHello.Selected.Crypto.Algorithm,
+ clientHelloFrame.Payload,
+ serverHelloFrame.Payload,
+ )
+ if err := wireConn.WriteFrame(serverHelloFrame); err != nil {
+ serverErrCh <- err
+ return
+ }
+ if err := rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"}); err != nil {
+ serverErrCh <- err
+ return
+ }
+
+ controlRW, err := netpkg.NewAEADCryptoReadWriter(
+ serverRaw,
+ []byte("token"),
+ netpkg.AEADCryptoRoleServer,
+ cryptoContext.Algorithm,
+ cryptoContext.TranscriptHash,
+ )
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ controlMsgRW := msg.NewReadWriter(controlRW, wire.ProtocolV2)
+ var ping msg.Ping
+ if err := controlMsgRW.ReadMsgInto(&ping); err != nil {
+ serverErrCh <- err
+ return
+ }
+ if ping.PrivilegeKey != "v2-ping" || ping.Timestamp != 12345 {
+ serverErrCh <- fmt.Errorf("unexpected ping: %+v", ping)
+ return
+ }
+ serverErrCh <- nil
}()
dialer := newTestControlSessionDialer(t, wire.ProtocolV2, connector, nil)
@@ -177,6 +228,7 @@ func TestControlSessionDialerDialV2(t *testing.T) {
require.NotNil(t, sessionCtx.Conn)
require.NotNil(t, sessionCtx.Connector)
require.False(t, connector.closed.Load())
+ require.NoError(t, sessionCtx.Conn.WriteMsg(&msg.Ping{PrivilegeKey: "v2-ping", Timestamp: 12345}))
require.NoError(t, <-serverErrCh)
}
diff --git a/doc/agents/release.md b/doc/agents/release.md
index d8ec6267..d353a088 100644
--- a/doc/agents/release.md
+++ b/doc/agents/release.md
@@ -33,7 +33,51 @@ git commit -m "bump version to vX.Y.Z"
git push origin dev
```
-## 3. Merge dev → master
+## 3. Pre-release Validation
+
+Run the standard e2e suite locally:
+
+```bash
+make e2e
+```
+
+For releases that touch compatibility-sensitive areas such as login, control
+connections, work connections, visitors, transport, or wire protocol handling,
+also run the manual compatibility e2e suite:
+
+```bash
+make e2e-compatibility
+make e2e-compatibility-floor
+```
+
+`make e2e-compatibility` builds the current `frps` and `frpc`, resolves the
+recent stable release baselines from GitHub, downloads or reuses their binaries,
+and tests current binaries against those historical releases. The default number
+of recent baselines is controlled by `FRP_COMPAT_BASELINE_COUNT` in the
+`Makefile`.
+
+Downloaded release binaries are cached under:
+
+```text
+.cache/e2e-compat//_/
+```
+
+For a release validation run that must be exactly reproducible, pass an explicit
+baseline matrix instead of using the floating recent-release list:
+
+```bash
+FRP_COMPAT_BASELINE_VERSIONS="0.X.0 0.Y.0" make e2e-compatibility
+```
+
+Use `make e2e-compatibility-smoke` for a quick single-baseline check while
+iterating locally. If GitHub release metadata requests are rate-limited, set
+`GITHUB_TOKEN` or use `FRP_COMPAT_BASELINE_VERSIONS`.
+
+The compatibility floor is a support-policy decision, not a value that should
+change every release. Update `FRP_COMPAT_FLOOR_VERSION` only when the declared
+compatibility window changes.
+
+## 4. Merge dev → master
Create a PR from `dev` to `master`:
@@ -43,7 +87,7 @@ gh pr create --base master --head dev --title "bump version"
Wait for CI to pass, then merge using **merge commit** (not squash).
-## 4. Tag the Release
+## 5. Tag the Release
```bash
git checkout master
@@ -52,7 +96,7 @@ git tag -a vX.Y.Z -m "bump version"
git push origin vX.Y.Z
```
-## 5. Trigger GoReleaser
+## 6. Trigger GoReleaser
Manually trigger the `goreleaser` workflow in GitHub Actions:
diff --git a/go.mod b/go.mod
index e3a1ddd8..46bbc2d4 100644
--- a/go.mod
+++ b/go.mod
@@ -5,7 +5,7 @@ go 1.25.0
require (
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
github.com/coreos/go-oidc/v3 v3.14.1
- github.com/fatedier/golib v0.6.0
+ github.com/fatedier/golib v0.7.0
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.1
github.com/gorilla/websocket v1.5.0
@@ -30,6 +30,7 @@ require (
golang.org/x/net v0.52.0
golang.org/x/oauth2 v0.28.0
golang.org/x/sync v0.20.0
+ golang.org/x/sys v0.42.0
golang.org/x/time v0.10.0
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
gopkg.in/ini.v1 v1.67.0
@@ -68,7 +69,6 @@ require (
github.com/wlynxg/anet v0.0.5 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
golang.org/x/mod v0.33.0 // indirect
- golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
golang.org/x/tools v0.42.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
diff --git a/go.sum b/go.sum
index 4a5082c0..2fad6ea2 100644
--- a/go.sum
+++ b/go.sum
@@ -20,8 +20,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
-github.com/fatedier/golib v0.6.0 h1:/mgBZZbkbMhIEZoXf7nV8knpUDzas/b+2ruYKxx1lww=
-github.com/fatedier/golib v0.6.0/go.mod h1:ArUGvPg2cOw/py2RAuBt46nNZH2VQ5Z70p109MAZpJw=
+github.com/fatedier/golib v0.7.0 h1:tMDF9ObcwVt59VUHroJOzHQjVFPLymZVMpGm9WAVwhY=
+github.com/fatedier/golib v0.7.0/go.mod h1:ArUGvPg2cOw/py2RAuBt46nNZH2VQ5Z70p109MAZpJw=
github.com/fatedier/yamux v0.0.0-20250825093530-d0154be01cd6 h1:u92UUy6FURPmNsMBUuongRWC0rBqN6gd01Dzu+D21NE=
github.com/fatedier/yamux v0.0.0-20250825093530-d0154be01cd6/go.mod h1:c5/tk6G0dSpXGzJN7Wk1OEie8grdSJAmeawId9Zvd34=
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
diff --git a/pkg/proto/wire/crypto.go b/pkg/proto/wire/crypto.go
new file mode 100644
index 00000000..69aedc69
--- /dev/null
+++ b/pkg/proto/wire/crypto.go
@@ -0,0 +1,197 @@
+// Copyright 2026 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package wire
+
+import (
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/binary"
+ "encoding/json"
+ "fmt"
+ "hash"
+ "runtime"
+
+ "golang.org/x/sys/cpu"
+)
+
+const (
+ AEADAlgorithmAES256GCM = "aes-256-gcm"
+ AEADAlgorithmXChaCha20Poly1305 = "xchacha20-poly1305"
+
+ CryptoRandomSize = 32
+
+ cryptoTranscriptLabel = "frp wire v2 crypto transcript"
+)
+
+var supportedAEADAlgorithms = []string{
+ AEADAlgorithmAES256GCM,
+ AEADAlgorithmXChaCha20Poly1305,
+}
+
+type CryptoContext struct {
+ Algorithm string
+ TranscriptHash []byte
+}
+
+func NewClientHello(bootstrap BootstrapInfo) (ClientHello, error) {
+ clientRandom, err := newCryptoRandom()
+ if err != nil {
+ return ClientHello{}, err
+ }
+ return clientHelloWithCryptoRandom(bootstrap, clientRandom), nil
+}
+
+func NewServerHello(clientHello ClientHello) (ServerHello, error) {
+ if err := ValidateClientHello(clientHello); err != nil {
+ return ServerHello{}, err
+ }
+ algorithm, ok := SelectAEADAlgorithm(clientHello.Capabilities.Crypto.Algorithms)
+ if !ok {
+ return ServerHello{}, fmt.Errorf("no supported crypto algorithm")
+ }
+ serverRandom, err := newCryptoRandom()
+ if err != nil {
+ return ServerHello{}, err
+ }
+ return ServerHello{
+ Selected: ServerSelection{
+ Message: MessageSelection{
+ Codec: MessageCodecJSON,
+ },
+ Crypto: CryptoSelection{
+ Algorithm: algorithm,
+ ServerRandom: serverRandom,
+ },
+ },
+ }, nil
+}
+
+func ValidateCryptoCapabilities(c CryptoCapabilities) error {
+ if len(c.ClientRandom) != CryptoRandomSize {
+ return fmt.Errorf("invalid crypto client random length %d, want %d", len(c.ClientRandom), CryptoRandomSize)
+ }
+ if _, ok := SelectAEADAlgorithm(c.Algorithms); !ok {
+ return fmt.Errorf("no supported crypto algorithm")
+ }
+ return nil
+}
+
+func ValidateServerHelloForClient(clientHello ClientHello, serverHello ServerHello) error {
+ if serverHello.Selected.Message.Codec != MessageCodecJSON {
+ return fmt.Errorf("unsupported selected message codec: %s", serverHello.Selected.Message.Codec)
+ }
+ cryptoSelection := serverHello.Selected.Crypto
+ if !IsSupportedAEADAlgorithm(cryptoSelection.Algorithm) {
+ return fmt.Errorf("unknown selected crypto algorithm: %s", cryptoSelection.Algorithm)
+ }
+ if !Supports(clientHello.Capabilities.Crypto.Algorithms, cryptoSelection.Algorithm) {
+ return fmt.Errorf("selected crypto algorithm was not advertised by client: %s", cryptoSelection.Algorithm)
+ }
+ if len(cryptoSelection.ServerRandom) != CryptoRandomSize {
+ return fmt.Errorf("invalid crypto server random length %d, want %d", len(cryptoSelection.ServerRandom), CryptoRandomSize)
+ }
+ return nil
+}
+
+func NewCryptoContext(algorithm string, clientHelloPayload, serverHelloPayload []byte) *CryptoContext {
+ return &CryptoContext{
+ Algorithm: algorithm,
+ TranscriptHash: HashCryptoTranscript(clientHelloPayload, serverHelloPayload),
+ }
+}
+
+func NewClientCryptoContext(clientHelloPayload, serverHelloPayload []byte) (*CryptoContext, error) {
+ var clientHello ClientHello
+ if err := json.Unmarshal(clientHelloPayload, &clientHello); err != nil {
+ return nil, fmt.Errorf("decode ClientHello transcript: %w", err)
+ }
+ var serverHello ServerHello
+ if err := json.Unmarshal(serverHelloPayload, &serverHello); err != nil {
+ return nil, fmt.Errorf("decode ServerHello transcript: %w", err)
+ }
+ if err := ValidateServerHelloForClient(clientHello, serverHello); err != nil {
+ return nil, err
+ }
+
+ return NewCryptoContext(serverHello.Selected.Crypto.Algorithm, clientHelloPayload, serverHelloPayload), nil
+}
+
+func HashCryptoTranscript(clientHelloPayload, serverHelloPayload []byte) []byte {
+ h := sha256.New()
+ _, _ = h.Write([]byte(cryptoTranscriptLabel))
+ writeCryptoTranscriptPart(h, "client hello", clientHelloPayload)
+ writeCryptoTranscriptPart(h, "server hello", serverHelloPayload)
+ return h.Sum(nil)
+}
+
+func writeCryptoTranscriptPart(h hash.Hash, label string, payload []byte) {
+ var length [8]byte
+ binary.BigEndian.PutUint64(length[:], uint64(len(payload)))
+ _, _ = h.Write([]byte{0})
+ _, _ = h.Write([]byte(label))
+ _, _ = h.Write([]byte{0})
+ _, _ = h.Write(length[:])
+ _, _ = h.Write(payload)
+}
+
+func PreferredAEADAlgorithms() []string {
+ if hasFastAESGCM() {
+ return []string{AEADAlgorithmAES256GCM, AEADAlgorithmXChaCha20Poly1305}
+ }
+ return []string{AEADAlgorithmXChaCha20Poly1305, AEADAlgorithmAES256GCM}
+}
+
+func SelectAEADAlgorithm(clientAlgorithms []string) (string, bool) {
+ for _, algorithm := range clientAlgorithms {
+ if IsSupportedAEADAlgorithm(algorithm) {
+ return algorithm, true
+ }
+ }
+ return "", false
+}
+
+func IsSupportedAEADAlgorithm(algorithm string) bool {
+ return Supports(supportedAEADAlgorithms, algorithm)
+}
+
+func newCryptoRandom() ([]byte, error) {
+ b := make([]byte, CryptoRandomSize)
+ if _, err := rand.Read(b); err != nil {
+ return nil, fmt.Errorf("generate crypto random: %w", err)
+ }
+ return b, nil
+}
+
+func hasFastAESGCM() bool {
+ switch runtime.GOARCH {
+ case "amd64":
+ return cpu.X86.HasAES &&
+ cpu.X86.HasPCLMULQDQ &&
+ cpu.X86.HasSSE41 &&
+ cpu.X86.HasSSSE3
+ case "arm64":
+ return cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
+ case "s390x":
+ return cpu.S390X.HasAES &&
+ cpu.S390X.HasAESCTR &&
+ cpu.S390X.HasGHASH
+ case "ppc64", "ppc64le":
+ // Go's ppc64/ppc64le port targets POWER8+, which has AES instructions;
+ // x/sys/cpu does not expose a PPC64 AES feature flag.
+ return true
+ default:
+ return false
+ }
+}
diff --git a/pkg/proto/wire/wire.go b/pkg/proto/wire/wire.go
index 5c38292d..47bf5984 100644
--- a/pkg/proto/wire/wire.go
+++ b/pkg/proto/wire/wire.go
@@ -120,15 +120,23 @@ func (c *Conn) UnmarshalFrame(f *Frame, out any) error {
return json.Unmarshal(f.Payload, out)
}
-func (c *Conn) WriteJSONFrame(frameType uint16, in any) error {
+func NewJSONFrame(frameType uint16, in any) (*Frame, error) {
payload, err := json.Marshal(in)
+ if err != nil {
+ return nil, err
+ }
+ return &Frame{
+ Type: frameType,
+ Payload: payload,
+ }, nil
+}
+
+func (c *Conn) WriteJSONFrame(frameType uint16, in any) error {
+ f, err := NewJSONFrame(frameType, in)
if err != nil {
return err
}
- return c.WriteFrame(&Frame{
- Type: frameType,
- Payload: payload,
- })
+ return c.WriteFrame(f)
}
func WriteMagic(w io.Writer) error {
@@ -170,12 +178,18 @@ type ClientHello struct {
type ClientCapabilities struct {
Message MessageCapabilities `json:"message,omitempty"`
+ Crypto CryptoCapabilities `json:"crypto,omitempty"`
}
type MessageCapabilities struct {
Codecs []string `json:"codecs,omitempty"`
}
+type CryptoCapabilities struct {
+ Algorithms []string `json:"algorithms,omitempty"`
+ ClientRandom []byte `json:"clientRandom,omitempty"`
+}
+
type ServerHello struct {
Selected ServerSelection `json:"selected,omitempty"`
Error string `json:"error,omitempty"`
@@ -183,19 +197,29 @@ type ServerHello struct {
type ServerSelection struct {
Message MessageSelection `json:"message,omitempty"`
+ Crypto CryptoSelection `json:"crypto,omitempty"`
}
type MessageSelection struct {
Codec string `json:"codec,omitempty"`
}
-func DefaultClientHello(bootstrap BootstrapInfo) ClientHello {
+type CryptoSelection struct {
+ Algorithm string `json:"algorithm,omitempty"`
+ ServerRandom []byte `json:"serverRandom,omitempty"`
+}
+
+func clientHelloWithCryptoRandom(bootstrap BootstrapInfo, clientRandom []byte) ClientHello {
return ClientHello{
Bootstrap: bootstrap,
Capabilities: ClientCapabilities{
Message: MessageCapabilities{
Codecs: []string{MessageCodecJSON},
},
+ Crypto: CryptoCapabilities{
+ Algorithms: PreferredAEADAlgorithms(),
+ ClientRandom: clientRandom,
+ },
},
}
}
@@ -218,5 +242,5 @@ func ValidateClientHello(h ClientHello) error {
if !Supports(h.Capabilities.Message.Codecs, MessageCodecJSON) {
return fmt.Errorf("unsupported message codec")
}
- return nil
+ return ValidateCryptoCapabilities(h.Capabilities.Crypto)
}
diff --git a/pkg/proto/wire/wire_test.go b/pkg/proto/wire/wire_test.go
index 19adb0be..b564f712 100644
--- a/pkg/proto/wire/wire_test.go
+++ b/pkg/proto/wire/wire_test.go
@@ -28,7 +28,7 @@ func TestFrameRoundTrip(t *testing.T) {
var buf bytes.Buffer
conn := NewConn(&buf)
- in := DefaultClientHello(BootstrapInfo{
+ in := mustClientHello(t, BootstrapInfo{
Transport: "tcp",
TLS: true,
TCPMux: true,
@@ -112,9 +112,122 @@ func TestCheckMagicV1PreservesReadBytes(t *testing.T) {
}
func TestValidateClientHello(t *testing.T) {
- require.NoError(t, ValidateClientHello(DefaultClientHello(BootstrapInfo{})))
+ hello := mustClientHello(t, BootstrapInfo{})
+ require.NoError(t, ValidateClientHello(hello))
+ require.Len(t, hello.Capabilities.Crypto.ClientRandom, CryptoRandomSize)
+ require.ElementsMatch(t, []string{
+ AEADAlgorithmAES256GCM,
+ AEADAlgorithmXChaCha20Poly1305,
+ }, hello.Capabilities.Crypto.Algorithms)
- hello := DefaultClientHello(BootstrapInfo{})
hello.Capabilities.Message.Codecs = []string{"unknown"}
require.ErrorContains(t, ValidateClientHello(hello), "unsupported message codec")
}
+
+func TestValidateClientHelloRejectsInvalidCrypto(t *testing.T) {
+ hello := mustClientHello(t, BootstrapInfo{})
+ hello.Capabilities.Crypto.ClientRandom = hello.Capabilities.Crypto.ClientRandom[:CryptoRandomSize-1]
+ require.ErrorContains(t, ValidateClientHello(hello), "invalid crypto client random length")
+
+ hello = mustClientHello(t, BootstrapInfo{})
+ hello.Capabilities.Crypto.Algorithms = []string{"unknown"}
+ require.ErrorContains(t, ValidateClientHello(hello), "no supported crypto algorithm")
+}
+
+func TestPreferredAEADAlgorithms(t *testing.T) {
+ require.ElementsMatch(t, []string{
+ AEADAlgorithmAES256GCM,
+ AEADAlgorithmXChaCha20Poly1305,
+ }, PreferredAEADAlgorithms())
+}
+
+func TestNewServerHelloSelectsFirstSupportedAEADAlgorithm(t *testing.T) {
+ hello := mustClientHello(t, BootstrapInfo{})
+ hello.Capabilities.Crypto.Algorithms = []string{"future-aead", AEADAlgorithmXChaCha20Poly1305, AEADAlgorithmAES256GCM}
+
+ serverHello, err := NewServerHello(hello)
+ require.NoError(t, err)
+ require.Equal(t, MessageCodecJSON, serverHello.Selected.Message.Codec)
+ require.Equal(t, AEADAlgorithmXChaCha20Poly1305, serverHello.Selected.Crypto.Algorithm)
+ require.Len(t, serverHello.Selected.Crypto.ServerRandom, CryptoRandomSize)
+}
+
+func TestNewClientCryptoContextValidatesServerHello(t *testing.T) {
+ hello := mustClientHello(t, BootstrapInfo{})
+ serverHello, err := NewServerHello(hello)
+ require.NoError(t, err)
+ clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
+
+ ctx, err := NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
+ require.NoError(t, err)
+ require.Equal(t, serverHello.Selected.Crypto.Algorithm, ctx.Algorithm)
+ require.Len(t, ctx.TranscriptHash, 32)
+
+ tampered := serverHello
+ tampered.Selected.Crypto.ServerRandom = append([]byte(nil), serverHello.Selected.Crypto.ServerRandom...)
+ tampered.Selected.Crypto.ServerRandom[0] ^= 0xff
+ _, tamperedServerHelloPayload := mustCryptoTranscriptPayloads(t, hello, tampered)
+ tamperedCtx, err := NewClientCryptoContext(clientHelloPayload, tamperedServerHelloPayload)
+ require.NoError(t, err)
+ require.NotEqual(t, ctx.TranscriptHash, tamperedCtx.TranscriptHash)
+}
+
+func TestNewCryptoContextBindsFullClientHelloPayload(t *testing.T) {
+ hello := mustClientHello(t, BootstrapInfo{
+ Transport: "tcp",
+ TLS: true,
+ TCPMux: true,
+ })
+ serverHello, err := NewServerHello(hello)
+ require.NoError(t, err)
+ clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
+
+ ctx := NewCryptoContext(serverHello.Selected.Crypto.Algorithm, clientHelloPayload, serverHelloPayload)
+
+ tamperedHello := hello
+ tamperedHello.Bootstrap.TLS = false
+ tamperedClientHelloPayload, _ := mustCryptoTranscriptPayloads(t, tamperedHello, serverHello)
+ tamperedCtx := NewCryptoContext(serverHello.Selected.Crypto.Algorithm, tamperedClientHelloPayload, serverHelloPayload)
+ require.NotEqual(t, ctx.TranscriptHash, tamperedCtx.TranscriptHash)
+}
+
+func TestNewClientCryptoContextRejectsUnknownServerSelection(t *testing.T) {
+ hello := mustClientHello(t, BootstrapInfo{})
+ serverHello, err := NewServerHello(hello)
+ require.NoError(t, err)
+
+ serverHello.Selected.Crypto.Algorithm = "unknown"
+ clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
+ _, err = NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
+ require.ErrorContains(t, err, "unknown selected crypto algorithm")
+}
+
+func TestNewClientCryptoContextRejectsUnadvertisedServerSelection(t *testing.T) {
+ hello := mustClientHello(t, BootstrapInfo{})
+ hello.Capabilities.Crypto.Algorithms = []string{AEADAlgorithmAES256GCM}
+ serverHello, err := NewServerHello(hello)
+ require.NoError(t, err)
+
+ serverHello.Selected.Crypto.Algorithm = AEADAlgorithmXChaCha20Poly1305
+ clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
+ _, err = NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
+ require.ErrorContains(t, err, "selected crypto algorithm was not advertised by client")
+}
+
+func mustClientHello(t *testing.T, bootstrap BootstrapInfo) ClientHello {
+ t.Helper()
+
+ hello, err := NewClientHello(bootstrap)
+ require.NoError(t, err)
+ return hello
+}
+
+func mustCryptoTranscriptPayloads(t *testing.T, hello ClientHello, serverHello ServerHello) ([]byte, []byte) {
+ t.Helper()
+
+ clientHelloFrame, err := NewJSONFrame(FrameTypeClientHello, hello)
+ require.NoError(t, err)
+ serverHelloFrame, err := NewJSONFrame(FrameTypeServerHello, serverHello)
+ require.NoError(t, err)
+ return clientHelloFrame.Payload, serverHelloFrame.Payload
+}
diff --git a/pkg/util/net/conn.go b/pkg/util/net/conn.go
index aa83b409..5dd605a5 100644
--- a/pkg/util/net/conn.go
+++ b/pkg/util/net/conn.go
@@ -16,14 +16,16 @@ package net
import (
"context"
+ "crypto/sha256"
"errors"
"io"
"net"
"sync/atomic"
"time"
- "github.com/fatedier/golib/crypto"
+ libcrypto "github.com/fatedier/golib/crypto"
quic "github.com/quic-go/quic-go"
+ "golang.org/x/crypto/hkdf"
"github.com/fatedier/frp/pkg/util/xlog"
)
@@ -241,8 +243,8 @@ func (conn *wrapQuicStream) Close() error {
}
func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
- encReader := crypto.NewReader(rw, key)
- encWriter, err := crypto.NewWriter(rw, key)
+ encReader := libcrypto.NewReader(rw, key)
+ encWriter, err := libcrypto.NewWriter(rw, key)
if err != nil {
return nil, err
}
@@ -254,3 +256,90 @@ func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
Writer: encWriter,
}, nil
}
+
+type AEADCryptoRole int
+
+const (
+ AEADCryptoRoleClient AEADCryptoRole = iota + 1
+ AEADCryptoRoleServer
+)
+
+const (
+ aeadControlHKDFInfoPrefix = "frp wire v2 control aead"
+ aeadDirectionClientToServer = "client-to-server"
+ aeadDirectionServerToClient = "server-to-client"
+)
+
+// NewAEADCryptoReadWriter wraps rw with framed AEAD encryption for the v2
+// control channel. Frames and their order are authenticated, but end-of-stream
+// is not: a clean EOF at a frame boundary is returned as normal EOF by the
+// underlying AEAD stream. Protocols that need truncation detection for finite
+// objects must add their own authenticated final message.
+func NewAEADCryptoReadWriter(
+ rw io.ReadWriter,
+ key []byte,
+ role AEADCryptoRole,
+ algorithm string,
+ transcriptHash []byte,
+) (io.ReadWriter, error) {
+ clientToServerKey, serverToClientKey, err := deriveAEADControlKeys(key, algorithm, transcriptHash)
+ if err != nil {
+ return nil, err
+ }
+
+ var readKey, writeKey []byte
+ switch role {
+ case AEADCryptoRoleClient:
+ readKey = serverToClientKey
+ writeKey = clientToServerKey
+ case AEADCryptoRoleServer:
+ readKey = clientToServerKey
+ writeKey = serverToClientKey
+ default:
+ return nil, errors.New("invalid aead crypto role")
+ }
+
+ encReader, err := libcrypto.NewAEADStreamReader(rw, libcrypto.AEADStreamOptions{
+ Algorithm: libcrypto.AEADAlgorithm(algorithm),
+ Key: readKey,
+ })
+ if err != nil {
+ return nil, err
+ }
+ encWriter, err := libcrypto.NewAEADStreamWriter(rw, libcrypto.AEADStreamOptions{
+ Algorithm: libcrypto.AEADAlgorithm(algorithm),
+ Key: writeKey,
+ })
+ if err != nil {
+ return nil, err
+ }
+ return struct {
+ io.Reader
+ io.Writer
+ }{
+ Reader: encReader,
+ Writer: encWriter,
+ }, nil
+}
+
+func deriveAEADControlKeys(key []byte, algorithm string, transcriptHash []byte) (clientToServerKey, serverToClientKey []byte, err error) {
+ clientToServerKey, err = deriveAEADControlKey(key, algorithm, transcriptHash, aeadDirectionClientToServer)
+ if err != nil {
+ return nil, nil, err
+ }
+ serverToClientKey, err = deriveAEADControlKey(key, algorithm, transcriptHash, aeadDirectionServerToClient)
+ if err != nil {
+ return nil, nil, err
+ }
+ return clientToServerKey, serverToClientKey, nil
+}
+
+func deriveAEADControlKey(key []byte, algorithm string, transcriptHash []byte, direction string) ([]byte, error) {
+ info := []byte(aeadControlHKDFInfoPrefix + " " + algorithm + " " + direction)
+ reader := hkdf.New(sha256.New, key, transcriptHash, info)
+ out := make([]byte, libcrypto.AEADKeySize)
+ if _, err := io.ReadFull(reader, out); err != nil {
+ return nil, err
+ }
+ return out, nil
+}
diff --git a/pkg/util/net/conn_test.go b/pkg/util/net/conn_test.go
new file mode 100644
index 00000000..42d3c06b
--- /dev/null
+++ b/pkg/util/net/conn_test.go
@@ -0,0 +1,118 @@
+// Copyright 2026 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package net
+
+import (
+ "bytes"
+ "io"
+ stdnet "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/fatedier/frp/pkg/proto/wire"
+)
+
+func TestNewAEADCryptoReadWriterRoundTrip(t *testing.T) {
+ clientConn, serverConn := stdnet.Pipe()
+ defer clientConn.Close()
+ defer serverConn.Close()
+
+ key := []byte("token")
+ transcriptHash := bytes.Repeat([]byte{0x11}, 32)
+ clientRW, err := NewAEADCryptoReadWriter(
+ clientConn,
+ key,
+ AEADCryptoRoleClient,
+ wire.AEADAlgorithmXChaCha20Poly1305,
+ transcriptHash,
+ )
+ require.NoError(t, err)
+ serverRW, err := NewAEADCryptoReadWriter(
+ serverConn,
+ key,
+ AEADCryptoRoleServer,
+ wire.AEADAlgorithmXChaCha20Poly1305,
+ transcriptHash,
+ )
+ require.NoError(t, err)
+
+ clientErrCh := make(chan error, 1)
+ go func() {
+ if _, err := clientRW.Write([]byte("ping")); err != nil {
+ clientErrCh <- err
+ return
+ }
+ buf := make([]byte, len("pong"))
+ _, err := io.ReadFull(clientRW, buf)
+ clientErrCh <- err
+ }()
+
+ buf := make([]byte, len("ping"))
+ _, err = io.ReadFull(serverRW, buf)
+ require.NoError(t, err)
+ require.Equal(t, "ping", string(buf))
+ _, err = serverRW.Write([]byte("pong"))
+ require.NoError(t, err)
+ require.NoError(t, <-clientErrCh)
+}
+
+func TestNewAEADCryptoReadWriterRejectsDifferentTranscript(t *testing.T) {
+ clientConn, serverConn := stdnet.Pipe()
+ defer clientConn.Close()
+ defer serverConn.Close()
+ require.NoError(t, clientConn.SetDeadline(time.Now().Add(time.Second)))
+ require.NoError(t, serverConn.SetDeadline(time.Now().Add(time.Second)))
+
+ key := []byte("token")
+ clientRW, err := NewAEADCryptoReadWriter(
+ clientConn,
+ key,
+ AEADCryptoRoleClient,
+ wire.AEADAlgorithmAES256GCM,
+ bytes.Repeat([]byte{0x22}, 32),
+ )
+ require.NoError(t, err)
+ serverRW, err := NewAEADCryptoReadWriter(
+ serverConn,
+ key,
+ AEADCryptoRoleServer,
+ wire.AEADAlgorithmAES256GCM,
+ bytes.Repeat([]byte{0x33}, 32),
+ )
+ require.NoError(t, err)
+
+ writeErrCh := make(chan error, 1)
+ go func() {
+ _, err := clientRW.Write([]byte("ping"))
+ writeErrCh <- err
+ }()
+
+ buf := make([]byte, len("ping"))
+ _, err = io.ReadFull(serverRW, buf)
+ require.Error(t, err)
+ require.NoError(t, <-writeErrCh)
+}
+
+func TestDeriveAEADControlKeysUsesDistinctDirections(t *testing.T) {
+ clientToServerKey, serverToClientKey, err := deriveAEADControlKeys(
+ []byte("token"),
+ wire.AEADAlgorithmXChaCha20Poly1305,
+ bytes.Repeat([]byte{0x44}, 32),
+ )
+ require.NoError(t, err)
+ require.NotEqual(t, clientToServerKey, serverToClientKey)
+}
diff --git a/server/service.go b/server/service.go
index 5077c21f..7ec303e4 100644
--- a/server/service.go
+++ b/server/service.go
@@ -60,6 +60,7 @@ import (
const (
connReadTimeout time.Duration = 10 * time.Second
+ connWriteTimeout time.Duration = 5 * time.Second
vhostReadWriteTimeout time.Duration = 30 * time.Second
)
@@ -456,7 +457,7 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
controlConn := acceptedConn.conn
if !internal {
var controlRW io.ReadWriter
- controlRW, err = netpkg.NewCryptoReadWriter(conn, svr.auth.EncryptionKey())
+ controlRW, err = acceptedConn.newControlReadWriter(conn, svr.auth.EncryptionKey())
if err == nil {
controlConn = acceptedConn.messageConnFor(controlRW)
}
@@ -468,17 +469,23 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
if err != nil {
xl.Warnf("register control error: %v", err)
- _ = acceptedConn.conn.WriteMsg(&msg.LoginResp{
- Version: version.Full(),
- Error: util.GenerateResponseErrorString("register control error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
- })
+ if writeErr := writeWithDeadline(conn, connWriteTimeout, func() error {
+ return acceptedConn.conn.WriteMsg(&msg.LoginResp{
+ Version: version.Full(),
+ Error: util.GenerateResponseErrorString("register control error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
+ })
+ }); writeErr != nil {
+ xl.Warnf("write login error response error: %v", writeErr)
+ }
conn.Close()
return
}
- if err = acceptedConn.conn.WriteMsg(&msg.LoginResp{
- Version: version.Full(),
- RunID: ctl.runID,
- Error: "",
+ if err = writeWithDeadline(conn, connWriteTimeout, func() error {
+ return acceptedConn.conn.WriteMsg(&msg.LoginResp{
+ Version: version.Full(),
+ RunID: ctl.runID,
+ Error: "",
+ })
}); err != nil {
xl.Warnf("write login response error: %v", err)
svr.ctlManager.Del(m.RunID, ctl)
@@ -521,9 +528,10 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
}
type acceptedConnection struct {
- conn *msg.Conn
- wireProtocol string
- firstMsg msg.Message
+ conn *msg.Conn
+ wireProtocol string
+ cryptoContext *wire.CryptoContext
+ firstMsg msg.Message
}
func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*acceptedConnection, error) {
@@ -544,7 +552,7 @@ func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*accep
wireConn := wire.NewConn(conn)
rw := msg.NewV2ReadWriterWithConn(wireConn)
acceptedConn.conn = msg.NewConn(conn, rw)
- acceptedConn.firstMsg, err = acceptedConn.readFirstV2Msg(wireConn)
+ acceptedConn.firstMsg, err = acceptedConn.readFirstV2Msg(conn, wireConn)
} else {
rw := msg.NewV1ReadWriter(conn)
acceptedConn.conn = msg.NewConn(conn, rw)
@@ -557,17 +565,41 @@ func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*accep
return acceptedConn, nil
}
+func writeWithDeadline(conn net.Conn, timeout time.Duration, writeFn func() error) error {
+ _ = conn.SetWriteDeadline(time.Now().Add(timeout))
+ defer func() {
+ _ = conn.SetWriteDeadline(time.Time{})
+ }()
+ return writeFn()
+}
+
func (ac *acceptedConnection) messageConnFor(rw io.ReadWriter) *msg.Conn {
return msg.NewConn(ac.conn, msg.NewReadWriter(rw, ac.wireProtocol))
}
-func (ac *acceptedConnection) readFirstV2Msg(wireConn *wire.Conn) (msg.Message, error) {
+func (ac *acceptedConnection) newControlReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
+ if ac.wireProtocol == wire.ProtocolV2 {
+ if ac.cryptoContext == nil {
+ return nil, fmt.Errorf("missing v2 crypto negotiation")
+ }
+ return netpkg.NewAEADCryptoReadWriter(
+ rw,
+ key,
+ netpkg.AEADCryptoRoleServer,
+ ac.cryptoContext.Algorithm,
+ ac.cryptoContext.TranscriptHash,
+ )
+ }
+ return netpkg.NewCryptoReadWriter(rw, key)
+}
+
+func (ac *acceptedConnection) readFirstV2Msg(conn net.Conn, wireConn *wire.Conn) (msg.Message, error) {
frame, err := wireConn.ReadFrame()
if err != nil {
return nil, fmt.Errorf("read v2 frame: %w", err)
}
if frame.Type == wire.FrameTypeClientHello {
- if err := ac.handleClientHello(wireConn, frame); err != nil {
+ if err := ac.handleClientHello(conn, wireConn, frame); err != nil {
return nil, err
}
frame, err = wireConn.ReadFrame()
@@ -583,21 +615,38 @@ func (ac *acceptedConnection) readFirstV2Msg(wireConn *wire.Conn) (msg.Message,
return m, nil
}
-func (ac *acceptedConnection) handleClientHello(wireConn *wire.Conn, frame *wire.Frame) error {
+func (ac *acceptedConnection) handleClientHello(conn net.Conn, wireConn *wire.Conn, frame *wire.Frame) error {
var hello wire.ClientHello
if err := wireConn.UnmarshalFrame(frame, &hello); err != nil {
return fmt.Errorf("decode ClientHello: %w", err)
}
- serverHello := wire.DefaultServerHello()
- if err := wire.ValidateClientHello(hello); err != nil {
+ serverHello, err := wire.NewServerHello(hello)
+ if err != nil {
+ serverHello = wire.DefaultServerHello()
serverHello.Error = err.Error()
- _ = wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello)
+ if writeErr := writeWithDeadline(conn, connWriteTimeout, func() error {
+ return wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello)
+ }); writeErr != nil {
+ return fmt.Errorf("%w; write ServerHello error: %v", err, writeErr)
+ }
return err
}
- if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello); err != nil {
+ serverHelloFrame, err := wire.NewJSONFrame(wire.FrameTypeServerHello, serverHello)
+ if err != nil {
+ return fmt.Errorf("encode ServerHello: %w", err)
+ }
+ cryptoContext := wire.NewCryptoContext(
+ serverHello.Selected.Crypto.Algorithm,
+ frame.Payload,
+ serverHelloFrame.Payload,
+ )
+ if err := writeWithDeadline(conn, connWriteTimeout, func() error {
+ return wireConn.WriteFrame(serverHelloFrame)
+ }); err != nil {
return fmt.Errorf("write ServerHello: %w", err)
}
+ ac.cryptoContext = cryptoContext
return nil
}
diff --git a/server/service_test.go b/server/service_test.go
new file mode 100644
index 00000000..cf5eec5e
--- /dev/null
+++ b/server/service_test.go
@@ -0,0 +1,63 @@
+// Copyright 2026 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package server
+
+import (
+ "errors"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestWriteWithDeadlineTimesOutAndClearsDeadline(t *testing.T) {
+ serverConn, clientConn := net.Pipe()
+ defer serverConn.Close()
+ defer clientConn.Close()
+
+ err := writeWithDeadline(serverConn, 50*time.Millisecond, func() error {
+ _, writeErr := serverConn.Write([]byte("x"))
+ return writeErr
+ })
+ require.Error(t, err)
+
+ var netErr net.Error
+ require.True(t, errors.As(err, &netErr))
+ require.True(t, netErr.Timeout())
+
+ readCh := make(chan byte, 1)
+ errCh := make(chan error, 1)
+ go func() {
+ buf := make([]byte, 1)
+ if _, readErr := clientConn.Read(buf); readErr != nil {
+ errCh <- readErr
+ return
+ }
+ readCh <- buf[0]
+ }()
+
+ _, err = serverConn.Write([]byte("y"))
+ require.NoError(t, err)
+
+ select {
+ case b := <-readCh:
+ require.Equal(t, byte('y'), b)
+ case err := <-errCh:
+ require.NoError(t, err)
+ case <-time.After(time.Second):
+ t.Fatal("timed out waiting for write after deadline reset")
+ }
+}