forked from lug/matterbridge
		
	
		
			
				
	
	
		
			222 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			222 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package middleware
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/subtle"
 | 
						|
	"errors"
 | 
						|
	"net/http"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/labstack/echo/v4"
 | 
						|
	"github.com/labstack/gommon/random"
 | 
						|
)
 | 
						|
 | 
						|
type (
 | 
						|
	// CSRFConfig defines the config for CSRF middleware.
 | 
						|
	CSRFConfig struct {
 | 
						|
		// Skipper defines a function to skip middleware.
 | 
						|
		Skipper Skipper
 | 
						|
 | 
						|
		// TokenLength is the length of the generated token.
 | 
						|
		TokenLength uint8 `yaml:"token_length"`
 | 
						|
		// Optional. Default value 32.
 | 
						|
 | 
						|
		// TokenLookup is a string in the form of "<source>:<key>" that is used
 | 
						|
		// to extract token from the request.
 | 
						|
		// Optional. Default value "header:X-CSRF-Token".
 | 
						|
		// Possible values:
 | 
						|
		// - "header:<name>"
 | 
						|
		// - "form:<name>"
 | 
						|
		// - "query:<name>"
 | 
						|
		TokenLookup string `yaml:"token_lookup"`
 | 
						|
 | 
						|
		// Context key to store generated CSRF token into context.
 | 
						|
		// Optional. Default value "csrf".
 | 
						|
		ContextKey string `yaml:"context_key"`
 | 
						|
 | 
						|
		// Name of the CSRF cookie. This cookie will store CSRF token.
 | 
						|
		// Optional. Default value "csrf".
 | 
						|
		CookieName string `yaml:"cookie_name"`
 | 
						|
 | 
						|
		// Domain of the CSRF cookie.
 | 
						|
		// Optional. Default value none.
 | 
						|
		CookieDomain string `yaml:"cookie_domain"`
 | 
						|
 | 
						|
		// Path of the CSRF cookie.
 | 
						|
		// Optional. Default value none.
 | 
						|
		CookiePath string `yaml:"cookie_path"`
 | 
						|
 | 
						|
		// Max age (in seconds) of the CSRF cookie.
 | 
						|
		// Optional. Default value 86400 (24hr).
 | 
						|
		CookieMaxAge int `yaml:"cookie_max_age"`
 | 
						|
 | 
						|
		// Indicates if CSRF cookie is secure.
 | 
						|
		// Optional. Default value false.
 | 
						|
		CookieSecure bool `yaml:"cookie_secure"`
 | 
						|
 | 
						|
		// Indicates if CSRF cookie is HTTP only.
 | 
						|
		// Optional. Default value false.
 | 
						|
		CookieHTTPOnly bool `yaml:"cookie_http_only"`
 | 
						|
 | 
						|
		// Indicates SameSite mode of the CSRF cookie.
 | 
						|
		// Optional. Default value SameSiteDefaultMode.
 | 
						|
		CookieSameSite http.SameSite `yaml:"cookie_same_site"`
 | 
						|
	}
 | 
						|
 | 
						|
	// csrfTokenExtractor defines a function that takes `echo.Context` and returns
 | 
						|
	// either a token or an error.
 | 
						|
	csrfTokenExtractor func(echo.Context) (string, error)
 | 
						|
)
 | 
						|
 | 
						|
var (
 | 
						|
	// DefaultCSRFConfig is the default CSRF middleware config.
 | 
						|
	DefaultCSRFConfig = CSRFConfig{
 | 
						|
		Skipper:        DefaultSkipper,
 | 
						|
		TokenLength:    32,
 | 
						|
		TokenLookup:    "header:" + echo.HeaderXCSRFToken,
 | 
						|
		ContextKey:     "csrf",
 | 
						|
		CookieName:     "_csrf",
 | 
						|
		CookieMaxAge:   86400,
 | 
						|
		CookieSameSite: http.SameSiteDefaultMode,
 | 
						|
	}
 | 
						|
)
 | 
						|
 | 
						|
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
 | 
						|
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
 | 
						|
func CSRF() echo.MiddlewareFunc {
 | 
						|
	c := DefaultCSRFConfig
 | 
						|
	return CSRFWithConfig(c)
 | 
						|
}
 | 
						|
 | 
						|
// CSRFWithConfig returns a CSRF middleware with config.
 | 
						|
// See `CSRF()`.
 | 
						|
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
 | 
						|
	// Defaults
 | 
						|
	if config.Skipper == nil {
 | 
						|
		config.Skipper = DefaultCSRFConfig.Skipper
 | 
						|
	}
 | 
						|
	if config.TokenLength == 0 {
 | 
						|
		config.TokenLength = DefaultCSRFConfig.TokenLength
 | 
						|
	}
 | 
						|
	if config.TokenLookup == "" {
 | 
						|
		config.TokenLookup = DefaultCSRFConfig.TokenLookup
 | 
						|
	}
 | 
						|
	if config.ContextKey == "" {
 | 
						|
		config.ContextKey = DefaultCSRFConfig.ContextKey
 | 
						|
	}
 | 
						|
	if config.CookieName == "" {
 | 
						|
		config.CookieName = DefaultCSRFConfig.CookieName
 | 
						|
	}
 | 
						|
	if config.CookieMaxAge == 0 {
 | 
						|
		config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
 | 
						|
	}
 | 
						|
	if config.CookieSameSite == http.SameSiteNoneMode {
 | 
						|
		config.CookieSecure = true
 | 
						|
	}
 | 
						|
 | 
						|
	// Initialize
 | 
						|
	parts := strings.Split(config.TokenLookup, ":")
 | 
						|
	extractor := csrfTokenFromHeader(parts[1])
 | 
						|
	switch parts[0] {
 | 
						|
	case "form":
 | 
						|
		extractor = csrfTokenFromForm(parts[1])
 | 
						|
	case "query":
 | 
						|
		extractor = csrfTokenFromQuery(parts[1])
 | 
						|
	}
 | 
						|
 | 
						|
	return func(next echo.HandlerFunc) echo.HandlerFunc {
 | 
						|
		return func(c echo.Context) error {
 | 
						|
			if config.Skipper(c) {
 | 
						|
				return next(c)
 | 
						|
			}
 | 
						|
 | 
						|
			req := c.Request()
 | 
						|
			k, err := c.Cookie(config.CookieName)
 | 
						|
			token := ""
 | 
						|
 | 
						|
			// Generate token
 | 
						|
			if err != nil {
 | 
						|
				token = random.String(config.TokenLength)
 | 
						|
			} else {
 | 
						|
				// Reuse token
 | 
						|
				token = k.Value
 | 
						|
			}
 | 
						|
 | 
						|
			switch req.Method {
 | 
						|
			case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
 | 
						|
			default:
 | 
						|
				// Validate token only for requests which are not defined as 'safe' by RFC7231
 | 
						|
				clientToken, err := extractor(c)
 | 
						|
				if err != nil {
 | 
						|
					return echo.NewHTTPError(http.StatusBadRequest, err.Error())
 | 
						|
				}
 | 
						|
				if !validateCSRFToken(token, clientToken) {
 | 
						|
					return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			// Set CSRF cookie
 | 
						|
			cookie := new(http.Cookie)
 | 
						|
			cookie.Name = config.CookieName
 | 
						|
			cookie.Value = token
 | 
						|
			if config.CookiePath != "" {
 | 
						|
				cookie.Path = config.CookiePath
 | 
						|
			}
 | 
						|
			if config.CookieDomain != "" {
 | 
						|
				cookie.Domain = config.CookieDomain
 | 
						|
			}
 | 
						|
			if config.CookieSameSite != http.SameSiteDefaultMode {
 | 
						|
				cookie.SameSite = config.CookieSameSite
 | 
						|
			}
 | 
						|
			cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
 | 
						|
			cookie.Secure = config.CookieSecure
 | 
						|
			cookie.HttpOnly = config.CookieHTTPOnly
 | 
						|
			c.SetCookie(cookie)
 | 
						|
 | 
						|
			// Store token in the context
 | 
						|
			c.Set(config.ContextKey, token)
 | 
						|
 | 
						|
			// Protect clients from caching the response
 | 
						|
			c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
 | 
						|
 | 
						|
			return next(c)
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
 | 
						|
// provided request header.
 | 
						|
func csrfTokenFromHeader(header string) csrfTokenExtractor {
 | 
						|
	return func(c echo.Context) (string, error) {
 | 
						|
		return c.Request().Header.Get(header), nil
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
 | 
						|
// provided form parameter.
 | 
						|
func csrfTokenFromForm(param string) csrfTokenExtractor {
 | 
						|
	return func(c echo.Context) (string, error) {
 | 
						|
		token := c.FormValue(param)
 | 
						|
		if token == "" {
 | 
						|
			return "", errors.New("missing csrf token in the form parameter")
 | 
						|
		}
 | 
						|
		return token, nil
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
 | 
						|
// provided query parameter.
 | 
						|
func csrfTokenFromQuery(param string) csrfTokenExtractor {
 | 
						|
	return func(c echo.Context) (string, error) {
 | 
						|
		token := c.QueryParam(param)
 | 
						|
		if token == "" {
 | 
						|
			return "", errors.New("missing csrf token in the query string")
 | 
						|
		}
 | 
						|
		return token, nil
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func validateCSRFToken(token, clientToken string) bool {
 | 
						|
	return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
 | 
						|
}
 |