forked from jshiffer/matterbridge
161 lines
3.7 KiB
Go
161 lines
3.7 KiB
Go
|
package middleware
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"math/rand"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"net/http/httputil"
|
||
|
"net/url"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
|
||
|
"github.com/labstack/echo"
|
||
|
)
|
||
|
|
||
|
// TODO: Handle TLS proxy
|
||
|
|
||
|
type (
|
||
|
// ProxyConfig defines the config for Proxy middleware.
|
||
|
ProxyConfig struct {
|
||
|
// Skipper defines a function to skip middleware.
|
||
|
Skipper Skipper
|
||
|
|
||
|
// Balancer defines a load balancing technique.
|
||
|
// Required.
|
||
|
// Possible values:
|
||
|
// - RandomBalancer
|
||
|
// - RoundRobinBalancer
|
||
|
Balancer ProxyBalancer
|
||
|
}
|
||
|
|
||
|
// ProxyTarget defines the upstream target.
|
||
|
ProxyTarget struct {
|
||
|
URL *url.URL
|
||
|
}
|
||
|
|
||
|
// RandomBalancer implements a random load balancing technique.
|
||
|
RandomBalancer struct {
|
||
|
Targets []*ProxyTarget
|
||
|
random *rand.Rand
|
||
|
}
|
||
|
|
||
|
// RoundRobinBalancer implements a round-robin load balancing technique.
|
||
|
RoundRobinBalancer struct {
|
||
|
Targets []*ProxyTarget
|
||
|
i uint32
|
||
|
}
|
||
|
|
||
|
// ProxyBalancer defines an interface to implement a load balancing technique.
|
||
|
ProxyBalancer interface {
|
||
|
Next() *ProxyTarget
|
||
|
}
|
||
|
)
|
||
|
|
||
|
func proxyHTTP(t *ProxyTarget) http.Handler {
|
||
|
return httputil.NewSingleHostReverseProxy(t.URL)
|
||
|
}
|
||
|
|
||
|
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
|
||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
h, ok := w.(http.Hijacker)
|
||
|
if !ok {
|
||
|
c.Error(errors.New("proxy raw, not a hijacker"))
|
||
|
return
|
||
|
}
|
||
|
in, _, err := h.Hijack()
|
||
|
if err != nil {
|
||
|
c.Error(fmt.Errorf("proxy raw, hijack error=%v, url=%s", r.URL, err))
|
||
|
return
|
||
|
}
|
||
|
defer in.Close()
|
||
|
|
||
|
out, err := net.Dial("tcp", t.URL.Host)
|
||
|
if err != nil {
|
||
|
he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", r.URL, err))
|
||
|
c.Error(he)
|
||
|
return
|
||
|
}
|
||
|
defer out.Close()
|
||
|
|
||
|
err = r.Write(out)
|
||
|
if err != nil {
|
||
|
he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request copy error=%v, url=%s", r.URL, err))
|
||
|
c.Error(he)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
errc := make(chan error, 2)
|
||
|
cp := func(dst io.Writer, src io.Reader) {
|
||
|
_, err := io.Copy(dst, src)
|
||
|
errc <- err
|
||
|
}
|
||
|
|
||
|
go cp(out, in)
|
||
|
go cp(in, out)
|
||
|
err = <-errc
|
||
|
if err != nil && err != io.EOF {
|
||
|
c.Logger().Errorf("proxy raw, error=%v, url=%s", r.URL, err)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// Next randomly returns an upstream target.
|
||
|
func (r *RandomBalancer) Next() *ProxyTarget {
|
||
|
if r.random == nil {
|
||
|
r.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
|
||
|
}
|
||
|
return r.Targets[r.random.Intn(len(r.Targets))]
|
||
|
}
|
||
|
|
||
|
// Next returns an upstream target using round-robin technique.
|
||
|
func (r *RoundRobinBalancer) Next() *ProxyTarget {
|
||
|
r.i = r.i % uint32(len(r.Targets))
|
||
|
t := r.Targets[r.i]
|
||
|
atomic.AddUint32(&r.i, 1)
|
||
|
return t
|
||
|
}
|
||
|
|
||
|
// Proxy returns an HTTP/WebSocket reverse proxy middleware.
|
||
|
func Proxy(config ProxyConfig) echo.MiddlewareFunc {
|
||
|
// Defaults
|
||
|
if config.Skipper == nil {
|
||
|
config.Skipper = DefaultLoggerConfig.Skipper
|
||
|
}
|
||
|
if config.Balancer == nil {
|
||
|
panic("echo: proxy middleware requires balancer")
|
||
|
}
|
||
|
|
||
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||
|
return func(c echo.Context) (err error) {
|
||
|
req := c.Request()
|
||
|
res := c.Response()
|
||
|
tgt := config.Balancer.Next()
|
||
|
|
||
|
// Fix header
|
||
|
if req.Header.Get(echo.HeaderXRealIP) == "" {
|
||
|
req.Header.Set(echo.HeaderXRealIP, c.RealIP())
|
||
|
}
|
||
|
if req.Header.Get(echo.HeaderXForwardedProto) == "" {
|
||
|
req.Header.Set(echo.HeaderXForwardedProto, c.Scheme())
|
||
|
}
|
||
|
if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy.
|
||
|
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
|
||
|
}
|
||
|
|
||
|
// Proxy
|
||
|
switch {
|
||
|
case c.IsWebSocket():
|
||
|
proxyRaw(tgt, c).ServeHTTP(res, req)
|
||
|
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
|
||
|
default:
|
||
|
proxyHTTP(tgt).ServeHTTP(res, req)
|
||
|
}
|
||
|
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|