mirror of
https://github.com/mmatczuk/go-http-tunnel.git
synced 2026-05-15 14:16:17 -06:00
Add support for custom error pages
This commit is contained in:
parent
b901486b9f
commit
572eb25500
2 changed files with 29 additions and 13 deletions
22
errors.go
22
errors.go
|
|
@ -4,12 +4,24 @@
|
|||
|
||||
package tunnel
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
var (
|
||||
errClientNotSubscribed = errors.New("client not subscribed")
|
||||
errClientNotConnected = errors.New("client not connected")
|
||||
errClientAlreadyConnected = errors.New("client already connected")
|
||||
errClientNotSubscribed = newError("clientNotSubscribed.html", "client not subscribed")
|
||||
errClientNotConnected = newError("clientNotConnected.html", "client not connected")
|
||||
errClientAlreadyConnected = newError("clientAlreadyConnected.html", "client already connected")
|
||||
|
||||
errUnauthorised = errors.New("unauthorised")
|
||||
errUnauthorised = newError("unauthorised.html", "unauthorised")
|
||||
)
|
||||
|
||||
func newError(fileName string, defaultMsg string) error {
|
||||
content, err := ioutil.ReadFile("html/errors/" + fileName)
|
||||
if err != nil {
|
||||
// handle the case where the file doesn't exist
|
||||
return errors.New(defaultMsg)
|
||||
}
|
||||
return errors.New(string(content))
|
||||
}
|
||||
|
|
|
|||
20
server.go
20
server.go
|
|
@ -564,22 +564,26 @@ func (s *Server) listen(l net.Listener, identifier id.ID) {
|
|||
// ServeHTTP proxies http connection to the client.
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
resp, err := s.RoundTrip(r)
|
||||
if err == errUnauthorised {
|
||||
w.Header().Set("WWW-Authenticate", "Basic realm=\"User Visible Realm\"")
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
code := http.StatusBadGateway
|
||||
if err == errUnauthorised {
|
||||
w.Header().Set("WWW-Authenticate", "Basic realm=\"User Visible Realm\"")
|
||||
code = http.StatusUnauthorized
|
||||
} else if err == errClientNotSubscribed {
|
||||
code = http.StatusNotFound
|
||||
}
|
||||
s.logger.Log(
|
||||
"level", 0,
|
||||
"action", "round trip failed",
|
||||
"addr", r.RemoteAddr,
|
||||
"host", r.Host,
|
||||
"url", r.URL,
|
||||
"err", err,
|
||||
"code", code,
|
||||
)
|
||||
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.WriteHeader(code)
|
||||
fmt.Fprintln(w, err.Error())
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue