diff --git a/README.md b/README.md index 94d26664..d850c854 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,14 @@ frp is an open source project with its ongoing development made possible entirel

Gold Sponsors

+

+ + +
+ The complete IDE crafted for professional Go developers +
+

+

@@ -32,24 +40,6 @@ If you're looking for a meeting recording API, consider checking out [Recall.ai] an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more. - -

- - -
- Requestly - Free & Open-Source alternative to Postman -
- All-in-one platform to Test, Mock and Intercept APIs. -
-

- -

- - -
- The complete IDE crafted for professional Go developers -
-

## What is frp? diff --git a/README_zh.md b/README_zh.md index 9a03b031..a08d4401 100644 --- a/README_zh.md +++ b/README_zh.md @@ -15,6 +15,14 @@ frp 是一个完全开源的项目,我们的开发工作完全依靠赞助者

Gold Sponsors

+

+ + +
+ The complete IDE crafted for professional Go developers +
+

+

@@ -34,24 +42,6 @@ If you're looking for a meeting recording API, consider checking out [Recall.ai] an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more. - -

- - -
- Requestly - Free & Open-Source alternative to Postman -
- All-in-one platform to Test, Mock and Intercept APIs. -
-

- -

- - -
- The complete IDE crafted for professional Go developers -
-

## 为什么使用 frp ? diff --git a/Release.md b/Release.md index 42db658a..e724fd34 100644 --- a/Release.md +++ b/Release.md @@ -10,9 +10,12 @@ This release introduces wire protocol v2 as a transition path for future frpc/fr **The default value of `transport.wireProtocol` remains `v1` in this release.** Users can keep the default for now. To test v2 early, upgrade both frpc and frps to versions that support it, then set `transport.wireProtocol = "v2"` in frpc. A v2-enabled frpc cannot connect to an older frps. +When `transport.wireProtocol = "v2"` is enabled, the control channel uses negotiated AEAD encryption after the login handshake. Both frpc and frps must be upgraded to this release to use v2. + v1 will be deprecated when v2 becomes the default in a future release. It will continue to be supported until v0.78.0 is released, and may be removed in v0.78.0 or later. ## Features * Added `transport.wireProtocol` for frpc to select the internal message protocol used between frpc and frps. Supported values are `v1` and `v2`. * Added client protocol visibility in the frps dashboard and `/api/clients` API. Online clients now report their negotiated protocol as `v1` or `v2`. +* Wire protocol v2 now negotiates AEAD control-channel encryption. Supported algorithms are `xchacha20-poly1305` and `aes-256-gcm`; frpc advertises its preferred order based on local AES-GCM hardware support, and frps selects the first supported algorithm from that list. diff --git a/client/control_session.go b/client/control_session.go index 4438acfe..d533ba2d 100644 --- a/client/control_session.go +++ b/client/control_session.go @@ -74,17 +74,18 @@ func (d *controlSessionDialer) Dial(previousRunID string) (*SessionContext, erro return nil, err } - loginRespMsg, err := d.exchangeLogin(conn, loginMsg) + loginResult, err := d.exchangeLogin(conn, loginMsg) if err != nil { return nil, err } + loginRespMsg := loginResult.resp if loginRespMsg.Error != "" { return nil, errors.New(loginRespMsg.Error) } var controlRW io.ReadWriter = conn if d.clientSpec == nil || d.clientSpec.Type != "ssh-tunnel" { - controlRW, err = netpkg.NewCryptoReadWriter(conn, d.auth.EncryptionKey()) + controlRW, err = d.newControlReadWriter(conn, loginResult.crypto) if err != nil { return nil, fmt.Errorf("create control crypto read writer: %w", err) } @@ -125,9 +126,16 @@ func (d *controlSessionDialer) buildLoginMsg(previousRunID string) (*msg.Login, return loginMsg, nil } -func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login) (*msg.LoginResp, error) { +type loginExchangeResult struct { + resp *msg.LoginResp + crypto *wire.CryptoContext +} + +func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login) (*loginExchangeResult, error) { rw := msg.NewV1ReadWriter(conn) var wireConn *wire.Conn + var clientHello wire.ClientHello + var clientHelloPayload []byte if d.common.Transport.WireProtocol == wire.ProtocolV2 { if err := wire.WriteMagic(conn); err != nil { @@ -136,14 +144,23 @@ func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login) wireConn = wire.NewConn(conn) rw = msg.NewV2ReadWriterWithConn(wireConn) - hello := wire.DefaultClientHello(wire.BootstrapInfo{ + var err error + clientHello, err = wire.NewClientHello(wire.BootstrapInfo{ Transport: d.common.Transport.Protocol, TLS: lo.FromPtr(d.common.Transport.TLS.Enable) || d.common.Transport.Protocol == "wss" || d.common.Transport.Protocol == "quic", TCPMux: lo.FromPtr(d.common.Transport.TCPMux), }) - if err := wireConn.WriteJSONFrame(wire.FrameTypeClientHello, hello); err != nil { + if err != nil { return nil, err } + clientHelloFrame, err := wire.NewJSONFrame(wire.FrameTypeClientHello, clientHello) + if err != nil { + return nil, err + } + if err := wireConn.WriteFrame(clientHelloFrame); err != nil { + return nil, err + } + clientHelloPayload = clientHelloFrame.Payload } if err := rw.WriteMsg(loginMsg); err != nil { return nil, err @@ -154,19 +171,50 @@ func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login) _ = conn.SetReadDeadline(time.Time{}) }() + var cryptoContext *wire.CryptoContext if wireConn != nil { + serverHelloFrame, err := wireConn.ReadFrame() + if err != nil { + return nil, err + } + if serverHelloFrame.Type != wire.FrameTypeServerHello { + return nil, fmt.Errorf("unexpected frame type %d, want %d", serverHelloFrame.Type, wire.FrameTypeServerHello) + } var serverHello wire.ServerHello - if err := wireConn.ReadJSONFrame(wire.FrameTypeServerHello, &serverHello); err != nil { + if err := wireConn.UnmarshalFrame(serverHelloFrame, &serverHello); err != nil { return nil, err } if serverHello.Error != "" { return nil, errors.New(serverHello.Error) } + cryptoContext, err = wire.NewClientCryptoContext(clientHelloPayload, serverHelloFrame.Payload) + if err != nil { + return nil, err + } } var loginRespMsg msg.LoginResp if err := rw.ReadMsgInto(&loginRespMsg); err != nil { return nil, err } - return &loginRespMsg, nil + return &loginExchangeResult{ + resp: &loginRespMsg, + crypto: cryptoContext, + }, nil +} + +func (d *controlSessionDialer) newControlReadWriter(conn net.Conn, cryptoContext *wire.CryptoContext) (io.ReadWriter, error) { + if d.common.Transport.WireProtocol == wire.ProtocolV2 { + if cryptoContext == nil { + return nil, errors.New("missing v2 crypto negotiation") + } + return netpkg.NewAEADCryptoReadWriter( + conn, + d.auth.EncryptionKey(), + netpkg.AEADCryptoRoleClient, + cryptoContext.Algorithm, + cryptoContext.TranscriptHash, + ) + } + return netpkg.NewCryptoReadWriter(conn, d.auth.EncryptionKey()) } diff --git a/client/control_session_test.go b/client/control_session_test.go index 9e59c9cd..a0778fba 100644 --- a/client/control_session_test.go +++ b/client/control_session_test.go @@ -29,6 +29,7 @@ import ( v1 "github.com/fatedier/frp/pkg/config/v1" "github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/proto/wire" + netpkg "github.com/fatedier/frp/pkg/util/net" ) type testConnector struct { @@ -140,8 +141,17 @@ func TestControlSessionDialerDialV2(t *testing.T) { } wireConn := wire.NewConn(serverRaw) + clientHelloFrame, err := wireConn.ReadFrame() + if err != nil { + serverErrCh <- err + return + } + if clientHelloFrame.Type != wire.FrameTypeClientHello { + serverErrCh <- fmt.Errorf("unexpected frame type %d, want %d", clientHelloFrame.Type, wire.FrameTypeClientHello) + return + } var hello wire.ClientHello - if err := wireConn.ReadJSONFrame(wire.FrameTypeClientHello, &hello); err != nil { + if err := wireConn.UnmarshalFrame(clientHelloFrame, &hello); err != nil { serverErrCh <- err return } @@ -160,11 +170,52 @@ func TestControlSessionDialerDialV2(t *testing.T) { serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User) return } - if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, wire.DefaultServerHello()); err != nil { + serverHello, err := wire.NewServerHello(hello) + if err != nil { serverErrCh <- err return } - serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"}) + serverHelloFrame, err := wire.NewJSONFrame(wire.FrameTypeServerHello, serverHello) + if err != nil { + serverErrCh <- err + return + } + cryptoContext := wire.NewCryptoContext( + serverHello.Selected.Crypto.Algorithm, + clientHelloFrame.Payload, + serverHelloFrame.Payload, + ) + if err := wireConn.WriteFrame(serverHelloFrame); err != nil { + serverErrCh <- err + return + } + if err := rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"}); err != nil { + serverErrCh <- err + return + } + + controlRW, err := netpkg.NewAEADCryptoReadWriter( + serverRaw, + []byte("token"), + netpkg.AEADCryptoRoleServer, + cryptoContext.Algorithm, + cryptoContext.TranscriptHash, + ) + if err != nil { + serverErrCh <- err + return + } + controlMsgRW := msg.NewReadWriter(controlRW, wire.ProtocolV2) + var ping msg.Ping + if err := controlMsgRW.ReadMsgInto(&ping); err != nil { + serverErrCh <- err + return + } + if ping.PrivilegeKey != "v2-ping" || ping.Timestamp != 12345 { + serverErrCh <- fmt.Errorf("unexpected ping: %+v", ping) + return + } + serverErrCh <- nil }() dialer := newTestControlSessionDialer(t, wire.ProtocolV2, connector, nil) @@ -177,6 +228,7 @@ func TestControlSessionDialerDialV2(t *testing.T) { require.NotNil(t, sessionCtx.Conn) require.NotNil(t, sessionCtx.Connector) require.False(t, connector.closed.Load()) + require.NoError(t, sessionCtx.Conn.WriteMsg(&msg.Ping{PrivilegeKey: "v2-ping", Timestamp: 12345})) require.NoError(t, <-serverErrCh) } diff --git a/doc/agents/release.md b/doc/agents/release.md index d8ec6267..d353a088 100644 --- a/doc/agents/release.md +++ b/doc/agents/release.md @@ -33,7 +33,51 @@ git commit -m "bump version to vX.Y.Z" git push origin dev ``` -## 3. Merge dev → master +## 3. Pre-release Validation + +Run the standard e2e suite locally: + +```bash +make e2e +``` + +For releases that touch compatibility-sensitive areas such as login, control +connections, work connections, visitors, transport, or wire protocol handling, +also run the manual compatibility e2e suite: + +```bash +make e2e-compatibility +make e2e-compatibility-floor +``` + +`make e2e-compatibility` builds the current `frps` and `frpc`, resolves the +recent stable release baselines from GitHub, downloads or reuses their binaries, +and tests current binaries against those historical releases. The default number +of recent baselines is controlled by `FRP_COMPAT_BASELINE_COUNT` in the +`Makefile`. + +Downloaded release binaries are cached under: + +```text +.cache/e2e-compat//_/ +``` + +For a release validation run that must be exactly reproducible, pass an explicit +baseline matrix instead of using the floating recent-release list: + +```bash +FRP_COMPAT_BASELINE_VERSIONS="0.X.0 0.Y.0" make e2e-compatibility +``` + +Use `make e2e-compatibility-smoke` for a quick single-baseline check while +iterating locally. If GitHub release metadata requests are rate-limited, set +`GITHUB_TOKEN` or use `FRP_COMPAT_BASELINE_VERSIONS`. + +The compatibility floor is a support-policy decision, not a value that should +change every release. Update `FRP_COMPAT_FLOOR_VERSION` only when the declared +compatibility window changes. + +## 4. Merge dev → master Create a PR from `dev` to `master`: @@ -43,7 +87,7 @@ gh pr create --base master --head dev --title "bump version" Wait for CI to pass, then merge using **merge commit** (not squash). -## 4. Tag the Release +## 5. Tag the Release ```bash git checkout master @@ -52,7 +96,7 @@ git tag -a vX.Y.Z -m "bump version" git push origin vX.Y.Z ``` -## 5. Trigger GoReleaser +## 6. Trigger GoReleaser Manually trigger the `goreleaser` workflow in GitHub Actions: diff --git a/go.mod b/go.mod index e3a1ddd8..46bbc2d4 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.25.0 require ( github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 github.com/coreos/go-oidc/v3 v3.14.1 - github.com/fatedier/golib v0.6.0 + github.com/fatedier/golib v0.7.0 github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 github.com/gorilla/websocket v1.5.0 @@ -30,6 +30,7 @@ require ( golang.org/x/net v0.52.0 golang.org/x/oauth2 v0.28.0 golang.org/x/sync v0.20.0 + golang.org/x/sys v0.42.0 golang.org/x/time v0.10.0 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 gopkg.in/ini.v1 v1.67.0 @@ -68,7 +69,6 @@ require ( github.com/wlynxg/anet v0.0.5 // indirect go.uber.org/automaxprocs v1.6.0 // indirect golang.org/x/mod v0.33.0 // indirect - golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect golang.org/x/tools v0.42.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index 4a5082c0..2fad6ea2 100644 --- a/go.sum +++ b/go.sum @@ -20,8 +20,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fatedier/golib v0.6.0 h1:/mgBZZbkbMhIEZoXf7nV8knpUDzas/b+2ruYKxx1lww= -github.com/fatedier/golib v0.6.0/go.mod h1:ArUGvPg2cOw/py2RAuBt46nNZH2VQ5Z70p109MAZpJw= +github.com/fatedier/golib v0.7.0 h1:tMDF9ObcwVt59VUHroJOzHQjVFPLymZVMpGm9WAVwhY= +github.com/fatedier/golib v0.7.0/go.mod h1:ArUGvPg2cOw/py2RAuBt46nNZH2VQ5Z70p109MAZpJw= github.com/fatedier/yamux v0.0.0-20250825093530-d0154be01cd6 h1:u92UUy6FURPmNsMBUuongRWC0rBqN6gd01Dzu+D21NE= github.com/fatedier/yamux v0.0.0-20250825093530-d0154be01cd6/go.mod h1:c5/tk6G0dSpXGzJN7Wk1OEie8grdSJAmeawId9Zvd34= github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE= diff --git a/pkg/proto/wire/crypto.go b/pkg/proto/wire/crypto.go new file mode 100644 index 00000000..69aedc69 --- /dev/null +++ b/pkg/proto/wire/crypto.go @@ -0,0 +1,197 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "encoding/json" + "fmt" + "hash" + "runtime" + + "golang.org/x/sys/cpu" +) + +const ( + AEADAlgorithmAES256GCM = "aes-256-gcm" + AEADAlgorithmXChaCha20Poly1305 = "xchacha20-poly1305" + + CryptoRandomSize = 32 + + cryptoTranscriptLabel = "frp wire v2 crypto transcript" +) + +var supportedAEADAlgorithms = []string{ + AEADAlgorithmAES256GCM, + AEADAlgorithmXChaCha20Poly1305, +} + +type CryptoContext struct { + Algorithm string + TranscriptHash []byte +} + +func NewClientHello(bootstrap BootstrapInfo) (ClientHello, error) { + clientRandom, err := newCryptoRandom() + if err != nil { + return ClientHello{}, err + } + return clientHelloWithCryptoRandom(bootstrap, clientRandom), nil +} + +func NewServerHello(clientHello ClientHello) (ServerHello, error) { + if err := ValidateClientHello(clientHello); err != nil { + return ServerHello{}, err + } + algorithm, ok := SelectAEADAlgorithm(clientHello.Capabilities.Crypto.Algorithms) + if !ok { + return ServerHello{}, fmt.Errorf("no supported crypto algorithm") + } + serverRandom, err := newCryptoRandom() + if err != nil { + return ServerHello{}, err + } + return ServerHello{ + Selected: ServerSelection{ + Message: MessageSelection{ + Codec: MessageCodecJSON, + }, + Crypto: CryptoSelection{ + Algorithm: algorithm, + ServerRandom: serverRandom, + }, + }, + }, nil +} + +func ValidateCryptoCapabilities(c CryptoCapabilities) error { + if len(c.ClientRandom) != CryptoRandomSize { + return fmt.Errorf("invalid crypto client random length %d, want %d", len(c.ClientRandom), CryptoRandomSize) + } + if _, ok := SelectAEADAlgorithm(c.Algorithms); !ok { + return fmt.Errorf("no supported crypto algorithm") + } + return nil +} + +func ValidateServerHelloForClient(clientHello ClientHello, serverHello ServerHello) error { + if serverHello.Selected.Message.Codec != MessageCodecJSON { + return fmt.Errorf("unsupported selected message codec: %s", serverHello.Selected.Message.Codec) + } + cryptoSelection := serverHello.Selected.Crypto + if !IsSupportedAEADAlgorithm(cryptoSelection.Algorithm) { + return fmt.Errorf("unknown selected crypto algorithm: %s", cryptoSelection.Algorithm) + } + if !Supports(clientHello.Capabilities.Crypto.Algorithms, cryptoSelection.Algorithm) { + return fmt.Errorf("selected crypto algorithm was not advertised by client: %s", cryptoSelection.Algorithm) + } + if len(cryptoSelection.ServerRandom) != CryptoRandomSize { + return fmt.Errorf("invalid crypto server random length %d, want %d", len(cryptoSelection.ServerRandom), CryptoRandomSize) + } + return nil +} + +func NewCryptoContext(algorithm string, clientHelloPayload, serverHelloPayload []byte) *CryptoContext { + return &CryptoContext{ + Algorithm: algorithm, + TranscriptHash: HashCryptoTranscript(clientHelloPayload, serverHelloPayload), + } +} + +func NewClientCryptoContext(clientHelloPayload, serverHelloPayload []byte) (*CryptoContext, error) { + var clientHello ClientHello + if err := json.Unmarshal(clientHelloPayload, &clientHello); err != nil { + return nil, fmt.Errorf("decode ClientHello transcript: %w", err) + } + var serverHello ServerHello + if err := json.Unmarshal(serverHelloPayload, &serverHello); err != nil { + return nil, fmt.Errorf("decode ServerHello transcript: %w", err) + } + if err := ValidateServerHelloForClient(clientHello, serverHello); err != nil { + return nil, err + } + + return NewCryptoContext(serverHello.Selected.Crypto.Algorithm, clientHelloPayload, serverHelloPayload), nil +} + +func HashCryptoTranscript(clientHelloPayload, serverHelloPayload []byte) []byte { + h := sha256.New() + _, _ = h.Write([]byte(cryptoTranscriptLabel)) + writeCryptoTranscriptPart(h, "client hello", clientHelloPayload) + writeCryptoTranscriptPart(h, "server hello", serverHelloPayload) + return h.Sum(nil) +} + +func writeCryptoTranscriptPart(h hash.Hash, label string, payload []byte) { + var length [8]byte + binary.BigEndian.PutUint64(length[:], uint64(len(payload))) + _, _ = h.Write([]byte{0}) + _, _ = h.Write([]byte(label)) + _, _ = h.Write([]byte{0}) + _, _ = h.Write(length[:]) + _, _ = h.Write(payload) +} + +func PreferredAEADAlgorithms() []string { + if hasFastAESGCM() { + return []string{AEADAlgorithmAES256GCM, AEADAlgorithmXChaCha20Poly1305} + } + return []string{AEADAlgorithmXChaCha20Poly1305, AEADAlgorithmAES256GCM} +} + +func SelectAEADAlgorithm(clientAlgorithms []string) (string, bool) { + for _, algorithm := range clientAlgorithms { + if IsSupportedAEADAlgorithm(algorithm) { + return algorithm, true + } + } + return "", false +} + +func IsSupportedAEADAlgorithm(algorithm string) bool { + return Supports(supportedAEADAlgorithms, algorithm) +} + +func newCryptoRandom() ([]byte, error) { + b := make([]byte, CryptoRandomSize) + if _, err := rand.Read(b); err != nil { + return nil, fmt.Errorf("generate crypto random: %w", err) + } + return b, nil +} + +func hasFastAESGCM() bool { + switch runtime.GOARCH { + case "amd64": + return cpu.X86.HasAES && + cpu.X86.HasPCLMULQDQ && + cpu.X86.HasSSE41 && + cpu.X86.HasSSSE3 + case "arm64": + return cpu.ARM64.HasAES && cpu.ARM64.HasPMULL + case "s390x": + return cpu.S390X.HasAES && + cpu.S390X.HasAESCTR && + cpu.S390X.HasGHASH + case "ppc64", "ppc64le": + // Go's ppc64/ppc64le port targets POWER8+, which has AES instructions; + // x/sys/cpu does not expose a PPC64 AES feature flag. + return true + default: + return false + } +} diff --git a/pkg/proto/wire/wire.go b/pkg/proto/wire/wire.go index 5c38292d..47bf5984 100644 --- a/pkg/proto/wire/wire.go +++ b/pkg/proto/wire/wire.go @@ -120,15 +120,23 @@ func (c *Conn) UnmarshalFrame(f *Frame, out any) error { return json.Unmarshal(f.Payload, out) } -func (c *Conn) WriteJSONFrame(frameType uint16, in any) error { +func NewJSONFrame(frameType uint16, in any) (*Frame, error) { payload, err := json.Marshal(in) + if err != nil { + return nil, err + } + return &Frame{ + Type: frameType, + Payload: payload, + }, nil +} + +func (c *Conn) WriteJSONFrame(frameType uint16, in any) error { + f, err := NewJSONFrame(frameType, in) if err != nil { return err } - return c.WriteFrame(&Frame{ - Type: frameType, - Payload: payload, - }) + return c.WriteFrame(f) } func WriteMagic(w io.Writer) error { @@ -170,12 +178,18 @@ type ClientHello struct { type ClientCapabilities struct { Message MessageCapabilities `json:"message,omitempty"` + Crypto CryptoCapabilities `json:"crypto,omitempty"` } type MessageCapabilities struct { Codecs []string `json:"codecs,omitempty"` } +type CryptoCapabilities struct { + Algorithms []string `json:"algorithms,omitempty"` + ClientRandom []byte `json:"clientRandom,omitempty"` +} + type ServerHello struct { Selected ServerSelection `json:"selected,omitempty"` Error string `json:"error,omitempty"` @@ -183,19 +197,29 @@ type ServerHello struct { type ServerSelection struct { Message MessageSelection `json:"message,omitempty"` + Crypto CryptoSelection `json:"crypto,omitempty"` } type MessageSelection struct { Codec string `json:"codec,omitempty"` } -func DefaultClientHello(bootstrap BootstrapInfo) ClientHello { +type CryptoSelection struct { + Algorithm string `json:"algorithm,omitempty"` + ServerRandom []byte `json:"serverRandom,omitempty"` +} + +func clientHelloWithCryptoRandom(bootstrap BootstrapInfo, clientRandom []byte) ClientHello { return ClientHello{ Bootstrap: bootstrap, Capabilities: ClientCapabilities{ Message: MessageCapabilities{ Codecs: []string{MessageCodecJSON}, }, + Crypto: CryptoCapabilities{ + Algorithms: PreferredAEADAlgorithms(), + ClientRandom: clientRandom, + }, }, } } @@ -218,5 +242,5 @@ func ValidateClientHello(h ClientHello) error { if !Supports(h.Capabilities.Message.Codecs, MessageCodecJSON) { return fmt.Errorf("unsupported message codec") } - return nil + return ValidateCryptoCapabilities(h.Capabilities.Crypto) } diff --git a/pkg/proto/wire/wire_test.go b/pkg/proto/wire/wire_test.go index 19adb0be..b564f712 100644 --- a/pkg/proto/wire/wire_test.go +++ b/pkg/proto/wire/wire_test.go @@ -28,7 +28,7 @@ func TestFrameRoundTrip(t *testing.T) { var buf bytes.Buffer conn := NewConn(&buf) - in := DefaultClientHello(BootstrapInfo{ + in := mustClientHello(t, BootstrapInfo{ Transport: "tcp", TLS: true, TCPMux: true, @@ -112,9 +112,122 @@ func TestCheckMagicV1PreservesReadBytes(t *testing.T) { } func TestValidateClientHello(t *testing.T) { - require.NoError(t, ValidateClientHello(DefaultClientHello(BootstrapInfo{}))) + hello := mustClientHello(t, BootstrapInfo{}) + require.NoError(t, ValidateClientHello(hello)) + require.Len(t, hello.Capabilities.Crypto.ClientRandom, CryptoRandomSize) + require.ElementsMatch(t, []string{ + AEADAlgorithmAES256GCM, + AEADAlgorithmXChaCha20Poly1305, + }, hello.Capabilities.Crypto.Algorithms) - hello := DefaultClientHello(BootstrapInfo{}) hello.Capabilities.Message.Codecs = []string{"unknown"} require.ErrorContains(t, ValidateClientHello(hello), "unsupported message codec") } + +func TestValidateClientHelloRejectsInvalidCrypto(t *testing.T) { + hello := mustClientHello(t, BootstrapInfo{}) + hello.Capabilities.Crypto.ClientRandom = hello.Capabilities.Crypto.ClientRandom[:CryptoRandomSize-1] + require.ErrorContains(t, ValidateClientHello(hello), "invalid crypto client random length") + + hello = mustClientHello(t, BootstrapInfo{}) + hello.Capabilities.Crypto.Algorithms = []string{"unknown"} + require.ErrorContains(t, ValidateClientHello(hello), "no supported crypto algorithm") +} + +func TestPreferredAEADAlgorithms(t *testing.T) { + require.ElementsMatch(t, []string{ + AEADAlgorithmAES256GCM, + AEADAlgorithmXChaCha20Poly1305, + }, PreferredAEADAlgorithms()) +} + +func TestNewServerHelloSelectsFirstSupportedAEADAlgorithm(t *testing.T) { + hello := mustClientHello(t, BootstrapInfo{}) + hello.Capabilities.Crypto.Algorithms = []string{"future-aead", AEADAlgorithmXChaCha20Poly1305, AEADAlgorithmAES256GCM} + + serverHello, err := NewServerHello(hello) + require.NoError(t, err) + require.Equal(t, MessageCodecJSON, serverHello.Selected.Message.Codec) + require.Equal(t, AEADAlgorithmXChaCha20Poly1305, serverHello.Selected.Crypto.Algorithm) + require.Len(t, serverHello.Selected.Crypto.ServerRandom, CryptoRandomSize) +} + +func TestNewClientCryptoContextValidatesServerHello(t *testing.T) { + hello := mustClientHello(t, BootstrapInfo{}) + serverHello, err := NewServerHello(hello) + require.NoError(t, err) + clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello) + + ctx, err := NewClientCryptoContext(clientHelloPayload, serverHelloPayload) + require.NoError(t, err) + require.Equal(t, serverHello.Selected.Crypto.Algorithm, ctx.Algorithm) + require.Len(t, ctx.TranscriptHash, 32) + + tampered := serverHello + tampered.Selected.Crypto.ServerRandom = append([]byte(nil), serverHello.Selected.Crypto.ServerRandom...) + tampered.Selected.Crypto.ServerRandom[0] ^= 0xff + _, tamperedServerHelloPayload := mustCryptoTranscriptPayloads(t, hello, tampered) + tamperedCtx, err := NewClientCryptoContext(clientHelloPayload, tamperedServerHelloPayload) + require.NoError(t, err) + require.NotEqual(t, ctx.TranscriptHash, tamperedCtx.TranscriptHash) +} + +func TestNewCryptoContextBindsFullClientHelloPayload(t *testing.T) { + hello := mustClientHello(t, BootstrapInfo{ + Transport: "tcp", + TLS: true, + TCPMux: true, + }) + serverHello, err := NewServerHello(hello) + require.NoError(t, err) + clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello) + + ctx := NewCryptoContext(serverHello.Selected.Crypto.Algorithm, clientHelloPayload, serverHelloPayload) + + tamperedHello := hello + tamperedHello.Bootstrap.TLS = false + tamperedClientHelloPayload, _ := mustCryptoTranscriptPayloads(t, tamperedHello, serverHello) + tamperedCtx := NewCryptoContext(serverHello.Selected.Crypto.Algorithm, tamperedClientHelloPayload, serverHelloPayload) + require.NotEqual(t, ctx.TranscriptHash, tamperedCtx.TranscriptHash) +} + +func TestNewClientCryptoContextRejectsUnknownServerSelection(t *testing.T) { + hello := mustClientHello(t, BootstrapInfo{}) + serverHello, err := NewServerHello(hello) + require.NoError(t, err) + + serverHello.Selected.Crypto.Algorithm = "unknown" + clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello) + _, err = NewClientCryptoContext(clientHelloPayload, serverHelloPayload) + require.ErrorContains(t, err, "unknown selected crypto algorithm") +} + +func TestNewClientCryptoContextRejectsUnadvertisedServerSelection(t *testing.T) { + hello := mustClientHello(t, BootstrapInfo{}) + hello.Capabilities.Crypto.Algorithms = []string{AEADAlgorithmAES256GCM} + serverHello, err := NewServerHello(hello) + require.NoError(t, err) + + serverHello.Selected.Crypto.Algorithm = AEADAlgorithmXChaCha20Poly1305 + clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello) + _, err = NewClientCryptoContext(clientHelloPayload, serverHelloPayload) + require.ErrorContains(t, err, "selected crypto algorithm was not advertised by client") +} + +func mustClientHello(t *testing.T, bootstrap BootstrapInfo) ClientHello { + t.Helper() + + hello, err := NewClientHello(bootstrap) + require.NoError(t, err) + return hello +} + +func mustCryptoTranscriptPayloads(t *testing.T, hello ClientHello, serverHello ServerHello) ([]byte, []byte) { + t.Helper() + + clientHelloFrame, err := NewJSONFrame(FrameTypeClientHello, hello) + require.NoError(t, err) + serverHelloFrame, err := NewJSONFrame(FrameTypeServerHello, serverHello) + require.NoError(t, err) + return clientHelloFrame.Payload, serverHelloFrame.Payload +} diff --git a/pkg/util/net/conn.go b/pkg/util/net/conn.go index aa83b409..5dd605a5 100644 --- a/pkg/util/net/conn.go +++ b/pkg/util/net/conn.go @@ -16,14 +16,16 @@ package net import ( "context" + "crypto/sha256" "errors" "io" "net" "sync/atomic" "time" - "github.com/fatedier/golib/crypto" + libcrypto "github.com/fatedier/golib/crypto" quic "github.com/quic-go/quic-go" + "golang.org/x/crypto/hkdf" "github.com/fatedier/frp/pkg/util/xlog" ) @@ -241,8 +243,8 @@ func (conn *wrapQuicStream) Close() error { } func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) { - encReader := crypto.NewReader(rw, key) - encWriter, err := crypto.NewWriter(rw, key) + encReader := libcrypto.NewReader(rw, key) + encWriter, err := libcrypto.NewWriter(rw, key) if err != nil { return nil, err } @@ -254,3 +256,90 @@ func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) { Writer: encWriter, }, nil } + +type AEADCryptoRole int + +const ( + AEADCryptoRoleClient AEADCryptoRole = iota + 1 + AEADCryptoRoleServer +) + +const ( + aeadControlHKDFInfoPrefix = "frp wire v2 control aead" + aeadDirectionClientToServer = "client-to-server" + aeadDirectionServerToClient = "server-to-client" +) + +// NewAEADCryptoReadWriter wraps rw with framed AEAD encryption for the v2 +// control channel. Frames and their order are authenticated, but end-of-stream +// is not: a clean EOF at a frame boundary is returned as normal EOF by the +// underlying AEAD stream. Protocols that need truncation detection for finite +// objects must add their own authenticated final message. +func NewAEADCryptoReadWriter( + rw io.ReadWriter, + key []byte, + role AEADCryptoRole, + algorithm string, + transcriptHash []byte, +) (io.ReadWriter, error) { + clientToServerKey, serverToClientKey, err := deriveAEADControlKeys(key, algorithm, transcriptHash) + if err != nil { + return nil, err + } + + var readKey, writeKey []byte + switch role { + case AEADCryptoRoleClient: + readKey = serverToClientKey + writeKey = clientToServerKey + case AEADCryptoRoleServer: + readKey = clientToServerKey + writeKey = serverToClientKey + default: + return nil, errors.New("invalid aead crypto role") + } + + encReader, err := libcrypto.NewAEADStreamReader(rw, libcrypto.AEADStreamOptions{ + Algorithm: libcrypto.AEADAlgorithm(algorithm), + Key: readKey, + }) + if err != nil { + return nil, err + } + encWriter, err := libcrypto.NewAEADStreamWriter(rw, libcrypto.AEADStreamOptions{ + Algorithm: libcrypto.AEADAlgorithm(algorithm), + Key: writeKey, + }) + if err != nil { + return nil, err + } + return struct { + io.Reader + io.Writer + }{ + Reader: encReader, + Writer: encWriter, + }, nil +} + +func deriveAEADControlKeys(key []byte, algorithm string, transcriptHash []byte) (clientToServerKey, serverToClientKey []byte, err error) { + clientToServerKey, err = deriveAEADControlKey(key, algorithm, transcriptHash, aeadDirectionClientToServer) + if err != nil { + return nil, nil, err + } + serverToClientKey, err = deriveAEADControlKey(key, algorithm, transcriptHash, aeadDirectionServerToClient) + if err != nil { + return nil, nil, err + } + return clientToServerKey, serverToClientKey, nil +} + +func deriveAEADControlKey(key []byte, algorithm string, transcriptHash []byte, direction string) ([]byte, error) { + info := []byte(aeadControlHKDFInfoPrefix + " " + algorithm + " " + direction) + reader := hkdf.New(sha256.New, key, transcriptHash, info) + out := make([]byte, libcrypto.AEADKeySize) + if _, err := io.ReadFull(reader, out); err != nil { + return nil, err + } + return out, nil +} diff --git a/pkg/util/net/conn_test.go b/pkg/util/net/conn_test.go new file mode 100644 index 00000000..42d3c06b --- /dev/null +++ b/pkg/util/net/conn_test.go @@ -0,0 +1,118 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "bytes" + "io" + stdnet "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/fatedier/frp/pkg/proto/wire" +) + +func TestNewAEADCryptoReadWriterRoundTrip(t *testing.T) { + clientConn, serverConn := stdnet.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + key := []byte("token") + transcriptHash := bytes.Repeat([]byte{0x11}, 32) + clientRW, err := NewAEADCryptoReadWriter( + clientConn, + key, + AEADCryptoRoleClient, + wire.AEADAlgorithmXChaCha20Poly1305, + transcriptHash, + ) + require.NoError(t, err) + serverRW, err := NewAEADCryptoReadWriter( + serverConn, + key, + AEADCryptoRoleServer, + wire.AEADAlgorithmXChaCha20Poly1305, + transcriptHash, + ) + require.NoError(t, err) + + clientErrCh := make(chan error, 1) + go func() { + if _, err := clientRW.Write([]byte("ping")); err != nil { + clientErrCh <- err + return + } + buf := make([]byte, len("pong")) + _, err := io.ReadFull(clientRW, buf) + clientErrCh <- err + }() + + buf := make([]byte, len("ping")) + _, err = io.ReadFull(serverRW, buf) + require.NoError(t, err) + require.Equal(t, "ping", string(buf)) + _, err = serverRW.Write([]byte("pong")) + require.NoError(t, err) + require.NoError(t, <-clientErrCh) +} + +func TestNewAEADCryptoReadWriterRejectsDifferentTranscript(t *testing.T) { + clientConn, serverConn := stdnet.Pipe() + defer clientConn.Close() + defer serverConn.Close() + require.NoError(t, clientConn.SetDeadline(time.Now().Add(time.Second))) + require.NoError(t, serverConn.SetDeadline(time.Now().Add(time.Second))) + + key := []byte("token") + clientRW, err := NewAEADCryptoReadWriter( + clientConn, + key, + AEADCryptoRoleClient, + wire.AEADAlgorithmAES256GCM, + bytes.Repeat([]byte{0x22}, 32), + ) + require.NoError(t, err) + serverRW, err := NewAEADCryptoReadWriter( + serverConn, + key, + AEADCryptoRoleServer, + wire.AEADAlgorithmAES256GCM, + bytes.Repeat([]byte{0x33}, 32), + ) + require.NoError(t, err) + + writeErrCh := make(chan error, 1) + go func() { + _, err := clientRW.Write([]byte("ping")) + writeErrCh <- err + }() + + buf := make([]byte, len("ping")) + _, err = io.ReadFull(serverRW, buf) + require.Error(t, err) + require.NoError(t, <-writeErrCh) +} + +func TestDeriveAEADControlKeysUsesDistinctDirections(t *testing.T) { + clientToServerKey, serverToClientKey, err := deriveAEADControlKeys( + []byte("token"), + wire.AEADAlgorithmXChaCha20Poly1305, + bytes.Repeat([]byte{0x44}, 32), + ) + require.NoError(t, err) + require.NotEqual(t, clientToServerKey, serverToClientKey) +} diff --git a/server/service.go b/server/service.go index 5077c21f..7ec303e4 100644 --- a/server/service.go +++ b/server/service.go @@ -60,6 +60,7 @@ import ( const ( connReadTimeout time.Duration = 10 * time.Second + connWriteTimeout time.Duration = 5 * time.Second vhostReadWriteTimeout time.Duration = 30 * time.Second ) @@ -456,7 +457,7 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna controlConn := acceptedConn.conn if !internal { var controlRW io.ReadWriter - controlRW, err = netpkg.NewCryptoReadWriter(conn, svr.auth.EncryptionKey()) + controlRW, err = acceptedConn.newControlReadWriter(conn, svr.auth.EncryptionKey()) if err == nil { controlConn = acceptedConn.messageConnFor(controlRW) } @@ -468,17 +469,23 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna if err != nil { xl.Warnf("register control error: %v", err) - _ = acceptedConn.conn.WriteMsg(&msg.LoginResp{ - Version: version.Full(), - Error: util.GenerateResponseErrorString("register control error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)), - }) + if writeErr := writeWithDeadline(conn, connWriteTimeout, func() error { + return acceptedConn.conn.WriteMsg(&msg.LoginResp{ + Version: version.Full(), + Error: util.GenerateResponseErrorString("register control error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)), + }) + }); writeErr != nil { + xl.Warnf("write login error response error: %v", writeErr) + } conn.Close() return } - if err = acceptedConn.conn.WriteMsg(&msg.LoginResp{ - Version: version.Full(), - RunID: ctl.runID, - Error: "", + if err = writeWithDeadline(conn, connWriteTimeout, func() error { + return acceptedConn.conn.WriteMsg(&msg.LoginResp{ + Version: version.Full(), + RunID: ctl.runID, + Error: "", + }) }); err != nil { xl.Warnf("write login response error: %v", err) svr.ctlManager.Del(m.RunID, ctl) @@ -521,9 +528,10 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna } type acceptedConnection struct { - conn *msg.Conn - wireProtocol string - firstMsg msg.Message + conn *msg.Conn + wireProtocol string + cryptoContext *wire.CryptoContext + firstMsg msg.Message } func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*acceptedConnection, error) { @@ -544,7 +552,7 @@ func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*accep wireConn := wire.NewConn(conn) rw := msg.NewV2ReadWriterWithConn(wireConn) acceptedConn.conn = msg.NewConn(conn, rw) - acceptedConn.firstMsg, err = acceptedConn.readFirstV2Msg(wireConn) + acceptedConn.firstMsg, err = acceptedConn.readFirstV2Msg(conn, wireConn) } else { rw := msg.NewV1ReadWriter(conn) acceptedConn.conn = msg.NewConn(conn, rw) @@ -557,17 +565,41 @@ func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*accep return acceptedConn, nil } +func writeWithDeadline(conn net.Conn, timeout time.Duration, writeFn func() error) error { + _ = conn.SetWriteDeadline(time.Now().Add(timeout)) + defer func() { + _ = conn.SetWriteDeadline(time.Time{}) + }() + return writeFn() +} + func (ac *acceptedConnection) messageConnFor(rw io.ReadWriter) *msg.Conn { return msg.NewConn(ac.conn, msg.NewReadWriter(rw, ac.wireProtocol)) } -func (ac *acceptedConnection) readFirstV2Msg(wireConn *wire.Conn) (msg.Message, error) { +func (ac *acceptedConnection) newControlReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) { + if ac.wireProtocol == wire.ProtocolV2 { + if ac.cryptoContext == nil { + return nil, fmt.Errorf("missing v2 crypto negotiation") + } + return netpkg.NewAEADCryptoReadWriter( + rw, + key, + netpkg.AEADCryptoRoleServer, + ac.cryptoContext.Algorithm, + ac.cryptoContext.TranscriptHash, + ) + } + return netpkg.NewCryptoReadWriter(rw, key) +} + +func (ac *acceptedConnection) readFirstV2Msg(conn net.Conn, wireConn *wire.Conn) (msg.Message, error) { frame, err := wireConn.ReadFrame() if err != nil { return nil, fmt.Errorf("read v2 frame: %w", err) } if frame.Type == wire.FrameTypeClientHello { - if err := ac.handleClientHello(wireConn, frame); err != nil { + if err := ac.handleClientHello(conn, wireConn, frame); err != nil { return nil, err } frame, err = wireConn.ReadFrame() @@ -583,21 +615,38 @@ func (ac *acceptedConnection) readFirstV2Msg(wireConn *wire.Conn) (msg.Message, return m, nil } -func (ac *acceptedConnection) handleClientHello(wireConn *wire.Conn, frame *wire.Frame) error { +func (ac *acceptedConnection) handleClientHello(conn net.Conn, wireConn *wire.Conn, frame *wire.Frame) error { var hello wire.ClientHello if err := wireConn.UnmarshalFrame(frame, &hello); err != nil { return fmt.Errorf("decode ClientHello: %w", err) } - serverHello := wire.DefaultServerHello() - if err := wire.ValidateClientHello(hello); err != nil { + serverHello, err := wire.NewServerHello(hello) + if err != nil { + serverHello = wire.DefaultServerHello() serverHello.Error = err.Error() - _ = wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello) + if writeErr := writeWithDeadline(conn, connWriteTimeout, func() error { + return wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello) + }); writeErr != nil { + return fmt.Errorf("%w; write ServerHello error: %v", err, writeErr) + } return err } - if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello); err != nil { + serverHelloFrame, err := wire.NewJSONFrame(wire.FrameTypeServerHello, serverHello) + if err != nil { + return fmt.Errorf("encode ServerHello: %w", err) + } + cryptoContext := wire.NewCryptoContext( + serverHello.Selected.Crypto.Algorithm, + frame.Payload, + serverHelloFrame.Payload, + ) + if err := writeWithDeadline(conn, connWriteTimeout, func() error { + return wireConn.WriteFrame(serverHelloFrame) + }); err != nil { return fmt.Errorf("write ServerHello: %w", err) } + ac.cryptoContext = cryptoContext return nil } diff --git a/server/service_test.go b/server/service_test.go new file mode 100644 index 00000000..cf5eec5e --- /dev/null +++ b/server/service_test.go @@ -0,0 +1,63 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestWriteWithDeadlineTimesOutAndClearsDeadline(t *testing.T) { + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + err := writeWithDeadline(serverConn, 50*time.Millisecond, func() error { + _, writeErr := serverConn.Write([]byte("x")) + return writeErr + }) + require.Error(t, err) + + var netErr net.Error + require.True(t, errors.As(err, &netErr)) + require.True(t, netErr.Timeout()) + + readCh := make(chan byte, 1) + errCh := make(chan error, 1) + go func() { + buf := make([]byte, 1) + if _, readErr := clientConn.Read(buf); readErr != nil { + errCh <- readErr + return + } + readCh <- buf[0] + }() + + _, err = serverConn.Write([]byte("y")) + require.NoError(t, err) + + select { + case b := <-readCh: + require.Equal(t, byte('y'), b) + case err := <-errCh: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timed out waiting for write after deadline reset") + } +}