Merge pull request #23 from mmatczuk/mmt/proto_no_regex

proto: header per control message property, no regex
This commit is contained in:
Michał Matczuk 2017-05-12 10:37:10 +02:00 committed by GitHub
commit d364f66c09
6 changed files with 130 additions and 83 deletions

View file

@ -216,7 +216,7 @@ func (c *Client) dial() (net.Conn, error) {
func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
if r.Header.Get(proto.ErrorHeader) != "" {
if r.Header.Get(proto.HeaderError) != "" {
c.handleHandshakeError(w, r)
} else {
c.handleHandshake(w, r)
@ -224,7 +224,7 @@ func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) {
return
}
msg, err := proto.ParseControlMessage(r.Header)
msg, err := proto.ReadControlMessage(r.Header)
if err != nil {
c.logger.Log(
"level", 1,
@ -240,7 +240,7 @@ func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) {
"ctrlMsg", msg,
)
switch msg.Action {
case proto.Proxy:
case proto.ActionProxy:
c.config.Proxy(w, r.Body, msg)
default:
c.logger.Log(
@ -258,7 +258,7 @@ func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) {
}
func (c *Client) handleHandshakeError(w http.ResponseWriter, r *http.Request) {
err := fmt.Errorf(r.Header.Get(proto.ErrorHeader))
err := fmt.Errorf(r.Header.Get(proto.HeaderError))
c.logger.Log(
"level", 1,

View file

@ -1,16 +0,0 @@
// Code generated by "stringer -type=Action"; DO NOT EDIT
package proto
import "fmt"
const _Action_name = "Proxy"
var _Action_index = [...]uint8{0, 5}
func (i Action) String() string {
if i < 0 || i >= Action(len(_Action_index)-1) {
return fmt.Sprintf("Action(%d)", i)
}
return _Action_name[_Action_index[i]:_Action_index[i+1]]
}

View file

@ -1,77 +1,83 @@
package proto
import (
"errors"
"fmt"
"net/http"
"regexp"
)
// Action represents type of ControlMessage.
type Action int
// ControlMessage actions.
// Protocol HTTP headers.
const (
Proxy Action = iota
HeaderAction = "T-Action"
HeaderError = "T-Error"
HeaderForwardedBy = "T-Forwarded-By"
HeaderForwardedFor = "T-Forwarded-For"
HeaderPath = "T-Path"
HeaderProtocol = "T-Proto"
)
// ControlMessage headers
// Known actions.
const (
ErrorHeader = "Error"
ForwardedHeader = "Forwarded"
ActionProxy string = "proxy"
)
// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), "udp",
// "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4" (IPv4-only),
// "ip6" (IPv6-only), "unix", "unixgram" and "unixpacket".
// Known protocol types.
const (
HTTP = "http"
TCP = "tcp"
TCP4 = "tcp4"
TCP6 = "tcp6"
UNIX = "unix"
WS = "ws"
)
// ControlMessage is sent from server to client to establish tunneled
// connection.
// ControlMessage is sent from server to client before streaming data. It's
// used to inform client about the data and action to take. Based on that client
// routes requests to backend services.
type ControlMessage struct {
Action Action
Action string
Protocol string
ForwardedFor string
ForwardedBy string
Path string
}
var xffRegexp = regexp.MustCompile("(for|proto|by)=([^;$]+)")
// ParseControlMessage creates new ControlMessage based on "Forwarded" http
// header.
func ParseControlMessage(h http.Header) (*ControlMessage, error) {
v := h.Get(ForwardedHeader)
if v == "" {
return nil, errors.New("missing Forwarded header")
// ReadControlMessage reads ControlMessage from HTTP headers.
func ReadControlMessage(h http.Header) (*ControlMessage, error) {
msg := ControlMessage{
Action: h.Get(HeaderAction),
Protocol: h.Get(HeaderProtocol),
ForwardedFor: h.Get(HeaderForwardedFor),
ForwardedBy: h.Get(HeaderForwardedBy),
Path: h.Get(HeaderPath),
}
var msg = &ControlMessage{}
var missing []string
for _, i := range xffRegexp.FindAllStringSubmatch(v, -1) {
switch i[1] {
case "for":
msg.ForwardedFor = i[2]
case "by":
msg.ForwardedBy = i[2]
case "proto":
msg.Protocol = i[2]
}
if msg.Action == "" {
missing = append(missing, HeaderAction)
}
if msg.Protocol == "" {
missing = append(missing, HeaderProtocol)
}
if msg.ForwardedFor == "" {
missing = append(missing, HeaderForwardedFor)
}
if msg.ForwardedBy == "" {
missing = append(missing, HeaderForwardedBy)
}
return msg, nil
if len(missing) != 0 {
return nil, fmt.Errorf("missing headers: %s", missing)
}
return &msg, nil
}
// Update writes ControlMessage to "Forwarded" http header, "by" and "for"
// parameters take form of full IP and port.
//
// See Forwarded header specification https://tools.ietf.org/html/rfc7239.
// Update writes ControlMessage to HTTP header.
func (c *ControlMessage) Update(h http.Header) {
v := fmt.Sprintf("for=%s; proto=%s; by=%s", c.ForwardedFor, c.Protocol, c.ForwardedBy)
h.Set(ForwardedHeader, v)
h.Set(HeaderAction, string(c.Action))
h.Set(HeaderProtocol, c.Protocol)
h.Set(HeaderForwardedFor, c.ForwardedFor)
h.Set(HeaderForwardedBy, c.ForwardedBy)
h.Set(HeaderPath, c.Path)
}

View file

@ -1,6 +1,7 @@
package proto
import (
"errors"
"net/http"
"reflect"
"testing"
@ -9,19 +10,67 @@ import (
func TestControlMessage_WriteParse(t *testing.T) {
t.Parallel()
msg := &ControlMessage{
Protocol: "tcp",
ForwardedFor: "127.0.0.1:58104",
ForwardedBy: "127.0.0.1:7777",
data := []struct {
msg *ControlMessage
err error
}{
{
&ControlMessage{
Action: "action",
Protocol: "protocol",
ForwardedFor: "host-for",
ForwardedBy: "host-by",
},
nil,
},
{
&ControlMessage{
Protocol: "protocol",
ForwardedFor: "host-for",
ForwardedBy: "host-by",
},
errors.New("missing headers: [T-Action]"),
},
{
&ControlMessage{
Action: "action",
ForwardedFor: "host-for",
ForwardedBy: "host-by",
},
errors.New("missing headers: [T-Proto]"),
},
{
&ControlMessage{
Action: "action",
Protocol: "protocol",
ForwardedBy: "host-by",
},
errors.New("missing headers: [T-Forwarded-For]"),
},
{
&ControlMessage{
Action: "action",
Protocol: "protocol",
ForwardedFor: "host-for",
},
errors.New("missing headers: [T-Forwarded-By]"),
},
}
var h = http.Header{}
msg.Update(h)
actual, err := ParseControlMessage(h)
if err != nil {
t.Errorf("Parse error %s", err)
}
if !reflect.DeepEqual(msg, actual) {
t.Errorf("Received %+v expected %+v", msg, actual)
for i, tt := range data {
h := http.Header{}
tt.msg.Update(h)
actual, err := ReadControlMessage(h)
if tt.err != nil {
if err == nil {
t.Error(i, "expected error")
} else if tt.err.Error() != err.Error() {
t.Error(i, tt.err, err)
}
} else {
if !reflect.DeepEqual(tt.msg, actual) {
t.Error(i, tt.msg, actual)
}
}
}
}

View file

@ -1,11 +1,19 @@
package proto
// Tunnel specifies tunnel entry point. Tunnel map is sent from client to server
// during handshake. Server tries to proxy connections to Host and Addr to
// client.
// Tunnel describes a single tunnel between client and server. When connecting
// client sends tunnels to server. If client gets connected server proxies
// connections to given Host and Addr to the client.
type Tunnel struct {
// Protocol specifies tunnel protocol, must be one of protocols known
// by the server.
Protocol string
Host string
Auth string
Addr string
// Host specified HTTP request host, it's required for HTTP and WS
// tunnels.
Host string
// Auth specifies HTTP basic auth credentials in form "user:password",
// if set server would protect HTTP and WS tunnels with basic auth.
Auth string
// Addr specifies TCP address server would listen on, it's required
// for TCP tunnels.
Addr string
}

View file

@ -325,7 +325,7 @@ func (s *Server) notifyError(serverError error, identifier id.ID) {
return
}
req.Header.Set(proto.ErrorHeader, serverError.Error())
req.Header.Set(proto.HeaderError, serverError.Error())
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
@ -421,7 +421,7 @@ func (s *Server) listen(l net.Listener, identifier id.ID) {
}
msg := &proto.ControlMessage{
Action: proto.Proxy,
Action: proto.ActionProxy,
Protocol: l.Addr().Network(),
ForwardedFor: conn.RemoteAddr().String(),
ForwardedBy: l.Addr().String(),
@ -522,7 +522,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// RoundTrip is http.RoundTriper implementation.
func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
msg := &proto.ControlMessage{
Action: proto.Proxy,
Action: proto.ActionProxy,
Protocol: proto.HTTP,
ForwardedFor: r.RemoteAddr,
ForwardedBy: r.Host,
@ -605,7 +605,7 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control
}
func (s *Server) proxyRequest(identifier id.ID, msg *proto.ControlMessage, r io.Reader) (*http.Request, error) {
if msg.Action != proto.Proxy {
if msg.Action != proto.ActionProxy {
panic("Invalid action")
}