mirror of
https://github.com/tarampampam/error-pages.git
synced 2024-08-30 18:22:40 +00:00
feat: fasthttp is back
This commit is contained in:
parent
afaef54ddf
commit
df2163e41e
4
go.mod
4
go.mod
@ -10,12 +10,16 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/andybalholm/brotli v1.1.0 // indirect
|
||||||
github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect
|
github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/klauspost/compress v1.17.9 // indirect
|
||||||
github.com/kr/pretty v0.3.1 // indirect
|
github.com/kr/pretty v0.3.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
||||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
|
github.com/valyala/fasthttp v1.55.0 // indirect
|
||||||
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect
|
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
8
go.sum
8
go.sum
@ -1,8 +1,12 @@
|
|||||||
|
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
|
||||||
|
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
|
||||||
github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4=
|
github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4=
|
||||||
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||||
|
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
@ -26,6 +30,10 @@ github.com/urfave/cli-docs/v3 v3.0.0-alpha5 h1:H1oWnR2/GN0dNm2PVylws+GxSOD6YOwW/
|
|||||||
github.com/urfave/cli-docs/v3 v3.0.0-alpha5/go.mod h1:AIqom6Q60U4tiqHp41i7+/AB2XHgi1WvQ7jOFlccmZ4=
|
github.com/urfave/cli-docs/v3 v3.0.0-alpha5/go.mod h1:AIqom6Q60U4tiqHp41i7+/AB2XHgi1WvQ7jOFlccmZ4=
|
||||||
github.com/urfave/cli/v3 v3.0.0-alpha9 h1:P0RMy5fQm1AslQS+XCmy9UknDXctOmG/q/FZkUFnJSo=
|
github.com/urfave/cli/v3 v3.0.0-alpha9 h1:P0RMy5fQm1AslQS+XCmy9UknDXctOmG/q/FZkUFnJSo=
|
||||||
github.com/urfave/cli/v3 v3.0.0-alpha9/go.mod h1:0kK/RUFHyh+yIKSfWxwheGndfnrvYSmYFVeKCh03ZUc=
|
github.com/urfave/cli/v3 v3.0.0-alpha9/go.mod h1:0kK/RUFHyh+yIKSfWxwheGndfnrvYSmYFVeKCh03ZUc=
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
|
github.com/valyala/fasthttp v1.55.0 h1:Zkefzgt6a7+bVKHnu/YaYSOPfNYNisSVBo/unVCf8k8=
|
||||||
|
github.com/valyala/fasthttp v1.55.0/go.mod h1:NkY9JtkrpPKmgwV3HTaS2HWaJss9RSIsRVfcxxoHiOM=
|
||||||
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4=
|
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4=
|
||||||
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
|
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
|
||||||
go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8=
|
go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8=
|
||||||
|
@ -55,7 +55,7 @@ func NewCommand(log *logger.Logger) *cli.Command { //nolint:funlen,gocognit
|
|||||||
Aliases: []string{"perf", "test"},
|
Aliases: []string{"perf", "test"},
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
Usage: "Simple performance (load) test for the HTTP server",
|
Usage: "Simple performance (load) test for the HTTP server",
|
||||||
Action: func(ctx context.Context, c *cli.Command) error {
|
Action: func(ctx context.Context, c *cli.Command) error { // TODO: use fasthttp.Client
|
||||||
var (
|
var (
|
||||||
perfCtx, cancel = context.WithTimeout(ctx, c.Duration(durationFlag.Name))
|
perfCtx, cancel = context.WithTimeout(ctx, c.Duration(durationFlag.Name))
|
||||||
startedAt = time.Now()
|
startedAt = time.Now()
|
||||||
|
@ -290,7 +290,7 @@ func NewCommand(log *logger.Logger) *cli.Command { //nolint:funlen,gocognit,gocy
|
|||||||
|
|
||||||
// Run current command.
|
// Run current command.
|
||||||
func (cmd *command) Run(ctx context.Context, log *logger.Logger, cfg *config.Config) error { //nolint:funlen
|
func (cmd *command) Run(ctx context.Context, log *logger.Logger, cfg *config.Config) error { //nolint:funlen
|
||||||
var srv = appHttp.NewServer(ctx, log)
|
var srv = appHttp.NewServer(log)
|
||||||
|
|
||||||
if err := srv.Register(cfg); err != nil {
|
if err := srv.Register(cfg); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
package error_page
|
package error_page
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// extractCodeFromURL extracts the error code from the given URL.
|
// extractCodeFromURL extracts the error code from the given URL.
|
||||||
@ -37,11 +38,15 @@ func extractCodeFromURL(url string) (uint16, bool) {
|
|||||||
func URLContainsCode(url string) (ok bool) { _, ok = extractCodeFromURL(url); return } //nolint:nlreturn
|
func URLContainsCode(url string) (ok bool) { _, ok = extractCodeFromURL(url); return } //nolint:nlreturn
|
||||||
|
|
||||||
// extractCodeFromHeaders extracts the error code from the given headers.
|
// extractCodeFromHeaders extracts the error code from the given headers.
|
||||||
func extractCodeFromHeaders(headers http.Header) (uint16, bool) {
|
func extractCodeFromHeaders(headers *fasthttp.RequestHeader) (uint16, bool) {
|
||||||
|
if headers == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
// https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/
|
// https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/
|
||||||
// HTTP status code returned by the request
|
// HTTP status code returned by the request
|
||||||
if value := headers.Get("X-Code"); len(value) > 0 && len(value) <= 3 {
|
if value := headers.Peek("X-Code"); len(value) > 0 && len(value) <= 3 {
|
||||||
if code, err := strconv.ParseUint(value, 10, 16); err == nil && code > 0 && code < 999 {
|
if code, err := strconv.ParseUint(string(value), 10, 16); err == nil && code > 0 && code < 999 {
|
||||||
return uint16(code), true
|
return uint16(code), true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -50,7 +55,7 @@ func extractCodeFromHeaders(headers http.Header) (uint16, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HeadersContainCode checks if the given headers contain an error code.
|
// HeadersContainCode checks if the given headers contain an error code.
|
||||||
func HeadersContainCode(headers http.Header) (ok bool) {
|
func HeadersContainCode(headers *fasthttp.RequestHeader) (ok bool) {
|
||||||
_, ok = extractCodeFromHeaders(headers)
|
_, ok = extractCodeFromHeaders(headers)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
package error_page_test
|
package error_page_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
"gh.tarampamp.am/error-pages/internal/http/handlers/error_page"
|
"gh.tarampamp.am/error-pages/internal/http/handlers/error_page"
|
||||||
)
|
)
|
||||||
@ -36,18 +36,26 @@ func TestURLContainsCode(t *testing.T) {
|
|||||||
func TestHeadersContainCode(t *testing.T) {
|
func TestHeadersContainCode(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
var mkHeaders = func(key, value string) *fasthttp.RequestHeader {
|
||||||
|
var out = new(fasthttp.RequestHeader)
|
||||||
|
|
||||||
|
out.Set(key, value)
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
for name, _tt := range map[string]struct {
|
for name, _tt := range map[string]struct {
|
||||||
giveHeaders http.Header
|
giveHeaders *fasthttp.RequestHeader
|
||||||
wantOk bool
|
wantOk bool
|
||||||
}{
|
}{
|
||||||
"with code": {giveHeaders: http.Header{"X-Code": {"404"}}, wantOk: true},
|
"with code": {giveHeaders: mkHeaders("X-Code", "404"), wantOk: true},
|
||||||
|
|
||||||
"empty": {giveHeaders: nil},
|
"empty": {giveHeaders: nil},
|
||||||
"no code": {giveHeaders: http.Header{"X-Code": {""}}},
|
"no code": {giveHeaders: mkHeaders("X-Code", "")},
|
||||||
"wrong": {giveHeaders: http.Header{"X-Code": {"foo"}}},
|
"wrong": {giveHeaders: mkHeaders("X-Code", "foo")},
|
||||||
"too big": {giveHeaders: http.Header{"X-Code": {"1000"}}},
|
"too big": {giveHeaders: mkHeaders("X-Code", "1000")},
|
||||||
"too small": {giveHeaders: http.Header{"X-Code": {"0"}}},
|
"too small": {giveHeaders: mkHeaders("X-Code", "0")},
|
||||||
"negative": {giveHeaders: http.Header{"X-Code": {"-1"}}},
|
"negative": {giveHeaders: mkHeaders("X-Code", "-1")},
|
||||||
} {
|
} {
|
||||||
tt := _tt
|
tt := _tt
|
||||||
|
|
||||||
|
@ -2,10 +2,11 @@ package error_page
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type preferredFormat = byte
|
type preferredFormat = byte
|
||||||
@ -21,10 +22,10 @@ const (
|
|||||||
// detectPreferredFormatForClient detects the preferred format for the client based on the headers.
|
// detectPreferredFormatForClient detects the preferred format for the client based on the headers.
|
||||||
// It supports the following headers: Content-Type, Accept, X-Format.
|
// It supports the following headers: Content-Type, Accept, X-Format.
|
||||||
// If the headers are not set or the format is not recognized, it returns unknownFormat.
|
// If the headers are not set or the format is not recognized, it returns unknownFormat.
|
||||||
func detectPreferredFormatForClient(headers http.Header) preferredFormat { //nolint:funlen,gocognit
|
func detectPreferredFormatForClient(headers *fasthttp.RequestHeader) preferredFormat { //nolint:funlen,gocognit
|
||||||
var contentType, accept string
|
var contentType, accept string
|
||||||
|
|
||||||
if contentTypeHeader := strings.TrimSpace(headers.Get("Content-Type")); contentTypeHeader != "" { //nolint:nestif
|
if contentTypeHeader := strings.TrimSpace(string(headers.Peek("Content-Type"))); contentTypeHeader != "" { //nolint:nestif,lll
|
||||||
// https://developer.mozilla.org/docs/Web/HTTP/Headers/Content-Type
|
// https://developer.mozilla.org/docs/Web/HTTP/Headers/Content-Type
|
||||||
// text/html; charset=utf-8
|
// text/html; charset=utf-8
|
||||||
// multipart/form-data; boundary=something
|
// multipart/form-data; boundary=something
|
||||||
@ -38,11 +39,11 @@ func detectPreferredFormatForClient(headers http.Header) preferredFormat { //nol
|
|||||||
// take the whole value
|
// take the whole value
|
||||||
contentType = contentTypeHeader
|
contentType = contentTypeHeader
|
||||||
}
|
}
|
||||||
} else if xFormatHeader := strings.TrimSpace(headers.Get("X-Format")); xFormatHeader != "" {
|
} else if xFormatHeader := strings.TrimSpace(string(headers.Peek("X-Format"))); xFormatHeader != "" {
|
||||||
// https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/
|
// https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/
|
||||||
// Value of the `Accept` header sent by the client
|
// Value of the `Accept` header sent by the client
|
||||||
accept = xFormatHeader
|
accept = xFormatHeader
|
||||||
} else if acceptHeader := strings.TrimSpace(headers.Get("Accept")); acceptHeader != "" {
|
} else if acceptHeader := strings.TrimSpace(string(headers.Peek("Accept"))); acceptHeader != "" {
|
||||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
|
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
|
||||||
// text/html, application/xhtml+xml, application/xml;q=0.9, image/webp, */*;q=0.8
|
// text/html, application/xhtml+xml, application/xml;q=0.9, image/webp, */*;q=0.8
|
||||||
// text/html
|
// text/html
|
||||||
|
@ -1,80 +1,80 @@
|
|||||||
package error_page
|
package error_page
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_detectPreferredFormatForClient(t *testing.T) {
|
func Test_detectPreferredFormatForClient(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
for name, _tt := range map[string]struct {
|
for name, _tt := range map[string]struct {
|
||||||
giveHeaders http.Header
|
giveHeaders map[string][]string
|
||||||
wantFormat preferredFormat
|
wantFormat preferredFormat
|
||||||
}{
|
}{
|
||||||
"content type json": {
|
"content type json": {
|
||||||
giveHeaders: http.Header{"Content-Type": {"application/jSoN"}},
|
giveHeaders: map[string][]string{"Content-Type": {"application/jSoN"}},
|
||||||
wantFormat: jsonFormat,
|
wantFormat: jsonFormat,
|
||||||
},
|
},
|
||||||
"content type xml": {
|
"content type xml": {
|
||||||
giveHeaders: http.Header{"Content-Type": {"application/xml; charset=UTF-8"}},
|
giveHeaders: map[string][]string{"Content-Type": {"application/xml; charset=UTF-8"}},
|
||||||
wantFormat: xmlFormat,
|
wantFormat: xmlFormat,
|
||||||
},
|
},
|
||||||
"content type html": {
|
"content type html": {
|
||||||
giveHeaders: http.Header{"Content-Type": {"text/hTmL; charset=utf-8"}},
|
giveHeaders: map[string][]string{"Content-Type": {"text/hTmL; charset=utf-8"}},
|
||||||
wantFormat: htmlFormat,
|
wantFormat: htmlFormat,
|
||||||
},
|
},
|
||||||
"content type plain": {
|
"content type plain": {
|
||||||
giveHeaders: http.Header{"Content-Type": {"text/plaIN"}},
|
giveHeaders: map[string][]string{"Content-Type": {"text/plaIN"}},
|
||||||
wantFormat: plainTextFormat,
|
wantFormat: plainTextFormat,
|
||||||
},
|
},
|
||||||
|
|
||||||
"accept json": {
|
"accept json": {
|
||||||
giveHeaders: http.Header{"Accept": {"application/jsoN,*/*;q=0.8"}},
|
giveHeaders: map[string][]string{"Accept": {"application/jsoN,*/*;q=0.8"}},
|
||||||
wantFormat: jsonFormat,
|
wantFormat: jsonFormat,
|
||||||
},
|
},
|
||||||
"accept xml, depends on weight": {
|
"accept xml, depends on weight": {
|
||||||
giveHeaders: http.Header{"Accept": {"text/html;q=0.5,application/xhtml+xml;q=0.9,application/xml;q=1,*/*;q=0.8"}},
|
giveHeaders: map[string][]string{"Accept": {"text/html;q=0.5,application/xhtml+xml;q=0.9,application/xml;q=1,*/*;q=0.8"}},
|
||||||
wantFormat: xmlFormat,
|
wantFormat: xmlFormat,
|
||||||
},
|
},
|
||||||
"accept json, depends on weight": {
|
"accept json, depends on weight": {
|
||||||
giveHeaders: http.Header{"Accept": {"application/jsoN,*/*;q=0.8"}},
|
giveHeaders: map[string][]string{"Accept": {"application/jsoN,*/*;q=0.8"}},
|
||||||
wantFormat: jsonFormat,
|
wantFormat: jsonFormat,
|
||||||
},
|
},
|
||||||
"accept xml": {
|
"accept xml": {
|
||||||
giveHeaders: http.Header{"Accept": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}},
|
giveHeaders: map[string][]string{"Accept": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}},
|
||||||
wantFormat: xmlFormat,
|
wantFormat: xmlFormat,
|
||||||
},
|
},
|
||||||
"accept html": {
|
"accept html": {
|
||||||
giveHeaders: http.Header{"Accept": {"text/html, application/xhtml+xml, application/xml;q=0.9, image/avif, image/webp, */*;q=0.8"}},
|
giveHeaders: map[string][]string{"Accept": {"text/html, application/xhtml+xml, application/xml;q=0.9, image/avif, image/webp, */*;q=0.8"}},
|
||||||
wantFormat: htmlFormat,
|
wantFormat: htmlFormat,
|
||||||
},
|
},
|
||||||
"accept plain": {
|
"accept plain": {
|
||||||
giveHeaders: http.Header{"Accept": {"text/plaiN,text/html,application/xml;q=0.9,,,*/*;q=0.8"}},
|
giveHeaders: map[string][]string{"Accept": {"text/plaiN,text/html,application/xml;q=0.9,,,*/*;q=0.8"}},
|
||||||
wantFormat: plainTextFormat,
|
wantFormat: plainTextFormat,
|
||||||
},
|
},
|
||||||
"accept json, weighted values only": {
|
"accept json, weighted values only": {
|
||||||
giveHeaders: http.Header{"Accept": {"application/jsoN;Q=0.1,text/html;q=1.1,application/xml;q=-1,*/*;q=0.8"}},
|
giveHeaders: map[string][]string{"Accept": {"application/jsoN;Q=0.1,text/html;q=1.1,application/xml;q=-1,*/*;q=0.8"}},
|
||||||
wantFormat: jsonFormat,
|
wantFormat: jsonFormat,
|
||||||
},
|
},
|
||||||
|
|
||||||
"x-format json, depends on weight": {
|
"x-format json, depends on weight": {
|
||||||
giveHeaders: http.Header{"X-Format": {"application/jsoN,*/*;q=0.8"}},
|
giveHeaders: map[string][]string{"X-Format": {"application/jsoN,*/*;q=0.8"}},
|
||||||
wantFormat: jsonFormat,
|
wantFormat: jsonFormat,
|
||||||
},
|
},
|
||||||
"x-format xml": {
|
"x-format xml": {
|
||||||
giveHeaders: http.Header{"X-Format": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}},
|
giveHeaders: map[string][]string{"X-Format": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}},
|
||||||
wantFormat: xmlFormat,
|
wantFormat: xmlFormat,
|
||||||
},
|
},
|
||||||
|
|
||||||
"content type has priority over accept": {
|
"content type has priority over accept": {
|
||||||
giveHeaders: http.Header{"Content-Type": {"text/plain"}, "Accept": {"application/xml"}},
|
giveHeaders: map[string][]string{"Content-Type": {"text/plain"}, "Accept": {"application/xml"}},
|
||||||
wantFormat: plainTextFormat,
|
wantFormat: plainTextFormat,
|
||||||
},
|
},
|
||||||
"accept has priority over x-format": {
|
"accept has priority over x-format": {
|
||||||
giveHeaders: http.Header{"Accept": {"application/xml"}, "X-Format": {"text/plain"}},
|
giveHeaders: map[string][]string{"Accept": {"application/xml"}, "X-Format": {"text/plain"}},
|
||||||
wantFormat: plainTextFormat,
|
wantFormat: plainTextFormat,
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -82,25 +82,33 @@ func Test_detectPreferredFormatForClient(t *testing.T) {
|
|||||||
giveHeaders: nil,
|
giveHeaders: nil,
|
||||||
},
|
},
|
||||||
"empty content type": {
|
"empty content type": {
|
||||||
giveHeaders: http.Header{"Content-Type": {" "}},
|
giveHeaders: map[string][]string{"Content-Type": {" "}},
|
||||||
},
|
},
|
||||||
"wrong content type": {
|
"wrong content type": {
|
||||||
giveHeaders: http.Header{"Content-Type": {"multipart/form-data; boundary=something"}},
|
giveHeaders: map[string][]string{"Content-Type": {"multipart/form-data; boundary=something"}},
|
||||||
},
|
},
|
||||||
"wrong accept": {
|
"wrong accept": {
|
||||||
giveHeaders: http.Header{"Accept": {";q=foobar,bar/baz;;;;;application/xml"}},
|
giveHeaders: map[string][]string{"Accept": {";q=foobar,bar/baz;;;;;application/xml"}},
|
||||||
},
|
},
|
||||||
"none on invalid input": {
|
"none on invalid input": {
|
||||||
giveHeaders: http.Header{"Content-Type": {"foo/bar; charset=utf-8"}, "Accept": {"foo/bar; charset=utf-8"}},
|
giveHeaders: map[string][]string{"Content-Type": {"foo/bar; charset=utf-8"}, "Accept": {"foo/bar; charset=utf-8"}},
|
||||||
},
|
},
|
||||||
"completely unknown": {
|
"completely unknown": {
|
||||||
giveHeaders: http.Header{"Content-Type": {"😀"}, "Accept": {"😄"}, "X-Format": {"😍"}},
|
giveHeaders: map[string][]string{"Content-Type": {"😀"}, "Accept": {"😄"}, "X-Format": {"😍"}},
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
tt := _tt
|
tt := _tt
|
||||||
|
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
assert.Equal(t, tt.wantFormat, detectPreferredFormatForClient(tt.giveHeaders))
|
var headers = new(fasthttp.RequestHeader)
|
||||||
|
|
||||||
|
for key, values := range tt.giveHeaders {
|
||||||
|
for _, value := range values {
|
||||||
|
headers.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantFormat, detectPreferredFormatForClient(headers))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,21 +7,24 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
"gh.tarampamp.am/error-pages/internal/config"
|
"gh.tarampamp.am/error-pages/internal/config"
|
||||||
"gh.tarampamp.am/error-pages/internal/logger"
|
"gh.tarampamp.am/error-pages/internal/logger"
|
||||||
"gh.tarampamp.am/error-pages/internal/template"
|
"gh.tarampamp.am/error-pages/internal/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
// New creates a new handler that returns an error page with the specified status code and format.
|
// New creates a new handler that returns an error page with the specified status code and format.
|
||||||
func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,gocognit,gocyclo
|
func New(cfg *config.Config, log *logger.Logger) fasthttp.RequestHandler { //nolint:funlen,gocognit,gocyclo
|
||||||
const contentTypeHeader = "Content-Type"
|
return func(ctx *fasthttp.RequestCtx) {
|
||||||
|
var (
|
||||||
|
reqHeaders = &ctx.Request.Header
|
||||||
|
code uint16
|
||||||
|
)
|
||||||
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
if fromUrl, okUrl := extractCodeFromURL(string(ctx.RequestURI())); okUrl {
|
||||||
var code uint16
|
|
||||||
|
|
||||||
if fromUrl, okUrl := extractCodeFromURL(r.URL.Path); okUrl {
|
|
||||||
code = fromUrl
|
code = fromUrl
|
||||||
} else if fromHeader, okHeaders := extractCodeFromHeaders(r.Header); okHeaders {
|
} else if fromHeader, okHeaders := extractCodeFromHeaders(reqHeaders); okHeaders {
|
||||||
code = fromHeader
|
code = fromHeader
|
||||||
} else {
|
} else {
|
||||||
code = cfg.DefaultCodeToRender
|
code = cfg.DefaultCodeToRender
|
||||||
@ -35,23 +38,23 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
|
|||||||
httpCode = http.StatusOK
|
httpCode = http.StatusOK
|
||||||
}
|
}
|
||||||
|
|
||||||
var format = detectPreferredFormatForClient(r.Header)
|
var format = detectPreferredFormatForClient(reqHeaders)
|
||||||
|
|
||||||
{ // deal with the headers
|
{ // deal with the headers
|
||||||
switch format {
|
switch format {
|
||||||
case jsonFormat:
|
case jsonFormat:
|
||||||
w.Header().Set(contentTypeHeader, "application/json; charset=utf-8")
|
ctx.SetContentType("application/json; charset=utf-8")
|
||||||
case xmlFormat:
|
case xmlFormat:
|
||||||
w.Header().Set(contentTypeHeader, "application/xml; charset=utf-8")
|
ctx.SetContentType("application/xml; charset=utf-8")
|
||||||
case htmlFormat:
|
case htmlFormat:
|
||||||
w.Header().Set(contentTypeHeader, "text/html; charset=utf-8")
|
ctx.SetContentType("text/html; charset=utf-8")
|
||||||
default:
|
default:
|
||||||
w.Header().Set(contentTypeHeader, "text/plain; charset=utf-8") // plainTextFormat as default
|
ctx.SetContentType("text/plain; charset=utf-8") // plainTextFormat as default
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://developers.google.com/search/docs/crawling-indexing/robots-meta-tag
|
// https://developers.google.com/search/docs/crawling-indexing/robots-meta-tag
|
||||||
// disallow indexing of the error pages
|
// disallow indexing of the error pages
|
||||||
w.Header().Set("X-Robots-Tag", "noindex")
|
ctx.Response.Header.Set("X-Robots-Tag", "noindex")
|
||||||
|
|
||||||
switch code {
|
switch code {
|
||||||
case http.StatusRequestTimeout, http.StatusTooEarly, http.StatusTooManyRequests,
|
case http.StatusRequestTimeout, http.StatusTooEarly, http.StatusTooManyRequests,
|
||||||
@ -59,18 +62,18 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
|
|||||||
http.StatusGatewayTimeout:
|
http.StatusGatewayTimeout:
|
||||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
|
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
|
||||||
// tell the client (search crawler) to retry the request after 120 seconds
|
// tell the client (search crawler) to retry the request after 120 seconds
|
||||||
w.Header().Set("Retry-After", "120")
|
ctx.Response.Header.Set("Retry-After", "120")
|
||||||
}
|
}
|
||||||
|
|
||||||
// proxy the headers from the incoming request to the error page response if they are defined in the config
|
// proxy the headers from the incoming request to the error page response if they are defined in the config
|
||||||
for _, proxyHeader := range cfg.ProxyHeaders {
|
for _, proxyHeader := range cfg.ProxyHeaders {
|
||||||
if value := r.Header.Get(proxyHeader); value != "" {
|
if value := reqHeaders.Peek(proxyHeader); len(value) > 0 {
|
||||||
w.Header().Set(proxyHeader, value)
|
ctx.Response.Header.SetBytesV(proxyHeader, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(httpCode)
|
ctx.SetStatusCode(httpCode)
|
||||||
|
|
||||||
// prepare the template properties for rendering
|
// prepare the template properties for rendering
|
||||||
var tplProps = template.Props{
|
var tplProps = template.Props{
|
||||||
@ -81,14 +84,14 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
|
|||||||
|
|
||||||
//nolint:lll
|
//nolint:lll
|
||||||
if cfg.ShowDetails { // https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/
|
if cfg.ShowDetails { // https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/
|
||||||
tplProps.OriginalURI = r.Header.Get("X-Original-URI") // (ingress-nginx) URI that caused the error
|
tplProps.OriginalURI = string(reqHeaders.Peek("X-Original-URI")) // (ingress-nginx) URI that caused the error
|
||||||
tplProps.Namespace = r.Header.Get("X-Namespace") // (ingress-nginx) namespace where the backend Service is located
|
tplProps.Namespace = string(reqHeaders.Peek("X-Namespace")) // (ingress-nginx) namespace where the backend Service is located
|
||||||
tplProps.IngressName = r.Header.Get("X-Ingress-Name") // (ingress-nginx) name of the Ingress where the backend is defined
|
tplProps.IngressName = string(reqHeaders.Peek("X-Ingress-Name")) // (ingress-nginx) name of the Ingress where the backend is defined
|
||||||
tplProps.ServiceName = r.Header.Get("X-Service-Name") // (ingress-nginx) name of the Service backing the backend
|
tplProps.ServiceName = string(reqHeaders.Peek("X-Service-Name")) // (ingress-nginx) name of the Service backing the backend
|
||||||
tplProps.ServicePort = r.Header.Get("X-Service-Port") // (ingress-nginx) port number of the Service backing the backend
|
tplProps.ServicePort = string(reqHeaders.Peek("X-Service-Port")) // (ingress-nginx) port number of the Service backing the backend
|
||||||
tplProps.RequestID = r.Header.Get("X-Request-Id") // (ingress-nginx) unique ID that identifies the request - same as for backend service
|
tplProps.RequestID = string(reqHeaders.Peek("X-Request-Id")) // (ingress-nginx) unique ID that identifies the request - same as for backend service
|
||||||
tplProps.ForwardedFor = r.Header.Get("X-Forwarded-For") // the value of the `X-Forwarded-For` header
|
tplProps.ForwardedFor = string(reqHeaders.Peek("X-Forwarded-For")) // the value of the `X-Forwarded-For` header
|
||||||
tplProps.Host = r.Host // the value of the `Host` header
|
tplProps.Host = string(reqHeaders.Peek("Host")) // the value of the `Host` header
|
||||||
}
|
}
|
||||||
|
|
||||||
// try to find the code message and description in the config and if not - use the standard status text or fallback
|
// try to find the code message and description in the config and if not - use the standard status text or fallback
|
||||||
@ -105,18 +108,18 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
|
|||||||
case format == jsonFormat && cfg.Formats.JSON != "":
|
case format == jsonFormat && cfg.Formats.JSON != "":
|
||||||
if content, err := template.Render(cfg.Formats.JSON, tplProps); err != nil {
|
if content, err := template.Render(cfg.Formats.JSON, tplProps); err != nil {
|
||||||
j, _ := json.Marshal(fmt.Sprintf("Failed to render the JSON template: %s", err.Error()))
|
j, _ := json.Marshal(fmt.Sprintf("Failed to render the JSON template: %s", err.Error()))
|
||||||
write(w, log, j)
|
write(ctx, log, j)
|
||||||
} else {
|
} else {
|
||||||
write(w, log, content)
|
write(ctx, log, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
case format == xmlFormat && cfg.Formats.XML != "":
|
case format == xmlFormat && cfg.Formats.XML != "":
|
||||||
if content, err := template.Render(cfg.Formats.XML, tplProps); err != nil {
|
if content, err := template.Render(cfg.Formats.XML, tplProps); err != nil {
|
||||||
write(w, log, fmt.Sprintf(
|
write(ctx, log, fmt.Sprintf(
|
||||||
"<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<error>Failed to render the XML template: %s</error>", err.Error(),
|
"<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<error>Failed to render the XML template: %s</error>", err.Error(),
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
write(w, log, content)
|
write(ctx, log, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
case format == htmlFormat:
|
case format == htmlFormat:
|
||||||
@ -125,16 +128,16 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
|
|||||||
if tpl, found := cfg.Templates.Get(templateName); found {
|
if tpl, found := cfg.Templates.Get(templateName); found {
|
||||||
if content, err := template.Render(tpl, tplProps); err != nil {
|
if content, err := template.Render(tpl, tplProps); err != nil {
|
||||||
// TODO: add GZIP compression for the HTML content support
|
// TODO: add GZIP compression for the HTML content support
|
||||||
write(w, log, fmt.Sprintf(
|
write(ctx, log, fmt.Sprintf(
|
||||||
"<!DOCTYPE html>\n<html><body>Failed to render the HTML template %s: %s</body></html>",
|
"<!DOCTYPE html>\n<html><body>Failed to render the HTML template %s: %s</body></html>",
|
||||||
templateName,
|
templateName,
|
||||||
err.Error(),
|
err.Error(),
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
write(w, log, content)
|
write(ctx, log, content)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
write(w, log, fmt.Sprintf(
|
write(ctx, log, fmt.Sprintf(
|
||||||
"<!DOCTYPE html>\n<html><body>Template %s not found and cannot be used</body></html>", templateName,
|
"<!DOCTYPE html>\n<html><body>Template %s not found and cannot be used</body></html>", templateName,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
@ -142,18 +145,18 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
|
|||||||
default: // plainTextFormat as default
|
default: // plainTextFormat as default
|
||||||
if cfg.Formats.PlainText != "" {
|
if cfg.Formats.PlainText != "" {
|
||||||
if content, err := template.Render(cfg.Formats.PlainText, tplProps); err != nil {
|
if content, err := template.Render(cfg.Formats.PlainText, tplProps); err != nil {
|
||||||
write(w, log, fmt.Sprintf("Failed to render the PlainText template: %s", err.Error()))
|
write(ctx, log, fmt.Sprintf("Failed to render the PlainText template: %s", err.Error()))
|
||||||
} else {
|
} else {
|
||||||
write(w, log, content)
|
write(ctx, log, content)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
write(w, log, `The requested content format is not supported.
|
write(ctx, log, `The requested content format is not supported.
|
||||||
Please create an issue on the project's GitHub page to request support for this format.
|
Please create an issue on the project's GitHub page to request support for this format.
|
||||||
|
|
||||||
Supported formats: JSON, XML, HTML, Plain Text`)
|
Supported formats: JSON, XML, HTML, Plain Text`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -204,7 +207,7 @@ func templateToUse(cfg *config.Config) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// write the content to the response writer and log the error if any.
|
// write the content to the response writer and log the error if any.
|
||||||
func write[T string | []byte](w http.ResponseWriter, log *logger.Logger, content T) {
|
func write[T string | []byte](ctx *fasthttp.RequestCtx, log *logger.Logger, content T) {
|
||||||
var data []byte
|
var data []byte
|
||||||
|
|
||||||
if s, ok := any(content).(string); ok {
|
if s, ok := any(content).(string); ok {
|
||||||
@ -213,7 +216,7 @@ func write[T string | []byte](w http.ResponseWriter, log *logger.Logger, content
|
|||||||
data = any(content).([]byte)
|
data = any(content).([]byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := w.Write(data); err != nil && log != nil {
|
if _, err := ctx.Write(data); err != nil && log != nil {
|
||||||
log.Error("failed to write the response body",
|
log.Error("failed to write the response body",
|
||||||
logger.String("content", string(data)),
|
logger.String("content", string(data)),
|
||||||
logger.Error(err),
|
logger.Error(err),
|
||||||
|
@ -2,13 +2,14 @@ package error_page_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"gh.tarampamp.am/error-pages/internal/config"
|
"gh.tarampamp.am/error-pages/internal/config"
|
||||||
"gh.tarampamp.am/error-pages/internal/http/handlers/error_page"
|
"gh.tarampamp.am/error-pages/internal/http/handlers/error_page"
|
||||||
|
"gh.tarampamp.am/error-pages/internal/http/httptest"
|
||||||
"gh.tarampamp.am/error-pages/internal/logger"
|
"gh.tarampamp.am/error-pages/internal/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,7 +27,7 @@ func TestHandler(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
"common, plain text": {
|
"common, plain text": {
|
||||||
giveConfig: func() *config.Config { cfg := config.New(); return &cfg },
|
giveConfig: func() *config.Config { cfg := config.New(); return &cfg },
|
||||||
giveUrl: "/",
|
giveUrl: "http://testing/",
|
||||||
giveHeaders: map[string]string{"Content-Type": "text/plain"},
|
giveHeaders: map[string]string{"Content-Type": "text/plain"},
|
||||||
|
|
||||||
wantStatusCode: http.StatusOK,
|
wantStatusCode: http.StatusOK,
|
||||||
@ -41,7 +42,7 @@ func TestHandler(t *testing.T) {
|
|||||||
|
|
||||||
return &cfg
|
return &cfg
|
||||||
},
|
},
|
||||||
giveUrl: "/",
|
giveUrl: "http://testing/",
|
||||||
giveHeaders: map[string]string{"X-Format": "text/html", "X-Code": "407"},
|
giveHeaders: map[string]string{"X-Format": "text/html", "X-Code": "407"},
|
||||||
|
|
||||||
wantStatusCode: http.StatusOK,
|
wantStatusCode: http.StatusOK,
|
||||||
@ -60,7 +61,7 @@ func TestHandler(t *testing.T) {
|
|||||||
|
|
||||||
return &cfg
|
return &cfg
|
||||||
},
|
},
|
||||||
giveUrl: "/503.html",
|
giveUrl: "http://testing/503.html",
|
||||||
giveHeaders: map[string]string{"Accept": "application/json", "X-FooBar": "baz"},
|
giveHeaders: map[string]string{"Accept": "application/json", "X-FooBar": "baz"},
|
||||||
|
|
||||||
wantStatusCode: http.StatusServiceUnavailable,
|
wantStatusCode: http.StatusServiceUnavailable,
|
||||||
@ -78,7 +79,7 @@ func TestHandler(t *testing.T) {
|
|||||||
|
|
||||||
return &cfg
|
return &cfg
|
||||||
},
|
},
|
||||||
giveUrl: "/500",
|
giveUrl: "http://testing/500",
|
||||||
giveHeaders: map[string]string{"Accept": "application/xml", "X-FooBar": "baz"},
|
giveHeaders: map[string]string{"Accept": "application/xml", "X-FooBar": "baz"},
|
||||||
|
|
||||||
wantStatusCode: http.StatusOK,
|
wantStatusCode: http.StatusOK,
|
||||||
@ -96,7 +97,7 @@ func TestHandler(t *testing.T) {
|
|||||||
|
|
||||||
return &cfg
|
return &cfg
|
||||||
},
|
},
|
||||||
giveUrl: "/503",
|
giveUrl: "http://example.com/503",
|
||||||
giveHeaders: map[string]string{
|
giveHeaders: map[string]string{
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
"X-Original-URI": "/foo/bar",
|
"X-Original-URI": "/foo/bar",
|
||||||
@ -106,7 +107,6 @@ func TestHandler(t *testing.T) {
|
|||||||
"X-Service-Port": "666",
|
"X-Service-Port": "666",
|
||||||
"X-Request-ID": "req-id-777",
|
"X-Request-ID": "req-id-777",
|
||||||
"X-Forwarded-For": "123.123.123.123:12312",
|
"X-Forwarded-For": "123.123.123.123:12312",
|
||||||
"Host": "example.com",
|
|
||||||
},
|
},
|
||||||
|
|
||||||
wantStatusCode: http.StatusOK,
|
wantStatusCode: http.StatusOK,
|
||||||
@ -133,7 +133,7 @@ func TestHandler(t *testing.T) {
|
|||||||
|
|
||||||
return &cfg
|
return &cfg
|
||||||
},
|
},
|
||||||
giveUrl: "/100",
|
giveUrl: "http://testing/100",
|
||||||
giveHeaders: map[string]string{"Accept": "application/json"},
|
giveHeaders: map[string]string{"Accept": "application/json"},
|
||||||
|
|
||||||
wantStatusCode: http.StatusOK,
|
wantStatusCode: http.StatusOK,
|
||||||
@ -148,7 +148,7 @@ func TestHandler(t *testing.T) {
|
|||||||
|
|
||||||
return &cfg
|
return &cfg
|
||||||
},
|
},
|
||||||
giveUrl: "/1",
|
giveUrl: "http://testing/1",
|
||||||
giveHeaders: map[string]string{"Accept": "application/json"},
|
giveHeaders: map[string]string{"Accept": "application/json"},
|
||||||
|
|
||||||
wantStatusCode: http.StatusOK,
|
wantStatusCode: http.StatusOK,
|
||||||
@ -159,32 +159,31 @@ func TestHandler(t *testing.T) {
|
|||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var (
|
var handler = error_page.New(tt.giveConfig(), logger.NewNop())
|
||||||
req = httptest.NewRequest(http.MethodGet, tt.giveUrl, http.NoBody)
|
|
||||||
handler = error_page.New(tt.giveConfig(), logger.NewNop())
|
req, reqErr := http.NewRequest(http.MethodGet, tt.giveUrl, http.NoBody)
|
||||||
rr = httptest.NewRecorder()
|
require.NoError(t, reqErr)
|
||||||
)
|
|
||||||
|
|
||||||
for k, v := range tt.giveHeaders {
|
for k, v := range tt.giveHeaders {
|
||||||
req.Header.Set(k, v)
|
req.Header.Set(k, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
handler.ServeHTTP(rr, req)
|
httptest.HandleFastRequest(t, handler, req, func(status int, body string, headers http.Header) {
|
||||||
|
assert.Equal(t, tt.wantStatusCode, status)
|
||||||
assert.Equal(t, rr.Code, tt.wantStatusCode)
|
|
||||||
|
|
||||||
for hName, hWant := range tt.wantHeaders {
|
for hName, hWant := range tt.wantHeaders {
|
||||||
for hGot := range rr.Header() {
|
for hGot := range headers {
|
||||||
if hGot == hName {
|
if hGot == hName {
|
||||||
assert.Contains(t, hWant, rr.Header().Get(hGot))
|
assert.Contains(t, hWant, headers.Get(hGot))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, wantBodyInclude := range tt.wantBodyIncludes {
|
for _, wantBodyInclude := range tt.wantBodyIncludes {
|
||||||
assert.Contains(t, rr.Body.String(), wantBodyInclude)
|
assert.Contains(t, body, wantBodyInclude)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,19 +206,17 @@ func TestRotationModeOnEachRequest(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
for range 300 {
|
for range 300 {
|
||||||
var (
|
req, reqErr := http.NewRequest(http.MethodGet, "http://testing/", http.NoBody)
|
||||||
req = httptest.NewRequest(http.MethodGet, "/", http.NoBody)
|
require.NoError(t, reqErr)
|
||||||
rr = httptest.NewRecorder()
|
|
||||||
)
|
|
||||||
|
|
||||||
req.Header.Set("Accept", "text/html")
|
req.Header.Set("Accept", "text/html")
|
||||||
|
|
||||||
handler.ServeHTTP(rr, req)
|
httptest.HandleFastRequest(t, handler, req, func(status int, body string, headers http.Header) {
|
||||||
|
if lastResponseBody != body {
|
||||||
if lastResponseBody != rr.Body.String() {
|
|
||||||
changedTimes++
|
changedTimes++
|
||||||
lastResponseBody = rr.Body.String()
|
lastResponseBody = body
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.True(t, changedTimes > 30, "the template should be changed at least 30 times")
|
assert.True(t, changedTimes > 30, "the template should be changed at least 30 times")
|
||||||
|
@ -2,24 +2,29 @@ package live
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// New creates a new handler that returns "OK" for GET and HEAD requests.
|
// New creates a new handler that returns "OK" for GET and HEAD requests.
|
||||||
func New() http.Handler {
|
func New() fasthttp.RequestHandler {
|
||||||
var body = []byte("OK\n")
|
var (
|
||||||
|
body = []byte("OK\n")
|
||||||
|
notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n"
|
||||||
|
)
|
||||||
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return func(ctx *fasthttp.RequestCtx) {
|
||||||
switch r.Method {
|
switch string(ctx.Method()) {
|
||||||
case http.MethodGet:
|
case fasthttp.MethodGet:
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
ctx.SetContentType("text/plain; charset=utf-8")
|
||||||
w.WriteHeader(http.StatusOK)
|
ctx.SetStatusCode(http.StatusOK)
|
||||||
_, _ = w.Write(body)
|
_, _ = ctx.Write(body)
|
||||||
|
|
||||||
case http.MethodHead:
|
case fasthttp.MethodHead:
|
||||||
w.WriteHeader(http.StatusOK)
|
ctx.SetStatusCode(http.StatusOK)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
ctx.Error(notAllowed, http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
@ -2,43 +2,37 @@ package live_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"gh.tarampamp.am/error-pages/internal/http/handlers/live"
|
"gh.tarampamp.am/error-pages/internal/http/handlers/live"
|
||||||
|
"gh.tarampamp.am/error-pages/internal/http/httptest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServeHTTP(t *testing.T) {
|
func TestServeHTTP(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var handler = live.New()
|
|
||||||
|
|
||||||
t.Run("get", func(t *testing.T) {
|
|
||||||
var (
|
var (
|
||||||
req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody)
|
handler = live.New()
|
||||||
rr = httptest.NewRecorder()
|
url = "http://testing"
|
||||||
|
body = http.NoBody
|
||||||
)
|
)
|
||||||
|
|
||||||
handler.ServeHTTP(rr, req)
|
t.Run("get", func(t *testing.T) {
|
||||||
|
httptest.HandleFast(t, handler, http.MethodGet, url, body, func(status int, body string, headers http.Header) {
|
||||||
assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8")
|
assert.Equal(t, http.StatusOK, status)
|
||||||
assert.Equal(t, rr.Code, http.StatusOK)
|
assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type"))
|
||||||
assert.Equal(t, "OK\n", rr.Body.String())
|
assert.Equal(t, "OK\n", body)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("head", func(t *testing.T) {
|
t.Run("head", func(t *testing.T) {
|
||||||
var (
|
httptest.HandleFast(t, handler, http.MethodHead, url, body, func(status int, body string, headers http.Header) {
|
||||||
req = httptest.NewRequest(http.MethodHead, "http://testing", http.NoBody)
|
assert.Equal(t, http.StatusOK, status)
|
||||||
rr = httptest.NewRecorder()
|
assert.Empty(t, headers.Get("Content-Type"))
|
||||||
)
|
assert.Empty(t, body)
|
||||||
|
})
|
||||||
handler.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
assert.Equal(t, rr.Code, http.StatusOK)
|
|
||||||
assert.Empty(t, rr.Header().Get("Content-Type"))
|
|
||||||
assert.Empty(t, rr.Body.Bytes())
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("method not allowed", func(t *testing.T) {
|
t.Run("method not allowed", func(t *testing.T) {
|
||||||
@ -48,16 +42,11 @@ func TestServeHTTP(t *testing.T) {
|
|||||||
http.MethodPost,
|
http.MethodPost,
|
||||||
http.MethodPut,
|
http.MethodPut,
|
||||||
} {
|
} {
|
||||||
var (
|
httptest.HandleFast(t, handler, method, url, body, func(status int, body string, headers http.Header) {
|
||||||
req = httptest.NewRequest(method, "http://testing", http.NoBody)
|
assert.Equal(t, http.StatusMethodNotAllowed, status)
|
||||||
rr = httptest.NewRecorder()
|
assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type"))
|
||||||
)
|
assert.Equal(t, "Method Not Allowed\n", body)
|
||||||
|
})
|
||||||
handler.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8")
|
|
||||||
assert.Equal(t, rr.Code, http.StatusMethodNotAllowed)
|
|
||||||
assert.Equal(t, "Method Not Allowed\n", rr.Body.String())
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -3,25 +3,29 @@ package static
|
|||||||
import (
|
import (
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed favicon.ico
|
//go:embed favicon.ico
|
||||||
var Favicon []byte
|
var Favicon []byte
|
||||||
|
|
||||||
// New creates a new handler that returns the provided content for GET and HEAD requests.
|
// New creates a new handler that returns the provided content for GET and HEAD requests.
|
||||||
func New(content []byte) http.Handler {
|
func New(content []byte) fasthttp.RequestHandler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
var notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n"
|
||||||
switch r.Method {
|
|
||||||
case http.MethodGet:
|
|
||||||
w.Header().Set("Content-Type", http.DetectContentType(content))
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = w.Write(content)
|
|
||||||
|
|
||||||
case http.MethodHead:
|
return func(ctx *fasthttp.RequestCtx) {
|
||||||
w.WriteHeader(http.StatusOK)
|
switch string(ctx.Method()) {
|
||||||
|
case fasthttp.MethodGet:
|
||||||
|
ctx.SetContentType(http.DetectContentType(content))
|
||||||
|
ctx.SetStatusCode(http.StatusOK)
|
||||||
|
_, _ = ctx.Write(content)
|
||||||
|
|
||||||
|
case fasthttp.MethodHead:
|
||||||
|
ctx.SetStatusCode(http.StatusOK)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
ctx.Error(notAllowed, http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
@ -2,43 +2,37 @@ package static_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"gh.tarampamp.am/error-pages/internal/http/handlers/static"
|
"gh.tarampamp.am/error-pages/internal/http/handlers/static"
|
||||||
|
"gh.tarampamp.am/error-pages/internal/http/httptest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServeHTTP(t *testing.T) {
|
func TestServeHTTP(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var handler = static.New([]byte{1, 2, 3})
|
|
||||||
|
|
||||||
t.Run("get", func(t *testing.T) {
|
|
||||||
var (
|
var (
|
||||||
req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody)
|
handler = static.New([]byte{1, 2, 3})
|
||||||
rr = httptest.NewRecorder()
|
url = "http://testing"
|
||||||
|
body = http.NoBody
|
||||||
)
|
)
|
||||||
|
|
||||||
handler.ServeHTTP(rr, req)
|
t.Run("get", func(t *testing.T) {
|
||||||
|
httptest.HandleFast(t, handler, http.MethodGet, url, body, func(status int, body string, headers http.Header) {
|
||||||
assert.Equal(t, rr.Header().Get("Content-Type"), "application/octet-stream")
|
assert.Equal(t, http.StatusOK, status)
|
||||||
assert.Equal(t, rr.Code, http.StatusOK)
|
assert.Equal(t, "application/octet-stream", headers.Get("Content-Type"))
|
||||||
assert.Equal(t, rr.Body.Bytes(), []byte{1, 2, 3})
|
assert.Equal(t, []byte{1, 2, 3}, []byte(body))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("head", func(t *testing.T) {
|
t.Run("head", func(t *testing.T) {
|
||||||
var (
|
httptest.HandleFast(t, handler, http.MethodHead, url, body, func(status int, body string, headers http.Header) {
|
||||||
req = httptest.NewRequest(http.MethodHead, "http://testing", http.NoBody)
|
assert.Equal(t, http.StatusOK, status)
|
||||||
rr = httptest.NewRecorder()
|
assert.Empty(t, headers.Get("Content-Type"))
|
||||||
)
|
assert.Empty(t, body)
|
||||||
|
})
|
||||||
handler.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
assert.Equal(t, rr.Code, http.StatusOK)
|
|
||||||
assert.Empty(t, rr.Header().Get("Content-Type"))
|
|
||||||
assert.Empty(t, rr.Body.Bytes())
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("method not allowed", func(t *testing.T) {
|
t.Run("method not allowed", func(t *testing.T) {
|
||||||
@ -48,16 +42,11 @@ func TestServeHTTP(t *testing.T) {
|
|||||||
http.MethodPost,
|
http.MethodPost,
|
||||||
http.MethodPut,
|
http.MethodPut,
|
||||||
} {
|
} {
|
||||||
var (
|
httptest.HandleFast(t, handler, method, url, body, func(status int, body string, headers http.Header) {
|
||||||
req = httptest.NewRequest(method, "http://testing", http.NoBody)
|
assert.Equal(t, http.StatusMethodNotAllowed, status)
|
||||||
rr = httptest.NewRecorder()
|
assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type"))
|
||||||
)
|
assert.Equal(t, "Method Not Allowed\n", body)
|
||||||
|
})
|
||||||
handler.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8")
|
|
||||||
assert.Equal(t, rr.Code, http.StatusMethodNotAllowed)
|
|
||||||
assert.Equal(t, "Method Not Allowed\n", rr.Body.String())
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -65,16 +54,15 @@ func TestServeHTTP(t *testing.T) {
|
|||||||
func TestServeHTTP_Favicon(t *testing.T) {
|
func TestServeHTTP_Favicon(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var (
|
httptest.HandleFast(t,
|
||||||
handler = static.New(static.Favicon)
|
static.New(static.Favicon),
|
||||||
|
http.MethodGet,
|
||||||
req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody)
|
"http://testing",
|
||||||
rr = httptest.NewRecorder()
|
http.NoBody,
|
||||||
|
func(status int, body string, headers http.Header) {
|
||||||
|
assert.Equal(t, http.StatusOK, status)
|
||||||
|
assert.Equal(t, "image/x-icon", headers.Get("Content-Type"))
|
||||||
|
assert.Equal(t, static.Favicon, []byte(body))
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
handler.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
assert.Equal(t, rr.Header().Get("Content-Type"), "image/x-icon")
|
|
||||||
assert.Equal(t, rr.Code, http.StatusOK)
|
|
||||||
assert.Equal(t, rr.Body.Bytes(), static.Favicon)
|
|
||||||
}
|
}
|
||||||
|
@ -4,28 +4,32 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// New creates a handler that returns the version of the service in JSON format.
|
// New creates a handler that returns the version of the service in JSON format.
|
||||||
func New(ver string) http.Handler {
|
func New(ver string) fasthttp.RequestHandler {
|
||||||
var body, _ = json.Marshal(struct { //nolint:errchkjson
|
var body, _ = json.Marshal(struct { //nolint:errchkjson
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}{
|
}{
|
||||||
Version: strings.TrimSpace(ver),
|
Version: strings.TrimSpace(ver),
|
||||||
})
|
})
|
||||||
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
var notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n"
|
||||||
switch r.Method {
|
|
||||||
case http.MethodGet:
|
|
||||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = w.Write(body)
|
|
||||||
|
|
||||||
case http.MethodHead:
|
return func(ctx *fasthttp.RequestCtx) {
|
||||||
w.WriteHeader(http.StatusOK)
|
switch string(ctx.Method()) {
|
||||||
|
case fasthttp.MethodGet:
|
||||||
|
ctx.SetContentType("application/json; charset=utf-8")
|
||||||
|
ctx.SetStatusCode(http.StatusOK)
|
||||||
|
_, _ = ctx.Write(body)
|
||||||
|
|
||||||
|
case fasthttp.MethodHead:
|
||||||
|
ctx.SetStatusCode(http.StatusOK)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
ctx.Error(notAllowed, http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
@ -2,43 +2,37 @@ package version_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"gh.tarampamp.am/error-pages/internal/http/handlers/version"
|
"gh.tarampamp.am/error-pages/internal/http/handlers/version"
|
||||||
|
"gh.tarampamp.am/error-pages/internal/http/httptest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServeHTTP(t *testing.T) {
|
func TestServeHTTP(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var handler = version.New("\t\n foo@bar ")
|
|
||||||
|
|
||||||
t.Run("get", func(t *testing.T) {
|
|
||||||
var (
|
var (
|
||||||
req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody)
|
handler = version.New("\t\n foo@bar ")
|
||||||
rr = httptest.NewRecorder()
|
url = "http://testing"
|
||||||
|
body = http.NoBody
|
||||||
)
|
)
|
||||||
|
|
||||||
handler.ServeHTTP(rr, req)
|
t.Run("get", func(t *testing.T) {
|
||||||
|
httptest.HandleFast(t, handler, http.MethodGet, url, body, func(status int, body string, headers http.Header) {
|
||||||
assert.Equal(t, rr.Header().Get("Content-Type"), "application/json; charset=utf-8")
|
assert.Equal(t, http.StatusOK, status)
|
||||||
assert.Equal(t, rr.Code, http.StatusOK)
|
assert.Equal(t, "application/json; charset=utf-8", headers.Get("Content-Type"))
|
||||||
assert.Equal(t, rr.Body.String(), `{"version":"foo@bar"}`)
|
assert.Equal(t, `{"version":"foo@bar"}`, body)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("head", func(t *testing.T) {
|
t.Run("head", func(t *testing.T) {
|
||||||
var (
|
httptest.HandleFast(t, handler, http.MethodHead, url, body, func(status int, body string, headers http.Header) {
|
||||||
req = httptest.NewRequest(http.MethodHead, "http://testing", http.NoBody)
|
assert.Equal(t, http.StatusOK, status)
|
||||||
rr = httptest.NewRecorder()
|
assert.Empty(t, headers.Get("Content-Type"))
|
||||||
)
|
assert.Empty(t, body)
|
||||||
|
})
|
||||||
handler.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
assert.Equal(t, rr.Code, http.StatusOK)
|
|
||||||
assert.Empty(t, rr.Header().Get("Content-Type"))
|
|
||||||
assert.Empty(t, rr.Body.Bytes())
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("method not allowed", func(t *testing.T) {
|
t.Run("method not allowed", func(t *testing.T) {
|
||||||
@ -48,16 +42,11 @@ func TestServeHTTP(t *testing.T) {
|
|||||||
http.MethodPost,
|
http.MethodPost,
|
||||||
http.MethodPut,
|
http.MethodPut,
|
||||||
} {
|
} {
|
||||||
var (
|
httptest.HandleFast(t, handler, method, url, body, func(status int, body string, headers http.Header) {
|
||||||
req = httptest.NewRequest(method, "http://testing", http.NoBody)
|
assert.Equal(t, http.StatusMethodNotAllowed, status)
|
||||||
rr = httptest.NewRecorder()
|
assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type"))
|
||||||
)
|
assert.Equal(t, "Method Not Allowed\n", body)
|
||||||
|
})
|
||||||
handler.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8")
|
|
||||||
assert.Equal(t, rr.Code, http.StatusMethodNotAllowed)
|
|
||||||
assert.Equal(t, "Method Not Allowed\n", rr.Body.String())
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
69
internal/http/httptest/httptest.go
Normal file
69
internal/http/httptest/httptest.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
// Package httptest provides utilities for (fast-)HTTP testing.
|
||||||
|
package httptest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
"github.com/valyala/fasthttp/fasthttputil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HandleFastRequest serves http request using provided fasthttp handler and HTTP request.
|
||||||
|
func HandleFastRequest(
|
||||||
|
t *testing.T,
|
||||||
|
handler fasthttp.RequestHandler,
|
||||||
|
req *http.Request,
|
||||||
|
check func(status int, body string, _ http.Header),
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// create in-memory listener
|
||||||
|
var ln = fasthttputil.NewInmemoryListener()
|
||||||
|
defer func() { require.NoError(t, ln.Close()) }()
|
||||||
|
|
||||||
|
// start fasthttp server
|
||||||
|
go func() { require.NoError(t, fasthttp.Serve(ln, handler)) }()
|
||||||
|
|
||||||
|
// send http request
|
||||||
|
resp, respErr := (&http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return ln.Dial() },
|
||||||
|
},
|
||||||
|
}).Do(req)
|
||||||
|
require.NoError(t, respErr)
|
||||||
|
|
||||||
|
// close response body after the test
|
||||||
|
defer func() { assert.NoError(t, resp.Body.Close()) }()
|
||||||
|
|
||||||
|
// read response body
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// check the response
|
||||||
|
check(resp.StatusCode, string(respBody), resp.Header)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleFast serves http request using provided fasthttp handler.
|
||||||
|
func HandleFast(
|
||||||
|
t *testing.T,
|
||||||
|
handler fasthttp.RequestHandler,
|
||||||
|
method string,
|
||||||
|
url string,
|
||||||
|
body io.Reader,
|
||||||
|
check func(status int, body string, _ http.Header),
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// create http request
|
||||||
|
req, reqErr := http.NewRequest(method, url, body)
|
||||||
|
require.NoError(t, reqErr)
|
||||||
|
|
||||||
|
// serve http request
|
||||||
|
HandleFastRequest(t, handler, req, check)
|
||||||
|
}
|
@ -1,20 +1,24 @@
|
|||||||
package logreq
|
package logreq
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
"gh.tarampamp.am/error-pages/internal/logger"
|
"gh.tarampamp.am/error-pages/internal/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// New creates a middleware for [http.ServeMux] that logs every incoming request.
|
// New creates a middleware that logs every incoming request.
|
||||||
//
|
//
|
||||||
// The skipper function should return true if the request should be skipped. It's ok to pass nil.
|
// The skipper function should return true if the request should be skipped. It's ok to pass nil.
|
||||||
func New(log *logger.Logger, skipper func(*http.Request) bool) func(http.Handler) http.Handler {
|
func New(
|
||||||
return func(next http.Handler) http.Handler {
|
log *logger.Logger,
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
skipper func(*fasthttp.RequestCtx) bool,
|
||||||
if skipper != nil && skipper(r) {
|
) func(fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||||
next.ServeHTTP(w, r)
|
return func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||||
|
return func(ctx *fasthttp.RequestCtx) {
|
||||||
|
if skipper != nil && skipper(ctx) {
|
||||||
|
next(ctx)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -23,27 +27,35 @@ func New(log *logger.Logger, skipper func(*http.Request) bool) func(http.Handler
|
|||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
var fields = []logger.Attr{
|
var fields = []logger.Attr{
|
||||||
logger.String("useragent", r.UserAgent()),
|
logger.Int("status code", ctx.Response.StatusCode()),
|
||||||
logger.String("method", r.Method),
|
logger.String("useragent", string(ctx.UserAgent())),
|
||||||
logger.String("url", r.URL.String()),
|
logger.String("method", string(ctx.Method())),
|
||||||
logger.String("referer", r.Referer()),
|
logger.String("url", string(ctx.RequestURI())),
|
||||||
logger.String("content type", w.Header().Get("Content-Type")),
|
logger.String("referer", string(ctx.Referer())),
|
||||||
logger.String("remote addr", r.RemoteAddr),
|
logger.String("content type", string(ctx.Response.Header.ContentType())),
|
||||||
logger.String("method", r.Method),
|
logger.String("remote addr", ctx.RemoteAddr().String()),
|
||||||
logger.Duration("duration", time.Since(now).Round(time.Microsecond)),
|
logger.Duration("duration", time.Since(now).Round(time.Microsecond)),
|
||||||
}
|
}
|
||||||
|
|
||||||
if log.Level() <= logger.DebugLevel {
|
if log.Level() <= logger.DebugLevel {
|
||||||
|
var (
|
||||||
|
reqHeaders = make(map[string]string)
|
||||||
|
respHeaders = make(map[string]string)
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.Request.Header.VisitAll(func(key, value []byte) { reqHeaders[string(key)] = string(value) })
|
||||||
|
ctx.Response.Header.VisitAll(func(key, value []byte) { respHeaders[string(key)] = string(value) })
|
||||||
|
|
||||||
fields = append(fields,
|
fields = append(fields,
|
||||||
logger.Any("request headers", r.Header.Clone()),
|
logger.Any("request headers", reqHeaders),
|
||||||
logger.Any("response headers", w.Header().Clone()),
|
logger.Any("response headers", respHeaders),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("HTTP request processed", fields...)
|
log.Info("HTTP request processed", fields...)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next(ctx)
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,11 +3,12 @@ package logreq_test
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
|
"gh.tarampamp.am/error-pages/internal/http/httptest"
|
||||||
"gh.tarampamp.am/error-pages/internal/http/middleware/logreq"
|
"gh.tarampamp.am/error-pages/internal/http/middleware/logreq"
|
||||||
"gh.tarampamp.am/error-pages/internal/logger"
|
"gh.tarampamp.am/error-pages/internal/logger"
|
||||||
)
|
)
|
||||||
@ -20,17 +21,18 @@ func TestNew(t *testing.T) {
|
|||||||
log, _ = logger.New(logger.DebugLevel, logger.JSONFormat, &buf)
|
log, _ = logger.New(logger.DebugLevel, logger.JSONFormat, &buf)
|
||||||
|
|
||||||
mw = logreq.New(log, nil)
|
mw = logreq.New(log, nil)
|
||||||
rr = httptest.NewRecorder()
|
req, _ = http.NewRequest(http.MethodPut, "http://testing/foo/bar", http.NoBody)
|
||||||
req = httptest.NewRequest(http.MethodPut, "/foo/bar", http.NoBody)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
req.Header.Set("User-Agent", "test")
|
req.Header.Set("User-Agent", "test")
|
||||||
req.Header.Set("Referer", "https://example.com")
|
req.Header.Set("Referer", "https://example.com")
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httptest.HandleFastRequest(t,
|
||||||
w.WriteHeader(http.StatusOK)
|
mw(func(ctx *fasthttp.RequestCtx) { ctx.SetStatusCode(http.StatusOK) }),
|
||||||
})).ServeHTTP(rr, req)
|
req,
|
||||||
|
func(status int, body string, _ http.Header) { assert.Equal(t, http.StatusOK, status) },
|
||||||
|
)
|
||||||
|
|
||||||
var logRecord = buf.String()
|
var logRecord = buf.String()
|
||||||
|
|
||||||
|
@ -2,12 +2,15 @@ package http
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
"gh.tarampamp.am/error-pages/internal/appmeta"
|
"gh.tarampamp.am/error-pages/internal/appmeta"
|
||||||
"gh.tarampamp.am/error-pages/internal/config"
|
"gh.tarampamp.am/error-pages/internal/config"
|
||||||
ep "gh.tarampamp.am/error-pages/internal/http/handlers/error_page"
|
ep "gh.tarampamp.am/error-pages/internal/http/handlers/error_page"
|
||||||
@ -21,11 +24,11 @@ import (
|
|||||||
// Server is an HTTP server for serving error pages.
|
// Server is an HTTP server for serving error pages.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
server *http.Server
|
server *fasthttp.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new HTTP server.
|
// NewServer creates a new HTTP server.
|
||||||
func NewServer(baseCtx context.Context, log *logger.Logger) Server {
|
func NewServer(log *logger.Logger) Server {
|
||||||
const (
|
const (
|
||||||
readTimeout = 30 * time.Second
|
readTimeout = 30 * time.Second
|
||||||
writeTimeout = readTimeout + 10*time.Second // should be bigger than the read timeout
|
writeTimeout = readTimeout + 10*time.Second // should be bigger than the read timeout
|
||||||
@ -34,13 +37,14 @@ func NewServer(baseCtx context.Context, log *logger.Logger) Server {
|
|||||||
|
|
||||||
return Server{
|
return Server{
|
||||||
log: log,
|
log: log,
|
||||||
server: &http.Server{
|
server: &fasthttp.Server{
|
||||||
ReadTimeout: readTimeout,
|
ReadTimeout: readTimeout,
|
||||||
WriteTimeout: writeTimeout,
|
WriteTimeout: writeTimeout,
|
||||||
ReadHeaderTimeout: readTimeout,
|
ReadBufferSize: maxHeaderBytes,
|
||||||
MaxHeaderBytes: maxHeaderBytes,
|
DisablePreParseMultipartForm: true,
|
||||||
ErrorLog: logger.NewStdLog(log),
|
NoDefaultServerHeader: true,
|
||||||
BaseContext: func(net.Listener) context.Context { return baseCtx },
|
CloseOnShutdown: true,
|
||||||
|
Logger: logger.NewStdLog(log),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -52,60 +56,78 @@ func (s *Server) Register(cfg *config.Config) error {
|
|||||||
versionHandler = version.New(appmeta.Version())
|
versionHandler = version.New(appmeta.Version())
|
||||||
errorPagesHandler = ep.New(cfg, s.log)
|
errorPagesHandler = ep.New(cfg, s.log)
|
||||||
faviconHandler = static.New(static.Favicon)
|
faviconHandler = static.New(static.Favicon)
|
||||||
|
|
||||||
|
notFound = http.StatusText(http.StatusNotFound) + "\n"
|
||||||
|
notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
s.server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s.server.Handler = func(ctx *fasthttp.RequestCtx) {
|
||||||
var url, method = r.URL.Path, r.Method
|
var url, method = string(ctx.RequestURI()), string(ctx.Method())
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
// live endpoints
|
// live endpoints
|
||||||
case url == "/health/live" || url == "/health" || url == "/healthz" || url == "/live":
|
case url == "/health/live" || url == "/health" || url == "/healthz" || url == "/live":
|
||||||
liveHandler.ServeHTTP(w, r)
|
liveHandler(ctx)
|
||||||
|
|
||||||
// version endpoint
|
// version endpoint
|
||||||
case url == "/version":
|
case url == "/version":
|
||||||
versionHandler.ServeHTTP(w, r)
|
versionHandler(ctx)
|
||||||
|
|
||||||
// favicon.ico endpoint
|
// favicon.ico endpoint
|
||||||
case url == "/favicon.ico":
|
case url == "/favicon.ico":
|
||||||
faviconHandler.ServeHTTP(w, r)
|
faviconHandler(ctx)
|
||||||
|
|
||||||
// error pages endpoints:
|
// error pages endpoints:
|
||||||
// - /
|
// - /
|
||||||
// - /{code}.html
|
// - /{code}.html
|
||||||
// - /{code}.htm
|
// - /{code}.htm
|
||||||
// - /{code}
|
// - /{code}
|
||||||
case method == http.MethodGet && (url == "/" || ep.URLContainsCode(url) || ep.HeadersContainCode(r.Header)):
|
case method == fasthttp.MethodGet &&
|
||||||
errorPagesHandler.ServeHTTP(w, r)
|
(url == "/" || ep.URLContainsCode(url) || ep.HeadersContainCode(&ctx.Request.Header)):
|
||||||
|
errorPagesHandler(ctx)
|
||||||
|
|
||||||
// wrong requests handling
|
// wrong requests handling
|
||||||
default:
|
default:
|
||||||
switch {
|
switch {
|
||||||
case method == http.MethodHead:
|
case method == fasthttp.MethodHead:
|
||||||
w.WriteHeader(http.StatusNotFound)
|
ctx.Error(notAllowed, fasthttp.StatusNotFound)
|
||||||
case method == http.MethodGet:
|
case method == fasthttp.MethodGet:
|
||||||
http.NotFound(w, r)
|
ctx.Error(notFound, fasthttp.StatusNotFound)
|
||||||
default:
|
default:
|
||||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
ctx.Error(notAllowed, fasthttp.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
|
||||||
|
|
||||||
// apply middleware
|
// apply middleware
|
||||||
s.server.Handler = logreq.New(s.log, func(r *http.Request) bool {
|
s.server.Handler = logreq.New(s.log, func(ctx *fasthttp.RequestCtx) bool {
|
||||||
// skip logging healthcheck and .ico (favicon) requests
|
// skip logging healthcheck and .ico (favicon) requests
|
||||||
return strings.Contains(strings.ToLower(r.UserAgent()), "healthcheck") ||
|
return strings.Contains(strings.ToLower(string(ctx.UserAgent())), "healthcheck") ||
|
||||||
strings.HasSuffix(r.URL.Path, ".ico")
|
strings.HasSuffix(string(ctx.Path()), ".ico")
|
||||||
})(s.server.Handler)
|
})(s.server.Handler)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start server.
|
// Start server.
|
||||||
func (s *Server) Start(ip string, port uint16) error {
|
func (s *Server) Start(ip string, port uint16) (err error) {
|
||||||
s.server.Addr = ip + ":" + strconv.Itoa(int(port))
|
if net.ParseIP(ip) == nil {
|
||||||
|
return errors.New("invalid IP address")
|
||||||
|
}
|
||||||
|
|
||||||
return s.server.ListenAndServe()
|
var ln net.Listener
|
||||||
|
|
||||||
|
if strings.Count(ip, ":") >= 2 { //nolint:mnd // ipv6
|
||||||
|
if ln, err = net.Listen("tcp6", fmt.Sprintf("[%s]:%d", ip, port)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else { // ipv4
|
||||||
|
if ln, err = net.Listen("tcp4", fmt.Sprintf("%s:%d", ip, port)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.server.Serve(ln)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop server gracefully.
|
// Stop server gracefully.
|
||||||
@ -113,5 +135,5 @@ func (s *Server) Stop(timeout time.Duration) error {
|
|||||||
var ctx, cancel = context.WithTimeout(context.Background(), timeout)
|
var ctx, cancel = context.WithTimeout(context.Background(), timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
return s.server.Shutdown(ctx)
|
return s.server.ShutdownWithContext(ctx)
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package http_test
|
package http_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -21,7 +20,7 @@ import (
|
|||||||
// TestRouting in fact is a test for the whole server, because it tests all the routes and their handlers.
|
// TestRouting in fact is a test for the whole server, because it tests all the routes and their handlers.
|
||||||
func TestRouting(t *testing.T) {
|
func TestRouting(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
srv = appHttp.NewServer(context.Background(), logger.NewNop())
|
srv = appHttp.NewServer(logger.NewNop())
|
||||||
cfg = config.New()
|
cfg = config.New()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -296,7 +295,7 @@ func TestRouting(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, http.StatusNotFound, status)
|
assert.Equal(t, http.StatusNotFound, status)
|
||||||
assert.Empty(t, body)
|
assert.Empty(t, body)
|
||||||
assert.Empty(t, headers.Get("Content-Type"))
|
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user