feat: fasthttp is back

This commit is contained in:
Paramtamtam 2024-07-01 23:41:26 +04:00
parent afaef54ddf
commit df2163e41e
No known key found for this signature in database
GPG Key ID: 366371698FAD0A2B
21 changed files with 432 additions and 315 deletions

4
go.mod
View File

@ -10,12 +10,16 @@ require (
) )
require ( require (
github.com/andybalholm/brotli v1.1.0 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/klauspost/compress v1.17.9 // indirect
github.com/kr/pretty v0.3.1 // indirect github.com/kr/pretty v0.3.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.55.0 // indirect
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect

8
go.sum
View File

@ -1,8 +1,12 @@
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4= github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
@ -26,6 +30,10 @@ github.com/urfave/cli-docs/v3 v3.0.0-alpha5 h1:H1oWnR2/GN0dNm2PVylws+GxSOD6YOwW/
github.com/urfave/cli-docs/v3 v3.0.0-alpha5/go.mod h1:AIqom6Q60U4tiqHp41i7+/AB2XHgi1WvQ7jOFlccmZ4= github.com/urfave/cli-docs/v3 v3.0.0-alpha5/go.mod h1:AIqom6Q60U4tiqHp41i7+/AB2XHgi1WvQ7jOFlccmZ4=
github.com/urfave/cli/v3 v3.0.0-alpha9 h1:P0RMy5fQm1AslQS+XCmy9UknDXctOmG/q/FZkUFnJSo= github.com/urfave/cli/v3 v3.0.0-alpha9 h1:P0RMy5fQm1AslQS+XCmy9UknDXctOmG/q/FZkUFnJSo=
github.com/urfave/cli/v3 v3.0.0-alpha9/go.mod h1:0kK/RUFHyh+yIKSfWxwheGndfnrvYSmYFVeKCh03ZUc= github.com/urfave/cli/v3 v3.0.0-alpha9/go.mod h1:0kK/RUFHyh+yIKSfWxwheGndfnrvYSmYFVeKCh03ZUc=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.55.0 h1:Zkefzgt6a7+bVKHnu/YaYSOPfNYNisSVBo/unVCf8k8=
github.com/valyala/fasthttp v1.55.0/go.mod h1:NkY9JtkrpPKmgwV3HTaS2HWaJss9RSIsRVfcxxoHiOM=
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4=
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8= go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8=

View File

@ -55,7 +55,7 @@ func NewCommand(log *logger.Logger) *cli.Command { //nolint:funlen,gocognit
Aliases: []string{"perf", "test"}, Aliases: []string{"perf", "test"},
Hidden: true, Hidden: true,
Usage: "Simple performance (load) test for the HTTP server", Usage: "Simple performance (load) test for the HTTP server",
Action: func(ctx context.Context, c *cli.Command) error { Action: func(ctx context.Context, c *cli.Command) error { // TODO: use fasthttp.Client
var ( var (
perfCtx, cancel = context.WithTimeout(ctx, c.Duration(durationFlag.Name)) perfCtx, cancel = context.WithTimeout(ctx, c.Duration(durationFlag.Name))
startedAt = time.Now() startedAt = time.Now()

View File

@ -290,7 +290,7 @@ func NewCommand(log *logger.Logger) *cli.Command { //nolint:funlen,gocognit,gocy
// Run current command. // Run current command.
func (cmd *command) Run(ctx context.Context, log *logger.Logger, cfg *config.Config) error { //nolint:funlen func (cmd *command) Run(ctx context.Context, log *logger.Logger, cfg *config.Config) error { //nolint:funlen
var srv = appHttp.NewServer(ctx, log) var srv = appHttp.NewServer(log)
if err := srv.Register(cfg); err != nil { if err := srv.Register(cfg); err != nil {
return err return err

View File

@ -1,10 +1,11 @@
package error_page package error_page
import ( import (
"net/http"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"github.com/valyala/fasthttp"
) )
// extractCodeFromURL extracts the error code from the given URL. // extractCodeFromURL extracts the error code from the given URL.
@ -37,11 +38,15 @@ func extractCodeFromURL(url string) (uint16, bool) {
func URLContainsCode(url string) (ok bool) { _, ok = extractCodeFromURL(url); return } //nolint:nlreturn func URLContainsCode(url string) (ok bool) { _, ok = extractCodeFromURL(url); return } //nolint:nlreturn
// extractCodeFromHeaders extracts the error code from the given headers. // extractCodeFromHeaders extracts the error code from the given headers.
func extractCodeFromHeaders(headers http.Header) (uint16, bool) { func extractCodeFromHeaders(headers *fasthttp.RequestHeader) (uint16, bool) {
if headers == nil {
return 0, false
}
// https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/ // https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/
// HTTP status code returned by the request // HTTP status code returned by the request
if value := headers.Get("X-Code"); len(value) > 0 && len(value) <= 3 { if value := headers.Peek("X-Code"); len(value) > 0 && len(value) <= 3 {
if code, err := strconv.ParseUint(value, 10, 16); err == nil && code > 0 && code < 999 { if code, err := strconv.ParseUint(string(value), 10, 16); err == nil && code > 0 && code < 999 {
return uint16(code), true return uint16(code), true
} }
} }
@ -50,7 +55,7 @@ func extractCodeFromHeaders(headers http.Header) (uint16, bool) {
} }
// HeadersContainCode checks if the given headers contain an error code. // HeadersContainCode checks if the given headers contain an error code.
func HeadersContainCode(headers http.Header) (ok bool) { func HeadersContainCode(headers *fasthttp.RequestHeader) (ok bool) {
_, ok = extractCodeFromHeaders(headers) _, ok = extractCodeFromHeaders(headers)
return return

View File

@ -1,10 +1,10 @@
package error_page_test package error_page_test
import ( import (
"net/http"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/valyala/fasthttp"
"gh.tarampamp.am/error-pages/internal/http/handlers/error_page" "gh.tarampamp.am/error-pages/internal/http/handlers/error_page"
) )
@ -36,18 +36,26 @@ func TestURLContainsCode(t *testing.T) {
func TestHeadersContainCode(t *testing.T) { func TestHeadersContainCode(t *testing.T) {
t.Parallel() t.Parallel()
var mkHeaders = func(key, value string) *fasthttp.RequestHeader {
var out = new(fasthttp.RequestHeader)
out.Set(key, value)
return out
}
for name, _tt := range map[string]struct { for name, _tt := range map[string]struct {
giveHeaders http.Header giveHeaders *fasthttp.RequestHeader
wantOk bool wantOk bool
}{ }{
"with code": {giveHeaders: http.Header{"X-Code": {"404"}}, wantOk: true}, "with code": {giveHeaders: mkHeaders("X-Code", "404"), wantOk: true},
"empty": {giveHeaders: nil}, "empty": {giveHeaders: nil},
"no code": {giveHeaders: http.Header{"X-Code": {""}}}, "no code": {giveHeaders: mkHeaders("X-Code", "")},
"wrong": {giveHeaders: http.Header{"X-Code": {"foo"}}}, "wrong": {giveHeaders: mkHeaders("X-Code", "foo")},
"too big": {giveHeaders: http.Header{"X-Code": {"1000"}}}, "too big": {giveHeaders: mkHeaders("X-Code", "1000")},
"too small": {giveHeaders: http.Header{"X-Code": {"0"}}}, "too small": {giveHeaders: mkHeaders("X-Code", "0")},
"negative": {giveHeaders: http.Header{"X-Code": {"-1"}}}, "negative": {giveHeaders: mkHeaders("X-Code", "-1")},
} { } {
tt := _tt tt := _tt

View File

@ -2,10 +2,11 @@ package error_page
import ( import (
"math" "math"
"net/http"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"github.com/valyala/fasthttp"
) )
type preferredFormat = byte type preferredFormat = byte
@ -21,10 +22,10 @@ const (
// detectPreferredFormatForClient detects the preferred format for the client based on the headers. // detectPreferredFormatForClient detects the preferred format for the client based on the headers.
// It supports the following headers: Content-Type, Accept, X-Format. // It supports the following headers: Content-Type, Accept, X-Format.
// If the headers are not set or the format is not recognized, it returns unknownFormat. // If the headers are not set or the format is not recognized, it returns unknownFormat.
func detectPreferredFormatForClient(headers http.Header) preferredFormat { //nolint:funlen,gocognit func detectPreferredFormatForClient(headers *fasthttp.RequestHeader) preferredFormat { //nolint:funlen,gocognit
var contentType, accept string var contentType, accept string
if contentTypeHeader := strings.TrimSpace(headers.Get("Content-Type")); contentTypeHeader != "" { //nolint:nestif if contentTypeHeader := strings.TrimSpace(string(headers.Peek("Content-Type"))); contentTypeHeader != "" { //nolint:nestif,lll
// https://developer.mozilla.org/docs/Web/HTTP/Headers/Content-Type // https://developer.mozilla.org/docs/Web/HTTP/Headers/Content-Type
// text/html; charset=utf-8 // text/html; charset=utf-8
// multipart/form-data; boundary=something // multipart/form-data; boundary=something
@ -38,11 +39,11 @@ func detectPreferredFormatForClient(headers http.Header) preferredFormat { //nol
// take the whole value // take the whole value
contentType = contentTypeHeader contentType = contentTypeHeader
} }
} else if xFormatHeader := strings.TrimSpace(headers.Get("X-Format")); xFormatHeader != "" { } else if xFormatHeader := strings.TrimSpace(string(headers.Peek("X-Format"))); xFormatHeader != "" {
// https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/ // https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/
// Value of the `Accept` header sent by the client // Value of the `Accept` header sent by the client
accept = xFormatHeader accept = xFormatHeader
} else if acceptHeader := strings.TrimSpace(headers.Get("Accept")); acceptHeader != "" { } else if acceptHeader := strings.TrimSpace(string(headers.Peek("Accept"))); acceptHeader != "" {
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
// text/html, application/xhtml+xml, application/xml;q=0.9, image/webp, */*;q=0.8 // text/html, application/xhtml+xml, application/xml;q=0.9, image/webp, */*;q=0.8
// text/html // text/html

View File

@ -1,80 +1,80 @@
package error_page package error_page
import ( import (
"net/http"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/valyala/fasthttp"
) )
func Test_detectPreferredFormatForClient(t *testing.T) { func Test_detectPreferredFormatForClient(t *testing.T) {
t.Parallel() t.Parallel()
for name, _tt := range map[string]struct { for name, _tt := range map[string]struct {
giveHeaders http.Header giveHeaders map[string][]string
wantFormat preferredFormat wantFormat preferredFormat
}{ }{
"content type json": { "content type json": {
giveHeaders: http.Header{"Content-Type": {"application/jSoN"}}, giveHeaders: map[string][]string{"Content-Type": {"application/jSoN"}},
wantFormat: jsonFormat, wantFormat: jsonFormat,
}, },
"content type xml": { "content type xml": {
giveHeaders: http.Header{"Content-Type": {"application/xml; charset=UTF-8"}}, giveHeaders: map[string][]string{"Content-Type": {"application/xml; charset=UTF-8"}},
wantFormat: xmlFormat, wantFormat: xmlFormat,
}, },
"content type html": { "content type html": {
giveHeaders: http.Header{"Content-Type": {"text/hTmL; charset=utf-8"}}, giveHeaders: map[string][]string{"Content-Type": {"text/hTmL; charset=utf-8"}},
wantFormat: htmlFormat, wantFormat: htmlFormat,
}, },
"content type plain": { "content type plain": {
giveHeaders: http.Header{"Content-Type": {"text/plaIN"}}, giveHeaders: map[string][]string{"Content-Type": {"text/plaIN"}},
wantFormat: plainTextFormat, wantFormat: plainTextFormat,
}, },
"accept json": { "accept json": {
giveHeaders: http.Header{"Accept": {"application/jsoN,*/*;q=0.8"}}, giveHeaders: map[string][]string{"Accept": {"application/jsoN,*/*;q=0.8"}},
wantFormat: jsonFormat, wantFormat: jsonFormat,
}, },
"accept xml, depends on weight": { "accept xml, depends on weight": {
giveHeaders: http.Header{"Accept": {"text/html;q=0.5,application/xhtml+xml;q=0.9,application/xml;q=1,*/*;q=0.8"}}, giveHeaders: map[string][]string{"Accept": {"text/html;q=0.5,application/xhtml+xml;q=0.9,application/xml;q=1,*/*;q=0.8"}},
wantFormat: xmlFormat, wantFormat: xmlFormat,
}, },
"accept json, depends on weight": { "accept json, depends on weight": {
giveHeaders: http.Header{"Accept": {"application/jsoN,*/*;q=0.8"}}, giveHeaders: map[string][]string{"Accept": {"application/jsoN,*/*;q=0.8"}},
wantFormat: jsonFormat, wantFormat: jsonFormat,
}, },
"accept xml": { "accept xml": {
giveHeaders: http.Header{"Accept": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}}, giveHeaders: map[string][]string{"Accept": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}},
wantFormat: xmlFormat, wantFormat: xmlFormat,
}, },
"accept html": { "accept html": {
giveHeaders: http.Header{"Accept": {"text/html, application/xhtml+xml, application/xml;q=0.9, image/avif, image/webp, */*;q=0.8"}}, giveHeaders: map[string][]string{"Accept": {"text/html, application/xhtml+xml, application/xml;q=0.9, image/avif, image/webp, */*;q=0.8"}},
wantFormat: htmlFormat, wantFormat: htmlFormat,
}, },
"accept plain": { "accept plain": {
giveHeaders: http.Header{"Accept": {"text/plaiN,text/html,application/xml;q=0.9,,,*/*;q=0.8"}}, giveHeaders: map[string][]string{"Accept": {"text/plaiN,text/html,application/xml;q=0.9,,,*/*;q=0.8"}},
wantFormat: plainTextFormat, wantFormat: plainTextFormat,
}, },
"accept json, weighted values only": { "accept json, weighted values only": {
giveHeaders: http.Header{"Accept": {"application/jsoN;Q=0.1,text/html;q=1.1,application/xml;q=-1,*/*;q=0.8"}}, giveHeaders: map[string][]string{"Accept": {"application/jsoN;Q=0.1,text/html;q=1.1,application/xml;q=-1,*/*;q=0.8"}},
wantFormat: jsonFormat, wantFormat: jsonFormat,
}, },
"x-format json, depends on weight": { "x-format json, depends on weight": {
giveHeaders: http.Header{"X-Format": {"application/jsoN,*/*;q=0.8"}}, giveHeaders: map[string][]string{"X-Format": {"application/jsoN,*/*;q=0.8"}},
wantFormat: jsonFormat, wantFormat: jsonFormat,
}, },
"x-format xml": { "x-format xml": {
giveHeaders: http.Header{"X-Format": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}}, giveHeaders: map[string][]string{"X-Format": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}},
wantFormat: xmlFormat, wantFormat: xmlFormat,
}, },
"content type has priority over accept": { "content type has priority over accept": {
giveHeaders: http.Header{"Content-Type": {"text/plain"}, "Accept": {"application/xml"}}, giveHeaders: map[string][]string{"Content-Type": {"text/plain"}, "Accept": {"application/xml"}},
wantFormat: plainTextFormat, wantFormat: plainTextFormat,
}, },
"accept has priority over x-format": { "accept has priority over x-format": {
giveHeaders: http.Header{"Accept": {"application/xml"}, "X-Format": {"text/plain"}}, giveHeaders: map[string][]string{"Accept": {"application/xml"}, "X-Format": {"text/plain"}},
wantFormat: plainTextFormat, wantFormat: plainTextFormat,
}, },
@ -82,25 +82,33 @@ func Test_detectPreferredFormatForClient(t *testing.T) {
giveHeaders: nil, giveHeaders: nil,
}, },
"empty content type": { "empty content type": {
giveHeaders: http.Header{"Content-Type": {" "}}, giveHeaders: map[string][]string{"Content-Type": {" "}},
}, },
"wrong content type": { "wrong content type": {
giveHeaders: http.Header{"Content-Type": {"multipart/form-data; boundary=something"}}, giveHeaders: map[string][]string{"Content-Type": {"multipart/form-data; boundary=something"}},
}, },
"wrong accept": { "wrong accept": {
giveHeaders: http.Header{"Accept": {";q=foobar,bar/baz;;;;;application/xml"}}, giveHeaders: map[string][]string{"Accept": {";q=foobar,bar/baz;;;;;application/xml"}},
}, },
"none on invalid input": { "none on invalid input": {
giveHeaders: http.Header{"Content-Type": {"foo/bar; charset=utf-8"}, "Accept": {"foo/bar; charset=utf-8"}}, giveHeaders: map[string][]string{"Content-Type": {"foo/bar; charset=utf-8"}, "Accept": {"foo/bar; charset=utf-8"}},
}, },
"completely unknown": { "completely unknown": {
giveHeaders: http.Header{"Content-Type": {"😀"}, "Accept": {"😄"}, "X-Format": {"😍"}}, giveHeaders: map[string][]string{"Content-Type": {"😀"}, "Accept": {"😄"}, "X-Format": {"😍"}},
}, },
} { } {
tt := _tt tt := _tt
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert.Equal(t, tt.wantFormat, detectPreferredFormatForClient(tt.giveHeaders)) var headers = new(fasthttp.RequestHeader)
for key, values := range tt.giveHeaders {
for _, value := range values {
headers.Add(key, value)
}
}
assert.Equal(t, tt.wantFormat, detectPreferredFormatForClient(headers))
}) })
} }
} }

View File

@ -7,21 +7,24 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/valyala/fasthttp"
"gh.tarampamp.am/error-pages/internal/config" "gh.tarampamp.am/error-pages/internal/config"
"gh.tarampamp.am/error-pages/internal/logger" "gh.tarampamp.am/error-pages/internal/logger"
"gh.tarampamp.am/error-pages/internal/template" "gh.tarampamp.am/error-pages/internal/template"
) )
// New creates a new handler that returns an error page with the specified status code and format. // New creates a new handler that returns an error page with the specified status code and format.
func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,gocognit,gocyclo func New(cfg *config.Config, log *logger.Logger) fasthttp.RequestHandler { //nolint:funlen,gocognit,gocyclo
const contentTypeHeader = "Content-Type" return func(ctx *fasthttp.RequestCtx) {
var (
reqHeaders = &ctx.Request.Header
code uint16
)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if fromUrl, okUrl := extractCodeFromURL(string(ctx.RequestURI())); okUrl {
var code uint16
if fromUrl, okUrl := extractCodeFromURL(r.URL.Path); okUrl {
code = fromUrl code = fromUrl
} else if fromHeader, okHeaders := extractCodeFromHeaders(r.Header); okHeaders { } else if fromHeader, okHeaders := extractCodeFromHeaders(reqHeaders); okHeaders {
code = fromHeader code = fromHeader
} else { } else {
code = cfg.DefaultCodeToRender code = cfg.DefaultCodeToRender
@ -35,23 +38,23 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
httpCode = http.StatusOK httpCode = http.StatusOK
} }
var format = detectPreferredFormatForClient(r.Header) var format = detectPreferredFormatForClient(reqHeaders)
{ // deal with the headers { // deal with the headers
switch format { switch format {
case jsonFormat: case jsonFormat:
w.Header().Set(contentTypeHeader, "application/json; charset=utf-8") ctx.SetContentType("application/json; charset=utf-8")
case xmlFormat: case xmlFormat:
w.Header().Set(contentTypeHeader, "application/xml; charset=utf-8") ctx.SetContentType("application/xml; charset=utf-8")
case htmlFormat: case htmlFormat:
w.Header().Set(contentTypeHeader, "text/html; charset=utf-8") ctx.SetContentType("text/html; charset=utf-8")
default: default:
w.Header().Set(contentTypeHeader, "text/plain; charset=utf-8") // plainTextFormat as default ctx.SetContentType("text/plain; charset=utf-8") // plainTextFormat as default
} }
// https://developers.google.com/search/docs/crawling-indexing/robots-meta-tag // https://developers.google.com/search/docs/crawling-indexing/robots-meta-tag
// disallow indexing of the error pages // disallow indexing of the error pages
w.Header().Set("X-Robots-Tag", "noindex") ctx.Response.Header.Set("X-Robots-Tag", "noindex")
switch code { switch code {
case http.StatusRequestTimeout, http.StatusTooEarly, http.StatusTooManyRequests, case http.StatusRequestTimeout, http.StatusTooEarly, http.StatusTooManyRequests,
@ -59,18 +62,18 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
http.StatusGatewayTimeout: http.StatusGatewayTimeout:
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
// tell the client (search crawler) to retry the request after 120 seconds // tell the client (search crawler) to retry the request after 120 seconds
w.Header().Set("Retry-After", "120") ctx.Response.Header.Set("Retry-After", "120")
} }
// proxy the headers from the incoming request to the error page response if they are defined in the config // proxy the headers from the incoming request to the error page response if they are defined in the config
for _, proxyHeader := range cfg.ProxyHeaders { for _, proxyHeader := range cfg.ProxyHeaders {
if value := r.Header.Get(proxyHeader); value != "" { if value := reqHeaders.Peek(proxyHeader); len(value) > 0 {
w.Header().Set(proxyHeader, value) ctx.Response.Header.SetBytesV(proxyHeader, value)
} }
} }
} }
w.WriteHeader(httpCode) ctx.SetStatusCode(httpCode)
// prepare the template properties for rendering // prepare the template properties for rendering
var tplProps = template.Props{ var tplProps = template.Props{
@ -81,14 +84,14 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
//nolint:lll //nolint:lll
if cfg.ShowDetails { // https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/ if cfg.ShowDetails { // https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/
tplProps.OriginalURI = r.Header.Get("X-Original-URI") // (ingress-nginx) URI that caused the error tplProps.OriginalURI = string(reqHeaders.Peek("X-Original-URI")) // (ingress-nginx) URI that caused the error
tplProps.Namespace = r.Header.Get("X-Namespace") // (ingress-nginx) namespace where the backend Service is located tplProps.Namespace = string(reqHeaders.Peek("X-Namespace")) // (ingress-nginx) namespace where the backend Service is located
tplProps.IngressName = r.Header.Get("X-Ingress-Name") // (ingress-nginx) name of the Ingress where the backend is defined tplProps.IngressName = string(reqHeaders.Peek("X-Ingress-Name")) // (ingress-nginx) name of the Ingress where the backend is defined
tplProps.ServiceName = r.Header.Get("X-Service-Name") // (ingress-nginx) name of the Service backing the backend tplProps.ServiceName = string(reqHeaders.Peek("X-Service-Name")) // (ingress-nginx) name of the Service backing the backend
tplProps.ServicePort = r.Header.Get("X-Service-Port") // (ingress-nginx) port number of the Service backing the backend tplProps.ServicePort = string(reqHeaders.Peek("X-Service-Port")) // (ingress-nginx) port number of the Service backing the backend
tplProps.RequestID = r.Header.Get("X-Request-Id") // (ingress-nginx) unique ID that identifies the request - same as for backend service tplProps.RequestID = string(reqHeaders.Peek("X-Request-Id")) // (ingress-nginx) unique ID that identifies the request - same as for backend service
tplProps.ForwardedFor = r.Header.Get("X-Forwarded-For") // the value of the `X-Forwarded-For` header tplProps.ForwardedFor = string(reqHeaders.Peek("X-Forwarded-For")) // the value of the `X-Forwarded-For` header
tplProps.Host = r.Host // the value of the `Host` header tplProps.Host = string(reqHeaders.Peek("Host")) // the value of the `Host` header
} }
// try to find the code message and description in the config and if not - use the standard status text or fallback // try to find the code message and description in the config and if not - use the standard status text or fallback
@ -105,18 +108,18 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
case format == jsonFormat && cfg.Formats.JSON != "": case format == jsonFormat && cfg.Formats.JSON != "":
if content, err := template.Render(cfg.Formats.JSON, tplProps); err != nil { if content, err := template.Render(cfg.Formats.JSON, tplProps); err != nil {
j, _ := json.Marshal(fmt.Sprintf("Failed to render the JSON template: %s", err.Error())) j, _ := json.Marshal(fmt.Sprintf("Failed to render the JSON template: %s", err.Error()))
write(w, log, j) write(ctx, log, j)
} else { } else {
write(w, log, content) write(ctx, log, content)
} }
case format == xmlFormat && cfg.Formats.XML != "": case format == xmlFormat && cfg.Formats.XML != "":
if content, err := template.Render(cfg.Formats.XML, tplProps); err != nil { if content, err := template.Render(cfg.Formats.XML, tplProps); err != nil {
write(w, log, fmt.Sprintf( write(ctx, log, fmt.Sprintf(
"<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<error>Failed to render the XML template: %s</error>", err.Error(), "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<error>Failed to render the XML template: %s</error>", err.Error(),
)) ))
} else { } else {
write(w, log, content) write(ctx, log, content)
} }
case format == htmlFormat: case format == htmlFormat:
@ -125,16 +128,16 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
if tpl, found := cfg.Templates.Get(templateName); found { if tpl, found := cfg.Templates.Get(templateName); found {
if content, err := template.Render(tpl, tplProps); err != nil { if content, err := template.Render(tpl, tplProps); err != nil {
// TODO: add GZIP compression for the HTML content support // TODO: add GZIP compression for the HTML content support
write(w, log, fmt.Sprintf( write(ctx, log, fmt.Sprintf(
"<!DOCTYPE html>\n<html><body>Failed to render the HTML template %s: %s</body></html>", "<!DOCTYPE html>\n<html><body>Failed to render the HTML template %s: %s</body></html>",
templateName, templateName,
err.Error(), err.Error(),
)) ))
} else { } else {
write(w, log, content) write(ctx, log, content)
} }
} else { } else {
write(w, log, fmt.Sprintf( write(ctx, log, fmt.Sprintf(
"<!DOCTYPE html>\n<html><body>Template %s not found and cannot be used</body></html>", templateName, "<!DOCTYPE html>\n<html><body>Template %s not found and cannot be used</body></html>", templateName,
)) ))
} }
@ -142,18 +145,18 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
default: // plainTextFormat as default default: // plainTextFormat as default
if cfg.Formats.PlainText != "" { if cfg.Formats.PlainText != "" {
if content, err := template.Render(cfg.Formats.PlainText, tplProps); err != nil { if content, err := template.Render(cfg.Formats.PlainText, tplProps); err != nil {
write(w, log, fmt.Sprintf("Failed to render the PlainText template: %s", err.Error())) write(ctx, log, fmt.Sprintf("Failed to render the PlainText template: %s", err.Error()))
} else { } else {
write(w, log, content) write(ctx, log, content)
} }
} else { } else {
write(w, log, `The requested content format is not supported. write(ctx, log, `The requested content format is not supported.
Please create an issue on the project's GitHub page to request support for this format. Please create an issue on the project's GitHub page to request support for this format.
Supported formats: JSON, XML, HTML, Plain Text`) Supported formats: JSON, XML, HTML, Plain Text`)
} }
} }
}) }
} }
var ( var (
@ -204,7 +207,7 @@ func templateToUse(cfg *config.Config) string {
} }
// write the content to the response writer and log the error if any. // write the content to the response writer and log the error if any.
func write[T string | []byte](w http.ResponseWriter, log *logger.Logger, content T) { func write[T string | []byte](ctx *fasthttp.RequestCtx, log *logger.Logger, content T) {
var data []byte var data []byte
if s, ok := any(content).(string); ok { if s, ok := any(content).(string); ok {
@ -213,7 +216,7 @@ func write[T string | []byte](w http.ResponseWriter, log *logger.Logger, content
data = any(content).([]byte) data = any(content).([]byte)
} }
if _, err := w.Write(data); err != nil && log != nil { if _, err := ctx.Write(data); err != nil && log != nil {
log.Error("failed to write the response body", log.Error("failed to write the response body",
logger.String("content", string(data)), logger.String("content", string(data)),
logger.Error(err), logger.Error(err),

View File

@ -2,13 +2,14 @@ package error_page_test
import ( import (
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gh.tarampamp.am/error-pages/internal/config" "gh.tarampamp.am/error-pages/internal/config"
"gh.tarampamp.am/error-pages/internal/http/handlers/error_page" "gh.tarampamp.am/error-pages/internal/http/handlers/error_page"
"gh.tarampamp.am/error-pages/internal/http/httptest"
"gh.tarampamp.am/error-pages/internal/logger" "gh.tarampamp.am/error-pages/internal/logger"
) )
@ -26,7 +27,7 @@ func TestHandler(t *testing.T) {
}{ }{
"common, plain text": { "common, plain text": {
giveConfig: func() *config.Config { cfg := config.New(); return &cfg }, giveConfig: func() *config.Config { cfg := config.New(); return &cfg },
giveUrl: "/", giveUrl: "http://testing/",
giveHeaders: map[string]string{"Content-Type": "text/plain"}, giveHeaders: map[string]string{"Content-Type": "text/plain"},
wantStatusCode: http.StatusOK, wantStatusCode: http.StatusOK,
@ -41,7 +42,7 @@ func TestHandler(t *testing.T) {
return &cfg return &cfg
}, },
giveUrl: "/", giveUrl: "http://testing/",
giveHeaders: map[string]string{"X-Format": "text/html", "X-Code": "407"}, giveHeaders: map[string]string{"X-Format": "text/html", "X-Code": "407"},
wantStatusCode: http.StatusOK, wantStatusCode: http.StatusOK,
@ -60,7 +61,7 @@ func TestHandler(t *testing.T) {
return &cfg return &cfg
}, },
giveUrl: "/503.html", giveUrl: "http://testing/503.html",
giveHeaders: map[string]string{"Accept": "application/json", "X-FooBar": "baz"}, giveHeaders: map[string]string{"Accept": "application/json", "X-FooBar": "baz"},
wantStatusCode: http.StatusServiceUnavailable, wantStatusCode: http.StatusServiceUnavailable,
@ -78,7 +79,7 @@ func TestHandler(t *testing.T) {
return &cfg return &cfg
}, },
giveUrl: "/500", giveUrl: "http://testing/500",
giveHeaders: map[string]string{"Accept": "application/xml", "X-FooBar": "baz"}, giveHeaders: map[string]string{"Accept": "application/xml", "X-FooBar": "baz"},
wantStatusCode: http.StatusOK, wantStatusCode: http.StatusOK,
@ -96,7 +97,7 @@ func TestHandler(t *testing.T) {
return &cfg return &cfg
}, },
giveUrl: "/503", giveUrl: "http://example.com/503",
giveHeaders: map[string]string{ giveHeaders: map[string]string{
"Accept": "application/json", "Accept": "application/json",
"X-Original-URI": "/foo/bar", "X-Original-URI": "/foo/bar",
@ -106,7 +107,6 @@ func TestHandler(t *testing.T) {
"X-Service-Port": "666", "X-Service-Port": "666",
"X-Request-ID": "req-id-777", "X-Request-ID": "req-id-777",
"X-Forwarded-For": "123.123.123.123:12312", "X-Forwarded-For": "123.123.123.123:12312",
"Host": "example.com",
}, },
wantStatusCode: http.StatusOK, wantStatusCode: http.StatusOK,
@ -133,7 +133,7 @@ func TestHandler(t *testing.T) {
return &cfg return &cfg
}, },
giveUrl: "/100", giveUrl: "http://testing/100",
giveHeaders: map[string]string{"Accept": "application/json"}, giveHeaders: map[string]string{"Accept": "application/json"},
wantStatusCode: http.StatusOK, wantStatusCode: http.StatusOK,
@ -148,7 +148,7 @@ func TestHandler(t *testing.T) {
return &cfg return &cfg
}, },
giveUrl: "/1", giveUrl: "http://testing/1",
giveHeaders: map[string]string{"Accept": "application/json"}, giveHeaders: map[string]string{"Accept": "application/json"},
wantStatusCode: http.StatusOK, wantStatusCode: http.StatusOK,
@ -159,31 +159,30 @@ func TestHandler(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
var ( var handler = error_page.New(tt.giveConfig(), logger.NewNop())
req = httptest.NewRequest(http.MethodGet, tt.giveUrl, http.NoBody)
handler = error_page.New(tt.giveConfig(), logger.NewNop()) req, reqErr := http.NewRequest(http.MethodGet, tt.giveUrl, http.NoBody)
rr = httptest.NewRecorder() require.NoError(t, reqErr)
)
for k, v := range tt.giveHeaders { for k, v := range tt.giveHeaders {
req.Header.Set(k, v) req.Header.Set(k, v)
} }
handler.ServeHTTP(rr, req) httptest.HandleFastRequest(t, handler, req, func(status int, body string, headers http.Header) {
assert.Equal(t, tt.wantStatusCode, status)
assert.Equal(t, rr.Code, tt.wantStatusCode) for hName, hWant := range tt.wantHeaders {
for hGot := range headers {
for hName, hWant := range tt.wantHeaders { if hGot == hName {
for hGot := range rr.Header() { assert.Contains(t, hWant, headers.Get(hGot))
if hGot == hName { }
assert.Contains(t, hWant, rr.Header().Get(hGot))
} }
} }
}
for _, wantBodyInclude := range tt.wantBodyIncludes { for _, wantBodyInclude := range tt.wantBodyIncludes {
assert.Contains(t, rr.Body.String(), wantBodyInclude) assert.Contains(t, body, wantBodyInclude)
} }
})
}) })
} }
} }
@ -207,19 +206,17 @@ func TestRotationModeOnEachRequest(t *testing.T) {
) )
for range 300 { for range 300 {
var ( req, reqErr := http.NewRequest(http.MethodGet, "http://testing/", http.NoBody)
req = httptest.NewRequest(http.MethodGet, "/", http.NoBody) require.NoError(t, reqErr)
rr = httptest.NewRecorder()
)
req.Header.Set("Accept", "text/html") req.Header.Set("Accept", "text/html")
handler.ServeHTTP(rr, req) httptest.HandleFastRequest(t, handler, req, func(status int, body string, headers http.Header) {
if lastResponseBody != body {
if lastResponseBody != rr.Body.String() { changedTimes++
changedTimes++ lastResponseBody = body
lastResponseBody = rr.Body.String() }
} })
} }
assert.True(t, changedTimes > 30, "the template should be changed at least 30 times") assert.True(t, changedTimes > 30, "the template should be changed at least 30 times")

View File

@ -2,24 +2,29 @@ package live
import ( import (
"net/http" "net/http"
"github.com/valyala/fasthttp"
) )
// New creates a new handler that returns "OK" for GET and HEAD requests. // New creates a new handler that returns "OK" for GET and HEAD requests.
func New() http.Handler { func New() fasthttp.RequestHandler {
var body = []byte("OK\n") var (
body = []byte("OK\n")
notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n"
)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return func(ctx *fasthttp.RequestCtx) {
switch r.Method { switch string(ctx.Method()) {
case http.MethodGet: case fasthttp.MethodGet:
w.Header().Set("Content-Type", "text/plain; charset=utf-8") ctx.SetContentType("text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK) ctx.SetStatusCode(http.StatusOK)
_, _ = w.Write(body) _, _ = ctx.Write(body)
case http.MethodHead: case fasthttp.MethodHead:
w.WriteHeader(http.StatusOK) ctx.SetStatusCode(http.StatusOK)
default: default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) ctx.Error(notAllowed, http.StatusMethodNotAllowed)
} }
}) }
} }

View File

@ -2,43 +2,37 @@ package live_test
import ( import (
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gh.tarampamp.am/error-pages/internal/http/handlers/live" "gh.tarampamp.am/error-pages/internal/http/handlers/live"
"gh.tarampamp.am/error-pages/internal/http/httptest"
) )
func TestServeHTTP(t *testing.T) { func TestServeHTTP(t *testing.T) {
t.Parallel() t.Parallel()
var handler = live.New() var (
handler = live.New()
url = "http://testing"
body = http.NoBody
)
t.Run("get", func(t *testing.T) { t.Run("get", func(t *testing.T) {
var ( httptest.HandleFast(t, handler, http.MethodGet, url, body, func(status int, body string, headers http.Header) {
req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody) assert.Equal(t, http.StatusOK, status)
rr = httptest.NewRecorder() assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type"))
) assert.Equal(t, "OK\n", body)
})
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8")
assert.Equal(t, rr.Code, http.StatusOK)
assert.Equal(t, "OK\n", rr.Body.String())
}) })
t.Run("head", func(t *testing.T) { t.Run("head", func(t *testing.T) {
var ( httptest.HandleFast(t, handler, http.MethodHead, url, body, func(status int, body string, headers http.Header) {
req = httptest.NewRequest(http.MethodHead, "http://testing", http.NoBody) assert.Equal(t, http.StatusOK, status)
rr = httptest.NewRecorder() assert.Empty(t, headers.Get("Content-Type"))
) assert.Empty(t, body)
})
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Code, http.StatusOK)
assert.Empty(t, rr.Header().Get("Content-Type"))
assert.Empty(t, rr.Body.Bytes())
}) })
t.Run("method not allowed", func(t *testing.T) { t.Run("method not allowed", func(t *testing.T) {
@ -48,16 +42,11 @@ func TestServeHTTP(t *testing.T) {
http.MethodPost, http.MethodPost,
http.MethodPut, http.MethodPut,
} { } {
var ( httptest.HandleFast(t, handler, method, url, body, func(status int, body string, headers http.Header) {
req = httptest.NewRequest(method, "http://testing", http.NoBody) assert.Equal(t, http.StatusMethodNotAllowed, status)
rr = httptest.NewRecorder() assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type"))
) assert.Equal(t, "Method Not Allowed\n", body)
})
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8")
assert.Equal(t, rr.Code, http.StatusMethodNotAllowed)
assert.Equal(t, "Method Not Allowed\n", rr.Body.String())
} }
}) })
} }

