mirror of
https://github.com/fatedier/frp.git
synced 2026-05-15 08:05:49 -06:00
protocol: add AEAD encryption negotiation to v2 wire control channel (#5304)
Some checks failed
golangci-lint / lint (push) Has been cancelled
Some checks failed
golangci-lint / lint (push) Has been cancelled
This commit is contained in:
parent
57bb9e80fe
commit
8666e3643f
15 changed files with 866 additions and 86 deletions
26
README.md
26
README.md
|
|
@ -13,6 +13,14 @@ frp is an open source project with its ongoing development made possible entirel
|
|||
|
||||
<h3 align="center">Gold Sponsors</h3>
|
||||
<!--gold sponsors start-->
|
||||
<p align="center">
|
||||
<a href="https://jb.gg/frp" target="_blank">
|
||||
<img width="420px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_jetbrains.jpg">
|
||||
<br>
|
||||
<b>The complete IDE crafted for professional Go developers</b>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/beclab/Olares" target="_blank">
|
||||
<img width="420px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_olares.jpeg">
|
||||
|
|
@ -32,24 +40,6 @@ If you're looking for a meeting recording API, consider checking out [Recall.ai]
|
|||
an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more.
|
||||
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://requestly.com/?utm_source=github&utm_medium=partnered&utm_campaign=frp" target="_blank">
|
||||
<img width="480px" src="https://github.com/user-attachments/assets/24670320-997d-4d62-9bca-955c59fe883d">
|
||||
<br>
|
||||
<b>Requestly - Free & Open-Source alternative to Postman</b>
|
||||
<br>
|
||||
<sub>All-in-one platform to Test, Mock and Intercept APIs.</sub>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://jb.gg/frp" target="_blank">
|
||||
<img width="420px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_jetbrains.jpg">
|
||||
<br>
|
||||
<b>The complete IDE crafted for professional Go developers</b>
|
||||
</a>
|
||||
</p>
|
||||
<!--gold sponsors end-->
|
||||
|
||||
## What is frp?
|
||||
|
|
|
|||
26
README_zh.md
26
README_zh.md
|
|
@ -15,6 +15,14 @@ frp 是一个完全开源的项目,我们的开发工作完全依靠赞助者
|
|||
|
||||
<h3 align="center">Gold Sponsors</h3>
|
||||
<!--gold sponsors start-->
|
||||
<p align="center">
|
||||
<a href="https://jb.gg/frp" target="_blank">
|
||||
<img width="420px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_jetbrains.jpg">
|
||||
<br>
|
||||
<b>The complete IDE crafted for professional Go developers</b>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/beclab/Olares" target="_blank">
|
||||
<img width="420px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_olares.jpeg">
|
||||
|
|
@ -34,24 +42,6 @@ If you're looking for a meeting recording API, consider checking out [Recall.ai]
|
|||
an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more.
|
||||
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://requestly.com/?utm_source=github&utm_medium=partnered&utm_campaign=frp" target="_blank">
|
||||
<img width="480px" src="https://github.com/user-attachments/assets/24670320-997d-4d62-9bca-955c59fe883d">
|
||||
<br>
|
||||
<b>Requestly - Free & Open-Source alternative to Postman</b>
|
||||
<br>
|
||||
<sub>All-in-one platform to Test, Mock and Intercept APIs.</sub>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://jb.gg/frp" target="_blank">
|
||||
<img width="420px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_jetbrains.jpg">
|
||||
<br>
|
||||
<b>The complete IDE crafted for professional Go developers</b>
|
||||
</a>
|
||||
</p>
|
||||
<!--gold sponsors end-->
|
||||
|
||||
## 为什么使用 frp ?
|
||||
|
|
|
|||
|
|
@ -10,9 +10,12 @@ This release introduces wire protocol v2 as a transition path for future frpc/fr
|
|||
|
||||
**The default value of `transport.wireProtocol` remains `v1` in this release.** Users can keep the default for now. To test v2 early, upgrade both frpc and frps to versions that support it, then set `transport.wireProtocol = "v2"` in frpc. A v2-enabled frpc cannot connect to an older frps.
|
||||
|
||||
When `transport.wireProtocol = "v2"` is enabled, the control channel uses negotiated AEAD encryption after the login handshake. Both frpc and frps must be upgraded to this release to use v2.
|
||||
|
||||
v1 will be deprecated when v2 becomes the default in a future release. It will continue to be supported until v0.78.0 is released, and may be removed in v0.78.0 or later.
|
||||
|
||||
## Features
|
||||
|
||||
* Added `transport.wireProtocol` for frpc to select the internal message protocol used between frpc and frps. Supported values are `v1` and `v2`.
|
||||
* Added client protocol visibility in the frps dashboard and `/api/clients` API. Online clients now report their negotiated protocol as `v1` or `v2`.
|
||||
* Wire protocol v2 now negotiates AEAD control-channel encryption. Supported algorithms are `xchacha20-poly1305` and `aes-256-gcm`; frpc advertises its preferred order based on local AES-GCM hardware support, and frps selects the first supported algorithm from that list.
|
||||
|
|
|
|||
|
|
@ -74,17 +74,18 @@ func (d *controlSessionDialer) Dial(previousRunID string) (*SessionContext, erro
|
|||
return nil, err
|
||||
}
|
||||
|
||||
loginRespMsg, err := d.exchangeLogin(conn, loginMsg)
|
||||
loginResult, err := d.exchangeLogin(conn, loginMsg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loginRespMsg := loginResult.resp
|
||||
if loginRespMsg.Error != "" {
|
||||
return nil, errors.New(loginRespMsg.Error)
|
||||
}
|
||||
|
||||
var controlRW io.ReadWriter = conn
|
||||
if d.clientSpec == nil || d.clientSpec.Type != "ssh-tunnel" {
|
||||
controlRW, err = netpkg.NewCryptoReadWriter(conn, d.auth.EncryptionKey())
|
||||
controlRW, err = d.newControlReadWriter(conn, loginResult.crypto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create control crypto read writer: %w", err)
|
||||
}
|
||||
|
|
@ -125,9 +126,16 @@ func (d *controlSessionDialer) buildLoginMsg(previousRunID string) (*msg.Login,
|
|||
return loginMsg, nil
|
||||
}
|
||||
|
||||
func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login) (*msg.LoginResp, error) {
|
||||
type loginExchangeResult struct {
|
||||
resp *msg.LoginResp
|
||||
crypto *wire.CryptoContext
|
||||
}
|
||||
|
||||
func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login) (*loginExchangeResult, error) {
|
||||
rw := msg.NewV1ReadWriter(conn)
|
||||
var wireConn *wire.Conn
|
||||
var clientHello wire.ClientHello
|
||||
var clientHelloPayload []byte
|
||||
|
||||
if d.common.Transport.WireProtocol == wire.ProtocolV2 {
|
||||
if err := wire.WriteMagic(conn); err != nil {
|
||||
|
|
@ -136,14 +144,23 @@ func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login)
|
|||
|
||||
wireConn = wire.NewConn(conn)
|
||||
rw = msg.NewV2ReadWriterWithConn(wireConn)
|
||||
hello := wire.DefaultClientHello(wire.BootstrapInfo{
|
||||
var err error
|
||||
clientHello, err = wire.NewClientHello(wire.BootstrapInfo{
|
||||
Transport: d.common.Transport.Protocol,
|
||||
TLS: lo.FromPtr(d.common.Transport.TLS.Enable) || d.common.Transport.Protocol == "wss" || d.common.Transport.Protocol == "quic",
|
||||
TCPMux: lo.FromPtr(d.common.Transport.TCPMux),
|
||||
})
|
||||
if err := wireConn.WriteJSONFrame(wire.FrameTypeClientHello, hello); err != nil {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientHelloFrame, err := wire.NewJSONFrame(wire.FrameTypeClientHello, clientHello)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := wireConn.WriteFrame(clientHelloFrame); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientHelloPayload = clientHelloFrame.Payload
|
||||
}
|
||||
if err := rw.WriteMsg(loginMsg); err != nil {
|
||||
return nil, err
|
||||
|
|
@ -154,19 +171,50 @@ func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login)
|
|||
_ = conn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
|
||||
var cryptoContext *wire.CryptoContext
|
||||
if wireConn != nil {
|
||||
serverHelloFrame, err := wireConn.ReadFrame()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if serverHelloFrame.Type != wire.FrameTypeServerHello {
|
||||
return nil, fmt.Errorf("unexpected frame type %d, want %d", serverHelloFrame.Type, wire.FrameTypeServerHello)
|
||||
}
|
||||
var serverHello wire.ServerHello
|
||||
if err := wireConn.ReadJSONFrame(wire.FrameTypeServerHello, &serverHello); err != nil {
|
||||
if err := wireConn.UnmarshalFrame(serverHelloFrame, &serverHello); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if serverHello.Error != "" {
|
||||
return nil, errors.New(serverHello.Error)
|
||||
}
|
||||
cryptoContext, err = wire.NewClientCryptoContext(clientHelloPayload, serverHelloFrame.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var loginRespMsg msg.LoginResp
|
||||
if err := rw.ReadMsgInto(&loginRespMsg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &loginRespMsg, nil
|
||||
return &loginExchangeResult{
|
||||
resp: &loginRespMsg,
|
||||
crypto: cryptoContext,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *controlSessionDialer) newControlReadWriter(conn net.Conn, cryptoContext *wire.CryptoContext) (io.ReadWriter, error) {
|
||||
if d.common.Transport.WireProtocol == wire.ProtocolV2 {
|
||||
if cryptoContext == nil {
|
||||
return nil, errors.New("missing v2 crypto negotiation")
|
||||
}
|
||||
return netpkg.NewAEADCryptoReadWriter(
|
||||
conn,
|
||||
d.auth.EncryptionKey(),
|
||||
netpkg.AEADCryptoRoleClient,
|
||||
cryptoContext.Algorithm,
|
||||
cryptoContext.TranscriptHash,
|
||||
)
|
||||
}
|
||||
return netpkg.NewCryptoReadWriter(conn, d.auth.EncryptionKey())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ import (
|
|||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
type testConnector struct {
|
||||
|
|
@ -140,8 +141,17 @@ func TestControlSessionDialerDialV2(t *testing.T) {
|
|||
}
|
||||
|
||||
wireConn := wire.NewConn(serverRaw)
|
||||
clientHelloFrame, err := wireConn.ReadFrame()
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if clientHelloFrame.Type != wire.FrameTypeClientHello {
|
||||
serverErrCh <- fmt.Errorf("unexpected frame type %d, want %d", clientHelloFrame.Type, wire.FrameTypeClientHello)
|
||||
return
|
||||
}
|
||||
var hello wire.ClientHello
|
||||
if err := wireConn.ReadJSONFrame(wire.FrameTypeClientHello, &hello); err != nil {
|
||||
if err := wireConn.UnmarshalFrame(clientHelloFrame, &hello); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
|
|
@ -160,11 +170,52 @@ func TestControlSessionDialerDialV2(t *testing.T) {
|
|||
serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
|
||||
return
|
||||
}
|
||||
if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, wire.DefaultServerHello()); err != nil {
|
||||
serverHello, err := wire.NewServerHello(hello)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"})
|
||||
serverHelloFrame, err := wire.NewJSONFrame(wire.FrameTypeServerHello, serverHello)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
cryptoContext := wire.NewCryptoContext(
|
||||
serverHello.Selected.Crypto.Algorithm,
|
||||
clientHelloFrame.Payload,
|
||||
serverHelloFrame.Payload,
|
||||
)
|
||||
if err := wireConn.WriteFrame(serverHelloFrame); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if err := rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"}); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
controlRW, err := netpkg.NewAEADCryptoReadWriter(
|
||||
serverRaw,
|
||||
[]byte("token"),
|
||||
netpkg.AEADCryptoRoleServer,
|
||||
cryptoContext.Algorithm,
|
||||
cryptoContext.TranscriptHash,
|
||||
)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
controlMsgRW := msg.NewReadWriter(controlRW, wire.ProtocolV2)
|
||||
var ping msg.Ping
|
||||
if err := controlMsgRW.ReadMsgInto(&ping); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if ping.PrivilegeKey != "v2-ping" || ping.Timestamp != 12345 {
|
||||
serverErrCh <- fmt.Errorf("unexpected ping: %+v", ping)
|
||||
return
|
||||
}
|
||||
serverErrCh <- nil
|
||||
}()
|
||||
|
||||
dialer := newTestControlSessionDialer(t, wire.ProtocolV2, connector, nil)
|
||||
|
|
@ -177,6 +228,7 @@ func TestControlSessionDialerDialV2(t *testing.T) {
|
|||
require.NotNil(t, sessionCtx.Conn)
|
||||
require.NotNil(t, sessionCtx.Connector)
|
||||
require.False(t, connector.closed.Load())
|
||||
require.NoError(t, sessionCtx.Conn.WriteMsg(&msg.Ping{PrivilegeKey: "v2-ping", Timestamp: 12345}))
|
||||
require.NoError(t, <-serverErrCh)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,51 @@ git commit -m "bump version to vX.Y.Z"
|
|||
git push origin dev
|
||||
```
|
||||
|
||||
## 3. Merge dev → master
|
||||
## 3. Pre-release Validation
|
||||
|
||||
Run the standard e2e suite locally:
|
||||
|
||||
```bash
|
||||
make e2e
|
||||
```
|
||||
|
||||
For releases that touch compatibility-sensitive areas such as login, control
|
||||
connections, work connections, visitors, transport, or wire protocol handling,
|
||||
also run the manual compatibility e2e suite:
|
||||
|
||||
```bash
|
||||
make e2e-compatibility
|
||||
make e2e-compatibility-floor
|
||||
```
|
||||
|
||||
`make e2e-compatibility` builds the current `frps` and `frpc`, resolves the
|
||||
recent stable release baselines from GitHub, downloads or reuses their binaries,
|
||||
and tests current binaries against those historical releases. The default number
|
||||
of recent baselines is controlled by `FRP_COMPAT_BASELINE_COUNT` in the
|
||||
`Makefile`.
|
||||
|
||||
Downloaded release binaries are cached under:
|
||||
|
||||
```text
|
||||
.cache/e2e-compat/<version>/<os>_<arch>/
|
||||
```
|
||||
|
||||
For a release validation run that must be exactly reproducible, pass an explicit
|
||||
baseline matrix instead of using the floating recent-release list:
|
||||
|
||||
```bash
|
||||
FRP_COMPAT_BASELINE_VERSIONS="0.X.0 0.Y.0" make e2e-compatibility
|
||||
```
|
||||
|
||||
Use `make e2e-compatibility-smoke` for a quick single-baseline check while
|
||||
iterating locally. If GitHub release metadata requests are rate-limited, set
|
||||
`GITHUB_TOKEN` or use `FRP_COMPAT_BASELINE_VERSIONS`.
|
||||
|
||||
The compatibility floor is a support-policy decision, not a value that should
|
||||
change every release. Update `FRP_COMPAT_FLOOR_VERSION` only when the declared
|
||||
compatibility window changes.
|
||||
|
||||
## 4. Merge dev → master
|
||||
|
||||
Create a PR from `dev` to `master`:
|
||||
|
||||
|
|
@ -43,7 +87,7 @@ gh pr create --base master --head dev --title "bump version"
|
|||
|
||||
Wait for CI to pass, then merge using **merge commit** (not squash).
|
||||
|
||||
## 4. Tag the Release
|
||||
## 5. Tag the Release
|
||||
|
||||
```bash
|
||||
git checkout master
|
||||
|
|
@ -52,7 +96,7 @@ git tag -a vX.Y.Z -m "bump version"
|
|||
git push origin vX.Y.Z
|
||||
```
|
||||
|
||||
## 5. Trigger GoReleaser
|
||||
## 6. Trigger GoReleaser
|
||||
|
||||
Manually trigger the `goreleaser` workflow in GitHub Actions:
|
||||
|
||||
|
|
|
|||
4
go.mod
4
go.mod
|
|
@ -5,7 +5,7 @@ go 1.25.0
|
|||
require (
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
|
||||
github.com/coreos/go-oidc/v3 v3.14.1
|
||||
github.com/fatedier/golib v0.6.0
|
||||
github.com/fatedier/golib v0.7.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
|
|
@ -30,6 +30,7 @@ require (
|
|||
golang.org/x/net v0.52.0
|
||||
golang.org/x/oauth2 v0.28.0
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/sys v0.42.0
|
||||
golang.org/x/time v0.10.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||
gopkg.in/ini.v1 v1.67.0
|
||||
|
|
@ -68,7 +69,6 @@ require (
|
|||
github.com/wlynxg/anet v0.0.5 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
golang.org/x/mod v0.33.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
golang.org/x/tools v0.42.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -20,8 +20,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
|
|||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/fatedier/golib v0.6.0 h1:/mgBZZbkbMhIEZoXf7nV8knpUDzas/b+2ruYKxx1lww=
|
||||
github.com/fatedier/golib v0.6.0/go.mod h1:ArUGvPg2cOw/py2RAuBt46nNZH2VQ5Z70p109MAZpJw=
|
||||
github.com/fatedier/golib v0.7.0 h1:tMDF9ObcwVt59VUHroJOzHQjVFPLymZVMpGm9WAVwhY=
|
||||
github.com/fatedier/golib v0.7.0/go.mod h1:ArUGvPg2cOw/py2RAuBt46nNZH2VQ5Z70p109MAZpJw=
|
||||
github.com/fatedier/yamux v0.0.0-20250825093530-d0154be01cd6 h1:u92UUy6FURPmNsMBUuongRWC0rBqN6gd01Dzu+D21NE=
|
||||
github.com/fatedier/yamux v0.0.0-20250825093530-d0154be01cd6/go.mod h1:c5/tk6G0dSpXGzJN7Wk1OEie8grdSJAmeawId9Zvd34=
|
||||
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
|
||||
|
|
|
|||
197
pkg/proto/wire/crypto.go
Normal file
197
pkg/proto/wire/crypto.go
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package wire
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash"
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
const (
|
||||
AEADAlgorithmAES256GCM = "aes-256-gcm"
|
||||
AEADAlgorithmXChaCha20Poly1305 = "xchacha20-poly1305"
|
||||
|
||||
CryptoRandomSize = 32
|
||||
|
||||
cryptoTranscriptLabel = "frp wire v2 crypto transcript"
|
||||
)
|
||||
|
||||
var supportedAEADAlgorithms = []string{
|
||||
AEADAlgorithmAES256GCM,
|
||||
AEADAlgorithmXChaCha20Poly1305,
|
||||
}
|
||||
|
||||
type CryptoContext struct {
|
||||
Algorithm string
|
||||
TranscriptHash []byte
|
||||
}
|
||||
|
||||
func NewClientHello(bootstrap BootstrapInfo) (ClientHello, error) {
|
||||
clientRandom, err := newCryptoRandom()
|
||||
if err != nil {
|
||||
return ClientHello{}, err
|
||||
}
|
||||
return clientHelloWithCryptoRandom(bootstrap, clientRandom), nil
|
||||
}
|
||||
|
||||
func NewServerHello(clientHello ClientHello) (ServerHello, error) {
|
||||
if err := ValidateClientHello(clientHello); err != nil {
|
||||
return ServerHello{}, err
|
||||
}
|
||||
algorithm, ok := SelectAEADAlgorithm(clientHello.Capabilities.Crypto.Algorithms)
|
||||
if !ok {
|
||||
return ServerHello{}, fmt.Errorf("no supported crypto algorithm")
|
||||
}
|
||||
serverRandom, err := newCryptoRandom()
|
||||
if err != nil {
|
||||
return ServerHello{}, err
|
||||
}
|
||||
return ServerHello{
|
||||
Selected: ServerSelection{
|
||||
Message: MessageSelection{
|
||||
Codec: MessageCodecJSON,
|
||||
},
|
||||
Crypto: CryptoSelection{
|
||||
Algorithm: algorithm,
|
||||
ServerRandom: serverRandom,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ValidateCryptoCapabilities(c CryptoCapabilities) error {
|
||||
if len(c.ClientRandom) != CryptoRandomSize {
|
||||
return fmt.Errorf("invalid crypto client random length %d, want %d", len(c.ClientRandom), CryptoRandomSize)
|
||||
}
|
||||
if _, ok := SelectAEADAlgorithm(c.Algorithms); !ok {
|
||||
return fmt.Errorf("no supported crypto algorithm")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateServerHelloForClient(clientHello ClientHello, serverHello ServerHello) error {
|
||||
if serverHello.Selected.Message.Codec != MessageCodecJSON {
|
||||
return fmt.Errorf("unsupported selected message codec: %s", serverHello.Selected.Message.Codec)
|
||||
}
|
||||
cryptoSelection := serverHello.Selected.Crypto
|
||||
if !IsSupportedAEADAlgorithm(cryptoSelection.Algorithm) {
|
||||
return fmt.Errorf("unknown selected crypto algorithm: %s", cryptoSelection.Algorithm)
|
||||
}
|
||||
if !Supports(clientHello.Capabilities.Crypto.Algorithms, cryptoSelection.Algorithm) {
|
||||
return fmt.Errorf("selected crypto algorithm was not advertised by client: %s", cryptoSelection.Algorithm)
|
||||
}
|
||||
if len(cryptoSelection.ServerRandom) != CryptoRandomSize {
|
||||
return fmt.Errorf("invalid crypto server random length %d, want %d", len(cryptoSelection.ServerRandom), CryptoRandomSize)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewCryptoContext(algorithm string, clientHelloPayload, serverHelloPayload []byte) *CryptoContext {
|
||||
return &CryptoContext{
|
||||
Algorithm: algorithm,
|
||||
TranscriptHash: HashCryptoTranscript(clientHelloPayload, serverHelloPayload),
|
||||
}
|
||||
}
|
||||
|
||||
func NewClientCryptoContext(clientHelloPayload, serverHelloPayload []byte) (*CryptoContext, error) {
|
||||
var clientHello ClientHello
|
||||
if err := json.Unmarshal(clientHelloPayload, &clientHello); err != nil {
|
||||
return nil, fmt.Errorf("decode ClientHello transcript: %w", err)
|
||||
}
|
||||
var serverHello ServerHello
|
||||
if err := json.Unmarshal(serverHelloPayload, &serverHello); err != nil {
|
||||
return nil, fmt.Errorf("decode ServerHello transcript: %w", err)
|
||||
}
|
||||
if err := ValidateServerHelloForClient(clientHello, serverHello); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewCryptoContext(serverHello.Selected.Crypto.Algorithm, clientHelloPayload, serverHelloPayload), nil
|
||||
}
|
||||
|
||||
func HashCryptoTranscript(clientHelloPayload, serverHelloPayload []byte) []byte {
|
||||
h := sha256.New()
|
||||
_, _ = h.Write([]byte(cryptoTranscriptLabel))
|
||||
writeCryptoTranscriptPart(h, "client hello", clientHelloPayload)
|
||||
writeCryptoTranscriptPart(h, "server hello", serverHelloPayload)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func writeCryptoTranscriptPart(h hash.Hash, label string, payload []byte) {
|
||||
var length [8]byte
|
||||
binary.BigEndian.PutUint64(length[:], uint64(len(payload)))
|
||||
_, _ = h.Write([]byte{0})
|
||||
_, _ = h.Write([]byte(label))
|
||||
_, _ = h.Write([]byte{0})
|
||||
_, _ = h.Write(length[:])
|
||||
_, _ = h.Write(payload)
|
||||
}
|
||||
|
||||
func PreferredAEADAlgorithms() []string {
|
||||
if hasFastAESGCM() {
|
||||
return []string{AEADAlgorithmAES256GCM, AEADAlgorithmXChaCha20Poly1305}
|
||||
}
|
||||
return []string{AEADAlgorithmXChaCha20Poly1305, AEADAlgorithmAES256GCM}
|
||||
}
|
||||
|
||||
func SelectAEADAlgorithm(clientAlgorithms []string) (string, bool) {
|
||||
for _, algorithm := range clientAlgorithms {
|
||||
if IsSupportedAEADAlgorithm(algorithm) {
|
||||
return algorithm, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func IsSupportedAEADAlgorithm(algorithm string) bool {
|
||||
return Supports(supportedAEADAlgorithms, algorithm)
|
||||
}
|
||||
|
||||
func newCryptoRandom() ([]byte, error) {
|
||||
b := make([]byte, CryptoRandomSize)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return nil, fmt.Errorf("generate crypto random: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func hasFastAESGCM() bool {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return cpu.X86.HasAES &&
|
||||
cpu.X86.HasPCLMULQDQ &&
|
||||
cpu.X86.HasSSE41 &&
|
||||
cpu.X86.HasSSSE3
|
||||
case "arm64":
|
||||
return cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
|
||||
case "s390x":
|
||||
return cpu.S390X.HasAES &&
|
||||
cpu.S390X.HasAESCTR &&
|
||||
cpu.S390X.HasGHASH
|
||||
case "ppc64", "ppc64le":
|
||||
// Go's ppc64/ppc64le port targets POWER8+, which has AES instructions;
|
||||
// x/sys/cpu does not expose a PPC64 AES feature flag.
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
@ -120,15 +120,23 @@ func (c *Conn) UnmarshalFrame(f *Frame, out any) error {
|
|||
return json.Unmarshal(f.Payload, out)
|
||||
}
|
||||
|
||||
func (c *Conn) WriteJSONFrame(frameType uint16, in any) error {
|
||||
func NewJSONFrame(frameType uint16, in any) (*Frame, error) {
|
||||
payload, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Frame{
|
||||
Type: frameType,
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) WriteJSONFrame(frameType uint16, in any) error {
|
||||
f, err := NewJSONFrame(frameType, in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.WriteFrame(&Frame{
|
||||
Type: frameType,
|
||||
Payload: payload,
|
||||
})
|
||||
return c.WriteFrame(f)
|
||||
}
|
||||
|
||||
func WriteMagic(w io.Writer) error {
|
||||
|
|
@ -170,12 +178,18 @@ type ClientHello struct {
|
|||
|
||||
type ClientCapabilities struct {
|
||||
Message MessageCapabilities `json:"message,omitempty"`
|
||||
Crypto CryptoCapabilities `json:"crypto,omitempty"`
|
||||
}
|
||||
|
||||
type MessageCapabilities struct {
|
||||
Codecs []string `json:"codecs,omitempty"`
|
||||
}
|
||||
|
||||
type CryptoCapabilities struct {
|
||||
Algorithms []string `json:"algorithms,omitempty"`
|
||||
ClientRandom []byte `json:"clientRandom,omitempty"`
|
||||
}
|
||||
|
||||
type ServerHello struct {
|
||||
Selected ServerSelection `json:"selected,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
|
|
@ -183,19 +197,29 @@ type ServerHello struct {
|
|||
|
||||
type ServerSelection struct {
|
||||
Message MessageSelection `json:"message,omitempty"`
|
||||
Crypto CryptoSelection `json:"crypto,omitempty"`
|
||||
}
|
||||
|
||||
type MessageSelection struct {
|
||||
Codec string `json:"codec,omitempty"`
|
||||
}
|
||||
|
||||
func DefaultClientHello(bootstrap BootstrapInfo) ClientHello {
|
||||
type CryptoSelection struct {
|
||||
Algorithm string `json:"algorithm,omitempty"`
|
||||
ServerRandom []byte `json:"serverRandom,omitempty"`
|
||||
}
|
||||
|
||||
func clientHelloWithCryptoRandom(bootstrap BootstrapInfo, clientRandom []byte) ClientHello {
|
||||
return ClientHello{
|
||||
Bootstrap: bootstrap,
|
||||
Capabilities: ClientCapabilities{
|
||||
Message: MessageCapabilities{
|
||||
Codecs: []string{MessageCodecJSON},
|
||||
},
|
||||
Crypto: CryptoCapabilities{
|
||||
Algorithms: PreferredAEADAlgorithms(),
|
||||
ClientRandom: clientRandom,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
@ -218,5 +242,5 @@ func ValidateClientHello(h ClientHello) error {
|
|||
if !Supports(h.Capabilities.Message.Codecs, MessageCodecJSON) {
|
||||
return fmt.Errorf("unsupported message codec")
|
||||
}
|
||||
return nil
|
||||
return ValidateCryptoCapabilities(h.Capabilities.Crypto)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ func TestFrameRoundTrip(t *testing.T) {
|
|||
var buf bytes.Buffer
|
||||
conn := NewConn(&buf)
|
||||
|
||||
in := DefaultClientHello(BootstrapInfo{
|
||||
in := mustClientHello(t, BootstrapInfo{
|
||||
Transport: "tcp",
|
||||
TLS: true,
|
||||
TCPMux: true,
|
||||
|
|
@ -112,9 +112,122 @@ func TestCheckMagicV1PreservesReadBytes(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestValidateClientHello(t *testing.T) {
|
||||
require.NoError(t, ValidateClientHello(DefaultClientHello(BootstrapInfo{})))
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
require.NoError(t, ValidateClientHello(hello))
|
||||
require.Len(t, hello.Capabilities.Crypto.ClientRandom, CryptoRandomSize)
|
||||
require.ElementsMatch(t, []string{
|
||||
AEADAlgorithmAES256GCM,
|
||||
AEADAlgorithmXChaCha20Poly1305,
|
||||
}, hello.Capabilities.Crypto.Algorithms)
|
||||
|
||||
hello := DefaultClientHello(BootstrapInfo{})
|
||||
hello.Capabilities.Message.Codecs = []string{"unknown"}
|
||||
require.ErrorContains(t, ValidateClientHello(hello), "unsupported message codec")
|
||||
}
|
||||
|
||||
func TestValidateClientHelloRejectsInvalidCrypto(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
hello.Capabilities.Crypto.ClientRandom = hello.Capabilities.Crypto.ClientRandom[:CryptoRandomSize-1]
|
||||
require.ErrorContains(t, ValidateClientHello(hello), "invalid crypto client random length")
|
||||
|
||||
hello = mustClientHello(t, BootstrapInfo{})
|
||||
hello.Capabilities.Crypto.Algorithms = []string{"unknown"}
|
||||
require.ErrorContains(t, ValidateClientHello(hello), "no supported crypto algorithm")
|
||||
}
|
||||
|
||||
func TestPreferredAEADAlgorithms(t *testing.T) {
|
||||
require.ElementsMatch(t, []string{
|
||||
AEADAlgorithmAES256GCM,
|
||||
AEADAlgorithmXChaCha20Poly1305,
|
||||
}, PreferredAEADAlgorithms())
|
||||
}
|
||||
|
||||
func TestNewServerHelloSelectsFirstSupportedAEADAlgorithm(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
hello.Capabilities.Crypto.Algorithms = []string{"future-aead", AEADAlgorithmXChaCha20Poly1305, AEADAlgorithmAES256GCM}
|
||||
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, MessageCodecJSON, serverHello.Selected.Message.Codec)
|
||||
require.Equal(t, AEADAlgorithmXChaCha20Poly1305, serverHello.Selected.Crypto.Algorithm)
|
||||
require.Len(t, serverHello.Selected.Crypto.ServerRandom, CryptoRandomSize)
|
||||
}
|
||||
|
||||
func TestNewClientCryptoContextValidatesServerHello(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
|
||||
|
||||
ctx, err := NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, serverHello.Selected.Crypto.Algorithm, ctx.Algorithm)
|
||||
require.Len(t, ctx.TranscriptHash, 32)
|
||||
|
||||
tampered := serverHello
|
||||
tampered.Selected.Crypto.ServerRandom = append([]byte(nil), serverHello.Selected.Crypto.ServerRandom...)
|
||||
tampered.Selected.Crypto.ServerRandom[0] ^= 0xff
|
||||
_, tamperedServerHelloPayload := mustCryptoTranscriptPayloads(t, hello, tampered)
|
||||
tamperedCtx, err := NewClientCryptoContext(clientHelloPayload, tamperedServerHelloPayload)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, ctx.TranscriptHash, tamperedCtx.TranscriptHash)
|
||||
}
|
||||
|
||||
func TestNewCryptoContextBindsFullClientHelloPayload(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{
|
||||
Transport: "tcp",
|
||||
TLS: true,
|
||||
TCPMux: true,
|
||||
})
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
|
||||
|
||||
ctx := NewCryptoContext(serverHello.Selected.Crypto.Algorithm, clientHelloPayload, serverHelloPayload)
|
||||
|
||||
tamperedHello := hello
|
||||
tamperedHello.Bootstrap.TLS = false
|
||||
tamperedClientHelloPayload, _ := mustCryptoTranscriptPayloads(t, tamperedHello, serverHello)
|
||||
tamperedCtx := NewCryptoContext(serverHello.Selected.Crypto.Algorithm, tamperedClientHelloPayload, serverHelloPayload)
|
||||
require.NotEqual(t, ctx.TranscriptHash, tamperedCtx.TranscriptHash)
|
||||
}
|
||||
|
||||
func TestNewClientCryptoContextRejectsUnknownServerSelection(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverHello.Selected.Crypto.Algorithm = "unknown"
|
||||
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
|
||||
_, err = NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
|
||||
require.ErrorContains(t, err, "unknown selected crypto algorithm")
|
||||
}
|
||||
|
||||
func TestNewClientCryptoContextRejectsUnadvertisedServerSelection(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
hello.Capabilities.Crypto.Algorithms = []string{AEADAlgorithmAES256GCM}
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverHello.Selected.Crypto.Algorithm = AEADAlgorithmXChaCha20Poly1305
|
||||
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
|
||||
_, err = NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
|
||||
require.ErrorContains(t, err, "selected crypto algorithm was not advertised by client")
|
||||
}
|
||||
|
||||
func mustClientHello(t *testing.T, bootstrap BootstrapInfo) ClientHello {
|
||||
t.Helper()
|
||||
|
||||
hello, err := NewClientHello(bootstrap)
|
||||
require.NoError(t, err)
|
||||
return hello
|
||||
}
|
||||
|
||||
func mustCryptoTranscriptPayloads(t *testing.T, hello ClientHello, serverHello ServerHello) ([]byte, []byte) {
|
||||
t.Helper()
|
||||
|
||||
clientHelloFrame, err := NewJSONFrame(FrameTypeClientHello, hello)
|
||||
require.NoError(t, err)
|
||||
serverHelloFrame, err := NewJSONFrame(FrameTypeServerHello, serverHello)
|
||||
require.NoError(t, err)
|
||||
return clientHelloFrame.Payload, serverHelloFrame.Payload
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,14 +16,16 @@ package net
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/crypto"
|
||||
libcrypto "github.com/fatedier/golib/crypto"
|
||||
quic "github.com/quic-go/quic-go"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
)
|
||||
|
|
@ -241,8 +243,8 @@ func (conn *wrapQuicStream) Close() error {
|
|||
}
|
||||
|
||||
func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
|
||||
encReader := crypto.NewReader(rw, key)
|
||||
encWriter, err := crypto.NewWriter(rw, key)
|
||||
encReader := libcrypto.NewReader(rw, key)
|
||||
encWriter, err := libcrypto.NewWriter(rw, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -254,3 +256,90 @@ func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
|
|||
Writer: encWriter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type AEADCryptoRole int
|
||||
|
||||
const (
|
||||
AEADCryptoRoleClient AEADCryptoRole = iota + 1
|
||||
AEADCryptoRoleServer
|
||||
)
|
||||
|
||||
const (
|
||||
aeadControlHKDFInfoPrefix = "frp wire v2 control aead"
|
||||
aeadDirectionClientToServer = "client-to-server"
|
||||
aeadDirectionServerToClient = "server-to-client"
|
||||
)
|
||||
|
||||
// NewAEADCryptoReadWriter wraps rw with framed AEAD encryption for the v2
|
||||
// control channel. Frames and their order are authenticated, but end-of-stream
|
||||
// is not: a clean EOF at a frame boundary is returned as normal EOF by the
|
||||
// underlying AEAD stream. Protocols that need truncation detection for finite
|
||||
// objects must add their own authenticated final message.
|
||||
func NewAEADCryptoReadWriter(
|
||||
rw io.ReadWriter,
|
||||
key []byte,
|
||||
role AEADCryptoRole,
|
||||
algorithm string,
|
||||
transcriptHash []byte,
|
||||
) (io.ReadWriter, error) {
|
||||
clientToServerKey, serverToClientKey, err := deriveAEADControlKeys(key, algorithm, transcriptHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var readKey, writeKey []byte
|
||||
switch role {
|
||||
case AEADCryptoRoleClient:
|
||||
readKey = serverToClientKey
|
||||
writeKey = clientToServerKey
|
||||
case AEADCryptoRoleServer:
|
||||
readKey = clientToServerKey
|
||||
writeKey = serverToClientKey
|
||||
default:
|
||||
return nil, errors.New("invalid aead crypto role")
|
||||
}
|
||||
|
||||
encReader, err := libcrypto.NewAEADStreamReader(rw, libcrypto.AEADStreamOptions{
|
||||
Algorithm: libcrypto.AEADAlgorithm(algorithm),
|
||||
Key: readKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encWriter, err := libcrypto.NewAEADStreamWriter(rw, libcrypto.AEADStreamOptions{
|
||||
Algorithm: libcrypto.AEADAlgorithm(algorithm),
|
||||
Key: writeKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}{
|
||||
Reader: encReader,
|
||||
Writer: encWriter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func deriveAEADControlKeys(key []byte, algorithm string, transcriptHash []byte) (clientToServerKey, serverToClientKey []byte, err error) {
|
||||
clientToServerKey, err = deriveAEADControlKey(key, algorithm, transcriptHash, aeadDirectionClientToServer)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
serverToClientKey, err = deriveAEADControlKey(key, algorithm, transcriptHash, aeadDirectionServerToClient)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return clientToServerKey, serverToClientKey, nil
|
||||
}
|
||||
|
||||
func deriveAEADControlKey(key []byte, algorithm string, transcriptHash []byte, direction string) ([]byte, error) {
|
||||
info := []byte(aeadControlHKDFInfoPrefix + " " + algorithm + " " + direction)
|
||||
reader := hkdf.New(sha256.New, key, transcriptHash, info)
|
||||
out := make([]byte, libcrypto.AEADKeySize)
|
||||
if _, err := io.ReadFull(reader, out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
|
|
|||
118
pkg/util/net/conn_test.go
Normal file
118
pkg/util/net/conn_test.go
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
stdnet "net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
||||
func TestNewAEADCryptoReadWriterRoundTrip(t *testing.T) {
|
||||
clientConn, serverConn := stdnet.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
key := []byte("token")
|
||||
transcriptHash := bytes.Repeat([]byte{0x11}, 32)
|
||||
clientRW, err := NewAEADCryptoReadWriter(
|
||||
clientConn,
|
||||
key,
|
||||
AEADCryptoRoleClient,
|
||||
wire.AEADAlgorithmXChaCha20Poly1305,
|
||||
transcriptHash,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
serverRW, err := NewAEADCryptoReadWriter(
|
||||
serverConn,
|
||||
key,
|
||||
AEADCryptoRoleServer,
|
||||
wire.AEADAlgorithmXChaCha20Poly1305,
|
||||
transcriptHash,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
if _, err := clientRW.Write([]byte("ping")); err != nil {
|
||||
clientErrCh <- err
|
||||
return
|
||||
}
|
||||
buf := make([]byte, len("pong"))
|
||||
_, err := io.ReadFull(clientRW, buf)
|
||||
clientErrCh <- err
|
||||
}()
|
||||
|
||||
buf := make([]byte, len("ping"))
|
||||
_, err = io.ReadFull(serverRW, buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "ping", string(buf))
|
||||
_, err = serverRW.Write([]byte("pong"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, <-clientErrCh)
|
||||
}
|
||||
|
||||
func TestNewAEADCryptoReadWriterRejectsDifferentTranscript(t *testing.T) {
|
||||
clientConn, serverConn := stdnet.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
require.NoError(t, clientConn.SetDeadline(time.Now().Add(time.Second)))
|
||||
require.NoError(t, serverConn.SetDeadline(time.Now().Add(time.Second)))
|
||||
|
||||
key := []byte("token")
|
||||
clientRW, err := NewAEADCryptoReadWriter(
|
||||
clientConn,
|
||||
key,
|
||||
AEADCryptoRoleClient,
|
||||
wire.AEADAlgorithmAES256GCM,
|
||||
bytes.Repeat([]byte{0x22}, 32),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
serverRW, err := NewAEADCryptoReadWriter(
|
||||
serverConn,
|
||||
key,
|
||||
AEADCryptoRoleServer,
|
||||
wire.AEADAlgorithmAES256GCM,
|
||||
bytes.Repeat([]byte{0x33}, 32),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
writeErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := clientRW.Write([]byte("ping"))
|
||||
writeErrCh <- err
|
||||
}()
|
||||
|
||||
buf := make([]byte, len("ping"))
|
||||
_, err = io.ReadFull(serverRW, buf)
|
||||
require.Error(t, err)
|
||||
require.NoError(t, <-writeErrCh)
|
||||
}
|
||||
|
||||
func TestDeriveAEADControlKeysUsesDistinctDirections(t *testing.T) {
|
||||
clientToServerKey, serverToClientKey, err := deriveAEADControlKeys(
|
||||
[]byte("token"),
|
||||
wire.AEADAlgorithmXChaCha20Poly1305,
|
||||
bytes.Repeat([]byte{0x44}, 32),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, clientToServerKey, serverToClientKey)
|
||||
}
|
||||
|
|
@ -60,6 +60,7 @@ import (
|
|||
|
||||
const (
|
||||
connReadTimeout time.Duration = 10 * time.Second
|
||||
connWriteTimeout time.Duration = 5 * time.Second
|
||||
vhostReadWriteTimeout time.Duration = 30 * time.Second
|
||||
)
|
||||
|
||||
|
|
@ -456,7 +457,7 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
|
|||
controlConn := acceptedConn.conn
|
||||
if !internal {
|
||||
var controlRW io.ReadWriter
|
||||
controlRW, err = netpkg.NewCryptoReadWriter(conn, svr.auth.EncryptionKey())
|
||||
controlRW, err = acceptedConn.newControlReadWriter(conn, svr.auth.EncryptionKey())
|
||||
if err == nil {
|
||||
controlConn = acceptedConn.messageConnFor(controlRW)
|
||||
}
|
||||
|
|
@ -468,17 +469,23 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
|
|||
|
||||
if err != nil {
|
||||
xl.Warnf("register control error: %v", err)
|
||||
_ = acceptedConn.conn.WriteMsg(&msg.LoginResp{
|
||||
if writeErr := writeWithDeadline(conn, connWriteTimeout, func() error {
|
||||
return acceptedConn.conn.WriteMsg(&msg.LoginResp{
|
||||
Version: version.Full(),
|
||||
Error: util.GenerateResponseErrorString("register control error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
||||
})
|
||||
}); writeErr != nil {
|
||||
xl.Warnf("write login error response error: %v", writeErr)
|
||||
}
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
if err = acceptedConn.conn.WriteMsg(&msg.LoginResp{
|
||||
if err = writeWithDeadline(conn, connWriteTimeout, func() error {
|
||||
return acceptedConn.conn.WriteMsg(&msg.LoginResp{
|
||||
Version: version.Full(),
|
||||
RunID: ctl.runID,
|
||||
Error: "",
|
||||
})
|
||||
}); err != nil {
|
||||
xl.Warnf("write login response error: %v", err)
|
||||
svr.ctlManager.Del(m.RunID, ctl)
|
||||
|
|
@ -523,6 +530,7 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
|
|||
type acceptedConnection struct {
|
||||
conn *msg.Conn
|
||||
wireProtocol string
|
||||
cryptoContext *wire.CryptoContext
|
||||
firstMsg msg.Message
|
||||
}
|
||||
|
||||
|
|
@ -544,7 +552,7 @@ func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*accep
|
|||
wireConn := wire.NewConn(conn)
|
||||
rw := msg.NewV2ReadWriterWithConn(wireConn)
|
||||
acceptedConn.conn = msg.NewConn(conn, rw)
|
||||
acceptedConn.firstMsg, err = acceptedConn.readFirstV2Msg(wireConn)
|
||||
acceptedConn.firstMsg, err = acceptedConn.readFirstV2Msg(conn, wireConn)
|
||||
} else {
|
||||
rw := msg.NewV1ReadWriter(conn)
|
||||
acceptedConn.conn = msg.NewConn(conn, rw)
|
||||
|
|
@ -557,17 +565,41 @@ func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*accep
|
|||
return acceptedConn, nil
|
||||
}
|
||||
|
||||
func writeWithDeadline(conn net.Conn, timeout time.Duration, writeFn func() error) error {
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
defer func() {
|
||||
_ = conn.SetWriteDeadline(time.Time{})
|
||||
}()
|
||||
return writeFn()
|
||||
}
|
||||
|
||||
func (ac *acceptedConnection) messageConnFor(rw io.ReadWriter) *msg.Conn {
|
||||
return msg.NewConn(ac.conn, msg.NewReadWriter(rw, ac.wireProtocol))
|
||||
}
|
||||
|
||||
func (ac *acceptedConnection) readFirstV2Msg(wireConn *wire.Conn) (msg.Message, error) {
|
||||
func (ac *acceptedConnection) newControlReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
|
||||
if ac.wireProtocol == wire.ProtocolV2 {
|
||||
if ac.cryptoContext == nil {
|
||||
return nil, fmt.Errorf("missing v2 crypto negotiation")
|
||||
}
|
||||
return netpkg.NewAEADCryptoReadWriter(
|
||||
rw,
|
||||
key,
|
||||
netpkg.AEADCryptoRoleServer,
|
||||
ac.cryptoContext.Algorithm,
|
||||
ac.cryptoContext.TranscriptHash,
|
||||
)
|
||||
}
|
||||
return netpkg.NewCryptoReadWriter(rw, key)
|
||||
}
|
||||
|
||||
func (ac *acceptedConnection) readFirstV2Msg(conn net.Conn, wireConn *wire.Conn) (msg.Message, error) {
|
||||
frame, err := wireConn.ReadFrame()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read v2 frame: %w", err)
|
||||
}
|
||||
if frame.Type == wire.FrameTypeClientHello {
|
||||
if err := ac.handleClientHello(wireConn, frame); err != nil {
|
||||
if err := ac.handleClientHello(conn, wireConn, frame); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame, err = wireConn.ReadFrame()
|
||||
|
|
@ -583,21 +615,38 @@ func (ac *acceptedConnection) readFirstV2Msg(wireConn *wire.Conn) (msg.Message,
|
|||
return m, nil
|
||||
}
|
||||
|
||||
func (ac *acceptedConnection) handleClientHello(wireConn *wire.Conn, frame *wire.Frame) error {
|
||||
func (ac *acceptedConnection) handleClientHello(conn net.Conn, wireConn *wire.Conn, frame *wire.Frame) error {
|
||||
var hello wire.ClientHello
|
||||
if err := wireConn.UnmarshalFrame(frame, &hello); err != nil {
|
||||
return fmt.Errorf("decode ClientHello: %w", err)
|
||||
}
|
||||
|
||||
serverHello := wire.DefaultServerHello()
|
||||
if err := wire.ValidateClientHello(hello); err != nil {
|
||||
serverHello, err := wire.NewServerHello(hello)
|
||||
if err != nil {
|
||||
serverHello = wire.DefaultServerHello()
|
||||
serverHello.Error = err.Error()
|
||||
_ = wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello)
|
||||
if writeErr := writeWithDeadline(conn, connWriteTimeout, func() error {
|
||||
return wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello)
|
||||
}); writeErr != nil {
|
||||
return fmt.Errorf("%w; write ServerHello error: %v", err, writeErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello); err != nil {
|
||||
serverHelloFrame, err := wire.NewJSONFrame(wire.FrameTypeServerHello, serverHello)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode ServerHello: %w", err)
|
||||
}
|
||||
cryptoContext := wire.NewCryptoContext(
|
||||
serverHello.Selected.Crypto.Algorithm,
|
||||
frame.Payload,
|
||||
serverHelloFrame.Payload,
|
||||
)
|
||||
if err := writeWithDeadline(conn, connWriteTimeout, func() error {
|
||||
return wireConn.WriteFrame(serverHelloFrame)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("write ServerHello: %w", err)
|
||||
}
|
||||
ac.cryptoContext = cryptoContext
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
63
server/service_test.go
Normal file
63
server/service_test.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteWithDeadlineTimesOutAndClearsDeadline(t *testing.T) {
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
err := writeWithDeadline(serverConn, 50*time.Millisecond, func() error {
|
||||
_, writeErr := serverConn.Write([]byte("x"))
|
||||
return writeErr
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
var netErr net.Error
|
||||
require.True(t, errors.As(err, &netErr))
|
||||
require.True(t, netErr.Timeout())
|
||||
|
||||
readCh := make(chan byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 1)
|
||||
if _, readErr := clientConn.Read(buf); readErr != nil {
|
||||
errCh <- readErr
|
||||
return
|
||||
}
|
||||
readCh <- buf[0]
|
||||
}()
|
||||
|
||||
_, err = serverConn.Write([]byte("y"))
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case b := <-readCh:
|
||||
require.Equal(t, byte('y'), b)
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for write after deadline reset")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue