proxy headers (#67)

This commit is contained in:
Paramtamtam
2022-02-23 11:09:54 +05:00
committed by GitHub
parent 06aff4ecb3
commit e857c0309b
13 changed files with 137 additions and 15 deletions

View File

@ -107,11 +107,20 @@ func run(parentCtx context.Context, log *zap.Logger, f flags, cfg *config.Config
}
}
var proxyHTTPHeaders = f.HeadersToProxy()
// create HTTP server
server := appHttp.NewServer(log)
// register server routes, middlewares, etc.
if err := server.Register(cfg, picker, f.defaultErrorPage, f.defaultHTTPCode, f.showDetails); err != nil {
if err := server.Register(
cfg,
picker,
f.defaultErrorPage,
f.defaultHTTPCode,
f.showDetails,
proxyHTTPHeaders,
); err != nil {
return err
}
@ -126,6 +135,7 @@ func run(parentCtx context.Context, log *zap.Logger, f flags, cfg *config.Config
zap.Uint16("port", f.listen.port),
zap.String("default error page", f.defaultErrorPage),
zap.Uint16("default HTTP response code", f.defaultHTTPCode),
zap.Strings("proxy headers", proxyHTTPHeaders),
zap.Bool("show request details", f.showDetails),
)

View File

@ -3,6 +3,7 @@ package serve
import (
"fmt"
"net"
"sort"
"strconv"
"strings"
@ -21,6 +22,45 @@ type flags struct {
defaultErrorPage string
defaultHTTPCode uint16
showDetails bool
proxyHTTPHeaders string // comma-separated
}
// HeadersToProxy converts a comma-separated string with headers list into strings slice (with a sorting and without
// duplicates).
func (f *flags) HeadersToProxy() []string {
var raw = strings.Split(f.proxyHTTPHeaders, ",")
if len(raw) == 0 {
return []string{}
} else if len(raw) == 1 {
if h := strings.TrimSpace(raw[0]); h != "" {
return []string{h}
} else {
return []string{}
}
}
var m = make(map[string]struct{}, len(raw))
// make unique and ignore empty strings
for _, h := range raw {
if h = strings.TrimSpace(h); h != "" {
if _, ok := m[h]; !ok {
m[h] = struct{}{}
}
}
}
// convert map into slice
var headers = make([]string, 0, len(m))
for h := range m {
headers = append(headers, h)
}
// make sort
sort.Strings(headers)
return headers
}
const (
@ -30,6 +70,7 @@ const (
defaultErrorPageFlagName = "default-error-page"
defaultHTTPCodeFlagName = "default-http-code"
showDetailsFlagName = "show-details"
proxyHTTPHeadersFlagName = "proxy-headers"
)
const (
@ -84,6 +125,12 @@ func (f *flags) init(flagSet *pflag.FlagSet) {
false,
fmt.Sprintf("show request details in response [$%s]", env.ShowDetails),
)
flagSet.StringVarP(
&f.proxyHTTPHeaders,
proxyHTTPHeadersFlagName, "",
"",
fmt.Sprintf("proxy HTTP request headers list (comma-separated) [$%s]", env.ProxyHTTPHeaders),
)
}
func (f *flags) overrideUsingEnv(flagSet *pflag.FlagSet) (lastErr error) { //nolint:gocognit,gocyclo
@ -130,6 +177,11 @@ func (f *flags) overrideUsingEnv(flagSet *pflag.FlagSet) (lastErr error) { //nol
f.showDetails = b
}
}
case proxyHTTPHeadersFlagName:
if envVar, exists := env.ProxyHTTPHeaders.Lookup(); exists {
f.proxyHTTPHeaders = strings.TrimSpace(envVar)
}
}
}
})
@ -146,5 +198,9 @@ func (f *flags) validate() error {
return fmt.Errorf("wrong default HTTP response code [%d]", f.defaultHTTPCode)
}
if strings.ContainsRune(f.proxyHTTPHeaders, ' ') {
return fmt.Errorf("whitespaces in the HTTP headers for proxying [%s] are not allowed", f.proxyHTTPHeaders)
}
return nil
}