View File

@ -3,25 +3,29 @@ package static
import ( import (
_ "embed" _ "embed"
"net/http" "net/http"
"github.com/valyala/fasthttp"
) )
//go:embed favicon.ico //go:embed favicon.ico
var Favicon []byte var Favicon []byte
// New creates a new handler that returns the provided content for GET and HEAD requests. // New creates a new handler that returns the provided content for GET and HEAD requests.
func New(content []byte) http.Handler { func New(content []byte) fasthttp.RequestHandler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n"
switch r.Method {
case http.MethodGet:
w.Header().Set("Content-Type", http.DetectContentType(content))
w.WriteHeader(http.StatusOK)
_, _ = w.Write(content)
case http.MethodHead: return func(ctx *fasthttp.RequestCtx) {
w.WriteHeader(http.StatusOK) switch string(ctx.Method()) {
case fasthttp.MethodGet:
ctx.SetContentType(http.DetectContentType(content))
ctx.SetStatusCode(http.StatusOK)
_, _ = ctx.Write(content)
case fasthttp.MethodHead:
ctx.SetStatusCode(http.StatusOK)
default: default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) ctx.Error(notAllowed, http.StatusMethodNotAllowed)
} }
}) }
} }

View File

@ -2,43 +2,37 @@ package static_test
import ( import (
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gh.tarampamp.am/error-pages/internal/http/handlers/static" "gh.tarampamp.am/error-pages/internal/http/handlers/static"
"gh.tarampamp.am/error-pages/internal/http/httptest"
) )
func TestServeHTTP(t *testing.T) { func TestServeHTTP(t *testing.T) {
t.Parallel() t.Parallel()
var handler = static.New([]byte{1, 2, 3}) var (
handler = static.New([]byte{1, 2, 3})
url = "http://testing"
body = http.NoBody
)
t.Run("get", func(t *testing.T) { t.Run("get", func(t *testing.T) {
var ( httptest.HandleFast(t, handler, http.MethodGet, url, body, func(status int, body string, headers http.Header) {
req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody) assert.Equal(t, http.StatusOK, status)
rr = httptest.NewRecorder() assert.Equal(t, "application/octet-stream", headers.Get("Content-Type"))
) assert.Equal(t, []byte{1, 2, 3}, []byte(body))
})
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Header().Get("Content-Type"), "application/octet-stream")
assert.Equal(t, rr.Code, http.StatusOK)
assert.Equal(t, rr.Body.Bytes(), []byte{1, 2, 3})
}) })
t.Run("head", func(t *testing.T) { t.Run("head", func(t *testing.T) {
var ( httptest.HandleFast(t, handler, http.MethodHead, url, body, func(status int, body string, headers http.Header) {
req = httptest.NewRequest(http.MethodHead, "http://testing", http.NoBody) assert.Equal(t, http.StatusOK, status)
rr = httptest.NewRecorder() assert.Empty(t, headers.Get("Content-Type"))
) assert.Empty(t, body)
})
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Code, http.StatusOK)
assert.Empty(t, rr.Header().Get("Content-Type"))
assert.Empty(t, rr.Body.Bytes())
}) })
t.Run("method not allowed", func(t *testing.T) { t.Run("method not allowed", func(t *testing.T) {
@ -48,16 +42,11 @@ func TestServeHTTP(t *testing.T) {
http.MethodPost, http.MethodPost,
http.MethodPut, http.MethodPut,
} { } {
var ( httptest.HandleFast(t, handler, method, url, body, func(status int, body string, headers http.Header) {
req = httptest.NewRequest(method, "http://testing", http.NoBody) assert.Equal(t, http.StatusMethodNotAllowed, status)
rr = httptest.NewRecorder() assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type"))
) assert.Equal(t, "Method Not Allowed\n", body)
})
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8")
assert.Equal(t, rr.Code, http.StatusMethodNotAllowed)
assert.Equal(t, "Method Not Allowed\n", rr.Body.String())
} }
}) })
} }
@ -65,16 +54,15 @@ func TestServeHTTP(t *testing.T) {
func TestServeHTTP_Favicon(t *testing.T) { func TestServeHTTP_Favicon(t *testing.T) {
t.Parallel() t.Parallel()
var ( httptest.HandleFast(t,
handler = static.New(static.Favicon) static.New(static.Favicon),
http.MethodGet,
req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody) "http://testing",
rr = httptest.NewRecorder() http.NoBody,
func(status int, body string, headers http.Header) {
assert.Equal(t, http.StatusOK, status)
assert.Equal(t, "image/x-icon", headers.Get("Content-Type"))
assert.Equal(t, static.Favicon, []byte(body))
},
) )
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Header().Get("Content-Type"), "image/x-icon")
assert.Equal(t, rr.Code, http.StatusOK)
assert.Equal(t, rr.Body.Bytes(), static.Favicon)
} }

View File

@ -4,28 +4,32 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"strings" "strings"
"github.com/valyala/fasthttp"
) )
// New creates a handler that returns the version of the service in JSON format. // New creates a handler that returns the version of the service in JSON format.
func New(ver string) http.Handler { func New(ver string) fasthttp.RequestHandler {
var body, _ = json.Marshal(struct { //nolint:errchkjson var body, _ = json.Marshal(struct { //nolint:errchkjson
Version string `json:"version"` Version string `json:"version"`
}{ }{
Version: strings.TrimSpace(ver), Version: strings.TrimSpace(ver),
}) })
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n"
switch r.Method {
case http.MethodGet:
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(body)
case http.MethodHead: return func(ctx *fasthttp.RequestCtx) {
w.WriteHeader(http.StatusOK) switch string(ctx.Method()) {
case fasthttp.MethodGet:
ctx.SetContentType("application/json; charset=utf-8")
ctx.SetStatusCode(http.StatusOK)
_, _ = ctx.Write(body)
case fasthttp.MethodHead:
ctx.SetStatusCode(http.StatusOK)
default: default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) ctx.Error(notAllowed, http.StatusMethodNotAllowed)
} }
}) }
} }

View File

@ -2,43 +2,37 @@ package version_test
import ( import (
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gh.tarampamp.am/error-pages/internal/http/handlers/version" "gh.tarampamp.am/error-pages/internal/http/handlers/version"
"gh.tarampamp.am/error-pages/internal/http/httptest"
) )
func TestServeHTTP(t *testing.T) { func TestServeHTTP(t *testing.T) {
t.Parallel() t.Parallel()
var handler = version.New("\t\n foo@bar ") var (
handler = version.New("\t\n foo@bar ")
url = "http://testing"
body = http.NoBody
)
t.Run("get", func(t *testing.T) { t.Run("get", func(t *testing.T) {
var ( httptest.HandleFast(t, handler, http.MethodGet, url, body, func(status int, body string, headers http.Header) {
req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody) assert.Equal(t, http.StatusOK, status)
rr = httptest.NewRecorder() assert.Equal(t, "application/json; charset=utf-8", headers.Get("Content-Type"))
) assert.Equal(t, `{"version":"foo@bar"}`, body)
})
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Header().Get("Content-Type"), "application/json; charset=utf-8")
assert.Equal(t, rr.Code, http.StatusOK)
assert.Equal(t, rr.Body.String(), `{"version":"foo@bar"}`)
}) })
t.Run("head", func(t *testing.T) { t.Run("head", func(t *testing.T) {
var ( httptest.HandleFast(t, handler, http.MethodHead, url, body, func(status int, body string, headers http.Header) {
req = httptest.NewRequest(http.MethodHead, "http://testing", http.NoBody) assert.Equal(t, http.StatusOK, status)
rr = httptest.NewRecorder() assert.Empty(t, headers.Get("Content-Type"))
) assert.Empty(t, body)
})
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Code, http.StatusOK)
assert.Empty(t, rr.Header().Get("Content-Type"))
assert.Empty(t, rr.Body.Bytes())
}) })
t.Run("method not allowed", func(t *testing.T) { t.Run("method not allowed", func(t *testing.T) {
@ -48,16 +42,11 @@ func TestServeHTTP(t *testing.T) {
http.MethodPost, http.MethodPost,
http.MethodPut, http.MethodPut,
} { } {
var ( httptest.HandleFast(t, handler, method, url, body, func(status int, body string, headers http.Header) {
req = httptest.NewRequest(method, "http://testing", http.NoBody) assert.Equal(t, http.StatusMethodNotAllowed, status)
rr = httptest.NewRecorder() assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type"))
) assert.Equal(t, "Method Not Allowed\n", body)
})
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8")
assert.Equal(t, rr.Code, http.StatusMethodNotAllowed)
assert.Equal(t, "Method Not Allowed\n", rr.Body.String())
} }
}) })
} }

View File

@ -0,0 +1,69 @@
// Package httptest provides utilities for (fast-)HTTP testing.
package httptest
import (
"context"
"io"
"net"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
)
// HandleFastRequest serves http request using provided fasthttp handler and HTTP request.
func HandleFastRequest(
t *testing.T,
handler fasthttp.RequestHandler,
req *http.Request,
check func(status int, body string, _ http.Header),
) {
t.Helper()
// create in-memory listener
var ln = fasthttputil.NewInmemoryListener()
defer func() { require.NoError(t, ln.Close()) }()
// start fasthttp server
go func() { require.NoError(t, fasthttp.Serve(ln, handler)) }()
// send http request
resp, respErr := (&http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return ln.Dial() },
},
}).Do(req)
require.NoError(t, respErr)
// close response body after the test
defer func() { assert.NoError(t, resp.Body.Close()) }()
// read response body
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// check the response
check(resp.StatusCode, string(respBody), resp.Header)
}
// HandleFast serves http request using provided fasthttp handler.
func HandleFast(
t *testing.T,
handler fasthttp.RequestHandler,
method string,
url string,
body io.Reader,
check func(status int, body string, _ http.Header),
) {
t.Helper()
// create http request
req, reqErr := http.NewRequest(method, url, body)
require.NoError(t, reqErr)
// serve http request
HandleFastRequest(t, handler, req, check)
}

View File

@ -1,20 +1,24 @@
package logreq package logreq
import ( import (
"net/http"
"time" "time"
"github.com/valyala/fasthttp"
"gh.tarampamp.am/error-pages/internal/logger" "gh.tarampamp.am/error-pages/internal/logger"
) )
// New creates a middleware for [http.ServeMux] that logs every incoming request. // New creates a middleware that logs every incoming request.
// //
// The skipper function should return true if the request should be skipped. It's ok to pass nil. // The skipper function should return true if the request should be skipped. It's ok to pass nil.
func New(log *logger.Logger, skipper func(*http.Request) bool) func(http.Handler) http.Handler { func New(
return func(next http.Handler) http.Handler { log *logger.Logger,
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { skipper func(*fasthttp.RequestCtx) bool,
if skipper != nil && skipper(r) { ) func(fasthttp.RequestHandler) fasthttp.RequestHandler {
next.ServeHTTP(w, r) return func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
if skipper != nil && skipper(ctx) {
next(ctx)
return return
} }
@ -23,27 +27,35 @@ func New(log *logger.Logger, skipper func(*http.Request) bool) func(http.Handler
defer func() { defer func() {
var fields = []logger.Attr{ var fields = []logger.Attr{
logger.String("useragent", r.UserAgent()), logger.Int("status code", ctx.Response.StatusCode()),
logger.String("method", r.Method), logger.String("useragent", string(ctx.UserAgent())),
logger.String("url", r.URL.String()), logger.String("method", string(ctx.Method())),
logger.String("referer", r.Referer()), logger.String("url", string(ctx.RequestURI())),
logger.String("content type", w.Header().Get("Content-Type")), logger.String("referer", string(ctx.Referer())),
logger.String("remote addr", r.RemoteAddr), logger.String("content type", string(ctx.Response.Header.ContentType())),
logger.String("method", r.Method), logger.String("remote addr", ctx.RemoteAddr().String()),
logger.Duration("duration", time.Since(now).Round(time.Microsecond)), logger.Duration("duration", time.Since(now).Round(time.Microsecond)),
} }
if log.Level() <= logger.DebugLevel { if log.Level() <= logger.DebugLevel {
var (
reqHeaders = make(map[string]string)
respHeaders = make(map[string]string)
)
ctx.Request.Header.VisitAll(func(key, value []byte) { reqHeaders[string(key)] = string(value) })
ctx.Response.Header.VisitAll(func(key, value []byte) { respHeaders[string(key)] = string(value) })
fields = append(fields, fields = append(fields,
logger.Any("request headers", r.Header.Clone()), logger.Any("request headers", reqHeaders),
logger.Any("response headers", w.Header().Clone()), logger.Any("response headers", respHeaders),
) )
} }
log.Info("HTTP request processed", fields...) log.Info("HTTP request processed", fields...)
}() }()
next.ServeHTTP(w, r) next(ctx)
}) }
} }
} }

View File

@ -3,11 +3,12 @@ package logreq_test
import ( import (
"bytes" "bytes"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/valyala/fasthttp"
"gh.tarampamp.am/error-pages/internal/http/httptest"
"gh.tarampamp.am/error-pages/internal/http/middleware/logreq" "gh.tarampamp.am/error-pages/internal/http/middleware/logreq"
"gh.tarampamp.am/error-pages/internal/logger" "gh.tarampamp.am/error-pages/internal/logger"
) )
@ -19,18 +20,19 @@ func TestNew(t *testing.T) {
buf bytes.Buffer buf bytes.Buffer
log, _ = logger.New(logger.DebugLevel, logger.JSONFormat, &buf) log, _ = logger.New(logger.DebugLevel, logger.JSONFormat, &buf)
mw = logreq.New(log, nil) mw = logreq.New(log, nil)
rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPut, "http://testing/foo/bar", http.NoBody)
req = httptest.NewRequest(http.MethodPut, "/foo/bar", http.NoBody)
) )
req.Header.Set("User-Agent", "test") req.Header.Set("User-Agent", "test")
req.Header.Set("Referer", "https://example.com") req.Header.Set("Referer", "https://example.com")
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httptest.HandleFastRequest(t,
w.WriteHeader(http.StatusOK) mw(func(ctx *fasthttp.RequestCtx) { ctx.SetStatusCode(http.StatusOK) }),
})).ServeHTTP(rr, req) req,
func(status int, body string, _ http.Header) { assert.Equal(t, http.StatusOK, status) },
)
var logRecord = buf.String() var logRecord = buf.String()

View File

@ -2,12 +2,15 @@ package http
import ( import (
"context" "context"
"errors"
"fmt"
"net" "net"
"net/http" "net/http"
"strconv"
"strings" "strings"
"time" "time"
"github.com/valyala/fasthttp"
"gh.tarampamp.am/error-pages/internal/appmeta" "gh.tarampamp.am/error-pages/internal/appmeta"
"gh.tarampamp.am/error-pages/internal/config" "gh.tarampamp.am/error-pages/internal/config"
ep "gh.tarampamp.am/error-pages/internal/http/handlers/error_page" ep "gh.tarampamp.am/error-pages/internal/http/handlers/error_page"
@ -21,11 +24,11 @@ import (
// Server is an HTTP server for serving error pages. // Server is an HTTP server for serving error pages.
type Server struct { type Server struct {
log *logger.Logger log *logger.Logger
server *http.Server server *fasthttp.Server
} }
// NewServer creates a new HTTP server. // NewServer creates a new HTTP server.
func NewServer(baseCtx context.Context, log *logger.Logger) Server { func NewServer(log *logger.Logger) Server {
const ( const (
readTimeout = 30 * time.Second readTimeout = 30 * time.Second
writeTimeout = readTimeout + 10*time.Second // should be bigger than the read timeout writeTimeout = readTimeout + 10*time.Second // should be bigger than the read timeout
@ -34,13 +37,14 @@ func NewServer(baseCtx context.Context, log *logger.Logger) Server {
return Server{ return Server{
log: log, log: log,
server: &http.Server{ server: &fasthttp.Server{
ReadTimeout: readTimeout, ReadTimeout: readTimeout,
WriteTimeout: writeTimeout, WriteTimeout: writeTimeout,
ReadHeaderTimeout: readTimeout, ReadBufferSize: maxHeaderBytes,
MaxHeaderBytes: maxHeaderBytes, DisablePreParseMultipartForm: true,
ErrorLog: logger.NewStdLog(log), NoDefaultServerHeader: true,
BaseContext: func(net.Listener) context.Context { return baseCtx }, CloseOnShutdown: true,
Logger: logger.NewStdLog(log),
}, },
} }
} }
@ -52,60 +56,78 @@ func (s *Server) Register(cfg *config.Config) error {
versionHandler = version.New(appmeta.Version()) versionHandler = version.New(appmeta.Version())
errorPagesHandler = ep.New(cfg, s.log) errorPagesHandler = ep.New(cfg, s.log)
faviconHandler = static.New(static.Favicon) faviconHandler = static.New(static.Favicon)
notFound = http.StatusText(http.StatusNotFound) + "\n"
notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n"
) )
s.server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.server.Handler = func(ctx *fasthttp.RequestCtx) {
var url, method = r.URL.Path, r.Method var url, method = string(ctx.RequestURI()), string(ctx.Method())
switch { switch {
// live endpoints // live endpoints
case url == "/health/live" || url == "/health" || url == "/healthz" || url == "/live": case url == "/health/live" || url == "/health" || url == "/healthz" || url == "/live":
liveHandler.ServeHTTP(w, r) liveHandler(ctx)
// version endpoint // version endpoint
case url == "/version": case url == "/version":
versionHandler.ServeHTTP(w, r) versionHandler(ctx)
// favicon.ico endpoint // favicon.ico endpoint
case url == "/favicon.ico": case url == "/favicon.ico":
faviconHandler.ServeHTTP(w, r) faviconHandler(ctx)
// error pages endpoints: // error pages endpoints:
// - / // - /
// - /{code}.html // - /{code}.html
// - /{code}.htm // - /{code}.htm
// - /{code} // - /{code}
case method == http.MethodGet && (url == "/" || ep.URLContainsCode(url) || ep.HeadersContainCode(r.Header)): case method == fasthttp.MethodGet &&
errorPagesHandler.ServeHTTP(w, r) (url == "/" || ep.URLContainsCode(url) || ep.HeadersContainCode(&ctx.Request.Header)):
errorPagesHandler(ctx)
// wrong requests handling // wrong requests handling
default: default:
switch { switch {
case method == http.MethodHead: case method == fasthttp.MethodHead:
w.WriteHeader(http.StatusNotFound) ctx.Error(notAllowed, fasthttp.StatusNotFound)
case method == http.MethodGet: case method == fasthttp.MethodGet:
http.NotFound(w, r) ctx.Error(notFound, fasthttp.StatusNotFound)
default: default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) ctx.Error(notAllowed, fasthttp.StatusMethodNotAllowed)
} }
} }
}) }
// apply middleware // apply middleware
s.server.Handler = logreq.New(s.log, func(r *http.Request) bool { s.server.Handler = logreq.New(s.log, func(ctx *fasthttp.RequestCtx) bool {
// skip logging healthcheck and .ico (favicon) requests // skip logging healthcheck and .ico (favicon) requests
return strings.Contains(strings.ToLower(r.UserAgent()), "healthcheck") || return strings.Contains(strings.ToLower(string(ctx.UserAgent())), "healthcheck") ||
strings.HasSuffix(r.URL.Path, ".ico") strings.HasSuffix(string(ctx.Path()), ".ico")
})(s.server.Handler) })(s.server.Handler)
return nil return nil
} }
// Start server. // Start server.
func (s *Server) Start(ip string, port uint16) error { func (s *Server) Start(ip string, port uint16) (err error) {
s.server.Addr = ip + ":" + strconv.Itoa(int(port)) if net.ParseIP(ip) == nil {
return errors.New("invalid IP address")
}
return s.server.ListenAndServe() var ln net.Listener
if strings.Count(ip, ":") >= 2 { //nolint:mnd // ipv6
if ln, err = net.Listen("tcp6", fmt.Sprintf("[%s]:%d", ip, port)); err != nil {
return err
}
} else { // ipv4
if ln, err = net.Listen("tcp4", fmt.Sprintf("%s:%d", ip, port)); err != nil {
return err
}
}
return s.server.Serve(ln)
} }
// Stop server gracefully. // Stop server gracefully.
@ -113,5 +135,5 @@ func (s *Server) Stop(timeout time.Duration) error {
var ctx, cancel = context.WithTimeout(context.Background(), timeout) var ctx, cancel = context.WithTimeout(context.Background(), timeout)
defer cancel() defer cancel()
return s.server.Shutdown(ctx) return s.server.ShutdownWithContext(ctx)
} }

View File

@ -1,7 +1,6 @@
package http_test package http_test
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -21,7 +20,7 @@ import (
// TestRouting in fact is a test for the whole server, because it tests all the routes and their handlers. // TestRouting in fact is a test for the whole server, because it tests all the routes and their handlers.
func TestRouting(t *testing.T) { func TestRouting(t *testing.T) {
var ( var (
srv = appHttp.NewServer(context.Background(), logger.NewNop()) srv = appHttp.NewServer(logger.NewNop())
cfg = config.New() cfg = config.New()
) )
@ -296,7 +295,7 @@ func TestRouting(t *testing.T) {
assert.Equal(t, http.StatusNotFound, status) assert.Equal(t, http.StatusNotFound, status)
assert.Empty(t, body) assert.Empty(t, body)
assert.Empty(t, headers.Get("Content-Type")) assert.Contains(t, headers.Get("Content-Type"), "text/plain")
} }
}) })