feat: fasthttp is back

This commit is contained in:
Paramtamtam 2024-07-01 23:41:26 +04:00
parent afaef54ddf
commit df2163e41e
No known key found for this signature in database
GPG Key ID: 366371698FAD0A2B
21 changed files with 432 additions and 315 deletions

4
go.mod
View File

@ -10,12 +10,16 @@ require (
)
require (
github.com/andybalholm/brotli v1.1.0 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/klauspost/compress v1.17.9 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.55.0 // indirect
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

8
go.sum
View File

@ -1,8 +1,12 @@
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
@ -26,6 +30,10 @@ github.com/urfave/cli-docs/v3 v3.0.0-alpha5 h1:H1oWnR2/GN0dNm2PVylws+GxSOD6YOwW/
github.com/urfave/cli-docs/v3 v3.0.0-alpha5/go.mod h1:AIqom6Q60U4tiqHp41i7+/AB2XHgi1WvQ7jOFlccmZ4=
github.com/urfave/cli/v3 v3.0.0-alpha9 h1:P0RMy5fQm1AslQS+XCmy9UknDXctOmG/q/FZkUFnJSo=
github.com/urfave/cli/v3 v3.0.0-alpha9/go.mod h1:0kK/RUFHyh+yIKSfWxwheGndfnrvYSmYFVeKCh03ZUc=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.55.0 h1:Zkefzgt6a7+bVKHnu/YaYSOPfNYNisSVBo/unVCf8k8=
github.com/valyala/fasthttp v1.55.0/go.mod h1:NkY9JtkrpPKmgwV3HTaS2HWaJss9RSIsRVfcxxoHiOM=
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4=
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8=

View File

@ -55,7 +55,7 @@ func NewCommand(log *logger.Logger) *cli.Command { //nolint:funlen,gocognit
Aliases: []string{"perf", "test"},
Hidden: true,
Usage: "Simple performance (load) test for the HTTP server",
Action: func(ctx context.Context, c *cli.Command) error {
Action: func(ctx context.Context, c *cli.Command) error { // TODO: use fasthttp.Client
var (
perfCtx, cancel = context.WithTimeout(ctx, c.Duration(durationFlag.Name))
startedAt = time.Now()

View File

@ -290,7 +290,7 @@ func NewCommand(log *logger.Logger) *cli.Command { //nolint:funlen,gocognit,gocy
// Run current command.
func (cmd *command) Run(ctx context.Context, log *logger.Logger, cfg *config.Config) error { //nolint:funlen
var srv = appHttp.NewServer(ctx, log)
var srv = appHttp.NewServer(log)
if err := srv.Register(cfg); err != nil {
return err

View File

@ -1,10 +1,11 @@
package error_page
import (
"net/http"
"path/filepath"
"strconv"
"strings"
"github.com/valyala/fasthttp"
)
// extractCodeFromURL extracts the error code from the given URL.
@ -37,11 +38,15 @@ func extractCodeFromURL(url string) (uint16, bool) {
func URLContainsCode(url string) (ok bool) { _, ok = extractCodeFromURL(url); return } //nolint:nlreturn
// extractCodeFromHeaders extracts the error code from the given headers.
func extractCodeFromHeaders(headers http.Header) (uint16, bool) {
func extractCodeFromHeaders(headers *fasthttp.RequestHeader) (uint16, bool) {
if headers == nil {
return 0, false
}
// https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/
// HTTP status code returned by the request
if value := headers.Get("X-Code"); len(value) > 0 && len(value) <= 3 {
if code, err := strconv.ParseUint(value, 10, 16); err == nil && code > 0 && code < 999 {
if value := headers.Peek("X-Code"); len(value) > 0 && len(value) <= 3 {
if code, err := strconv.ParseUint(string(value), 10, 16); err == nil && code > 0 && code < 999 {
return uint16(code), true
}
}
@ -50,7 +55,7 @@ func extractCodeFromHeaders(headers http.Header) (uint16, bool) {
}
// HeadersContainCode checks if the given headers contain an error code.
func HeadersContainCode(headers http.Header) (ok bool) {
func HeadersContainCode(headers *fasthttp.RequestHeader) (ok bool) {
_, ok = extractCodeFromHeaders(headers)
return

View File

@ -1,10 +1,10 @@
package error_page_test
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/valyala/fasthttp"
"gh.tarampamp.am/error-pages/internal/http/handlers/error_page"
)
@ -36,18 +36,26 @@ func TestURLContainsCode(t *testing.T) {
func TestHeadersContainCode(t *testing.T) {
t.Parallel()
var mkHeaders = func(key, value string) *fasthttp.RequestHeader {
var out = new(fasthttp.RequestHeader)
out.Set(key, value)
return out
}
for name, _tt := range map[string]struct {
giveHeaders http.Header
giveHeaders *fasthttp.RequestHeader
wantOk bool
}{
"with code": {giveHeaders: http.Header{"X-Code": {"404"}}, wantOk: true},
"with code": {giveHeaders: mkHeaders("X-Code", "404"), wantOk: true},
"empty": {giveHeaders: nil},
"no code": {giveHeaders: http.Header{"X-Code": {""}}},
"wrong": {giveHeaders: http.Header{"X-Code": {"foo"}}},
"too big": {giveHeaders: http.Header{"X-Code": {"1000"}}},
"too small": {giveHeaders: http.Header{"X-Code": {"0"}}},
"negative": {giveHeaders: http.Header{"X-Code": {"-1"}}},
"no code": {giveHeaders: mkHeaders("X-Code", "")},
"wrong": {giveHeaders: mkHeaders("X-Code", "foo")},
"too big": {giveHeaders: mkHeaders("X-Code", "1000")},
"too small": {giveHeaders: mkHeaders("X-Code", "0")},
"negative": {giveHeaders: mkHeaders("X-Code", "-1")},
} {
tt := _tt

View File

@ -2,10 +2,11 @@ package error_page
import (
"math"
"net/http"
"slices"
"strconv"
"strings"
"github.com/valyala/fasthttp"
)
type preferredFormat = byte
@ -21,10 +22,10 @@ const (
// detectPreferredFormatForClient detects the preferred format for the client based on the headers.
// It supports the following headers: Content-Type, Accept, X-Format.
// If the headers are not set or the format is not recognized, it returns unknownFormat.
func detectPreferredFormatForClient(headers http.Header) preferredFormat { //nolint:funlen,gocognit
func detectPreferredFormatForClient(headers *fasthttp.RequestHeader) preferredFormat { //nolint:funlen,gocognit
var contentType, accept string
if contentTypeHeader := strings.TrimSpace(headers.Get("Content-Type")); contentTypeHeader != "" { //nolint:nestif
if contentTypeHeader := strings.TrimSpace(string(headers.Peek("Content-Type"))); contentTypeHeader != "" { //nolint:nestif,lll
// https://developer.mozilla.org/docs/Web/HTTP/Headers/Content-Type
// text/html; charset=utf-8
// multipart/form-data; boundary=something
@ -38,11 +39,11 @@ func detectPreferredFormatForClient(headers http.Header) preferredFormat { //nol
// take the whole value
contentType = contentTypeHeader
}
} else if xFormatHeader := strings.TrimSpace(headers.Get("X-Format")); xFormatHeader != "" {
} else if xFormatHeader := strings.TrimSpace(string(headers.Peek("X-Format"))); xFormatHeader != "" {
// https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/
// Value of the `Accept` header sent by the client
accept = xFormatHeader
} else if acceptHeader := strings.TrimSpace(headers.Get("Accept")); acceptHeader != "" {
} else if acceptHeader := strings.TrimSpace(string(headers.Peek("Accept"))); acceptHeader != "" {
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
// text/html, application/xhtml+xml, application/xml;q=0.9, image/webp, */*;q=0.8
// text/html

View File

@ -1,80 +1,80 @@
package error_page
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/valyala/fasthttp"
)
func Test_detectPreferredFormatForClient(t *testing.T) {
t.Parallel()
for name, _tt := range map[string]struct {
giveHeaders http.Header
giveHeaders map[string][]string
wantFormat preferredFormat
}{
"content type json": {
giveHeaders: http.Header{"Content-Type": {"application/jSoN"}},
giveHeaders: map[string][]string{"Content-Type": {"application/jSoN"}},
wantFormat: jsonFormat,
},
"content type xml": {
giveHeaders: http.Header{"Content-Type": {"application/xml; charset=UTF-8"}},
giveHeaders: map[string][]string{"Content-Type": {"application/xml; charset=UTF-8"}},
wantFormat: xmlFormat,
},
"content type html": {
giveHeaders: http.Header{"Content-Type": {"text/hTmL; charset=utf-8"}},
giveHeaders: map[string][]string{"Content-Type": {"text/hTmL; charset=utf-8"}},
wantFormat: htmlFormat,
},
"content type plain": {
giveHeaders: http.Header{"Content-Type": {"text/plaIN"}},
giveHeaders: map[string][]string{"Content-Type": {"text/plaIN"}},
wantFormat: plainTextFormat,
},
"accept json": {
giveHeaders: http.Header{"Accept": {"application/jsoN,*/*;q=0.8"}},
giveHeaders: map[string][]string{"Accept": {"application/jsoN,*/*;q=0.8"}},
wantFormat: jsonFormat,
},
"accept xml, depends on weight": {
giveHeaders: http.Header{"Accept": {"text/html;q=0.5,application/xhtml+xml;q=0.9,application/xml;q=1,*/*;q=0.8"}},
giveHeaders: map[string][]string{"Accept": {"text/html;q=0.5,application/xhtml+xml;q=0.9,application/xml;q=1,*/*;q=0.8"}},
wantFormat: xmlFormat,
},
"accept json, depends on weight": {
giveHeaders: http.Header{"Accept": {"application/jsoN,*/*;q=0.8"}},
giveHeaders: map[string][]string{"Accept": {"application/jsoN,*/*;q=0.8"}},
wantFormat: jsonFormat,
},
"accept xml": {
giveHeaders: http.Header{"Accept": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}},
giveHeaders: map[string][]string{"Accept": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}},
wantFormat: xmlFormat,
},
"accept html": {
giveHeaders: http.Header{"Accept": {"text/html, application/xhtml+xml, application/xml;q=0.9, image/avif, image/webp, */*;q=0.8"}},
giveHeaders: map[string][]string{"Accept": {"text/html, application/xhtml+xml, application/xml;q=0.9, image/avif, image/webp, */*;q=0.8"}},
wantFormat: htmlFormat,
},
"accept plain": {
giveHeaders: http.Header{"Accept": {"text/plaiN,text/html,application/xml;q=0.9,,,*/*;q=0.8"}},
giveHeaders: map[string][]string{"Accept": {"text/plaiN,text/html,application/xml;q=0.9,,,*/*;q=0.8"}},
wantFormat: plainTextFormat,
},
"accept json, weighted values only": {
giveHeaders: http.Header{"Accept": {"application/jsoN;Q=0.1,text/html;q=1.1,application/xml;q=-1,*/*;q=0.8"}},
giveHeaders: map[string][]string{"Accept": {"application/jsoN;Q=0.1,text/html;q=1.1,application/xml;q=-1,*/*;q=0.8"}},
wantFormat: jsonFormat,
},
"x-format json, depends on weight": {
giveHeaders: http.Header{"X-Format": {"application/jsoN,*/*;q=0.8"}},
giveHeaders: map[string][]string{"X-Format": {"application/jsoN,*/*;q=0.8"}},
wantFormat: jsonFormat,
},
"x-format xml": {
giveHeaders: http.Header{"X-Format": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}},
giveHeaders: map[string][]string{"X-Format": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}},
wantFormat: xmlFormat,
},
"content type has priority over accept": {
giveHeaders: http.Header{"Content-Type": {"text/plain"}, "Accept": {"application/xml"}},
giveHeaders: map[string][]string{"Content-Type": {"text/plain"}, "Accept": {"application/xml"}},
wantFormat: plainTextFormat,
},
"accept has priority over x-format": {
giveHeaders: http.Header{"Accept": {"application/xml"}, "X-Format": {"text/plain"}},
giveHeaders: map[string][]string{"Accept": {"application/xml"}, "X-Format": {"text/plain"}},
wantFormat: plainTextFormat,
},
@ -82,25 +82,33 @@ func Test_detectPreferredFormatForClient(t *testing.T) {
giveHeaders: nil,
},
"empty content type": {
giveHeaders: http.Header{"Content-Type": {" "}},
giveHeaders: map[string][]string{"Content-Type": {" "}},
},
"wrong content type": {
giveHeaders: http.Header{"Content-Type": {"multipart/form-data; boundary=something"}},
giveHeaders: map[string][]string{"Content-Type": {"multipart/form-data; boundary=something"}},
},
"wrong accept": {
giveHeaders: http.Header{"Accept": {";q=foobar,bar/baz;;;;;application/xml"}},
giveHeaders: map[string][]string{"Accept": {";q=foobar,bar/baz;;;;;application/xml"}},
},
"none on invalid input": {
giveHeaders: http.Header{"Content-Type": {"foo/bar; charset=utf-8"}, "Accept": {"foo/bar; charset=utf-8"}},
giveHeaders: map[string][]string{"Content-Type": {"foo/bar; charset=utf-8"}, "Accept": {"foo/bar; charset=utf-8"}},
},
"completely unknown": {
giveHeaders: http.Header{"Content-Type": {"😀"}, "Accept": {"😄"}, "X-Format": {"😍"}},
giveHeaders: map[string][]string{"Content-Type": {"😀"}, "Accept": {"😄"}, "X-Format": {"😍"}},
},
} {
tt := _tt
t.Run(name, func(t *testing.T) {
assert.Equal(t, tt.wantFormat, detectPreferredFormatForClient(tt.giveHeaders))
var headers = new(fasthttp.RequestHeader)
for key, values := range tt.giveHeaders {
for _, value := range values {
headers.Add(key, value)
}
}
assert.Equal(t, tt.wantFormat, detectPreferredFormatForClient(headers))
})
}
}

View File

@ -7,21 +7,24 @@ import (
"sync/atomic"
"time"
"github.com/valyala/fasthttp"
"gh.tarampamp.am/error-pages/internal/config"
"gh.tarampamp.am/error-pages/internal/logger"
"gh.tarampamp.am/error-pages/internal/template"
)
// New creates a new handler that returns an error page with the specified status code and format.
func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,gocognit,gocyclo
const contentTypeHeader = "Content-Type"
func New(cfg *config.Config, log *logger.Logger) fasthttp.RequestHandler { //nolint:funlen,gocognit,gocyclo
return func(ctx *fasthttp.RequestCtx) {
var (
reqHeaders = &ctx.Request.Header
code uint16
)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var code uint16
if fromUrl, okUrl := extractCodeFromURL(r.URL.Path); okUrl {
if fromUrl, okUrl := extractCodeFromURL(string(ctx.RequestURI())); okUrl {
code = fromUrl
} else if fromHeader, okHeaders := extractCodeFromHeaders(r.Header); okHeaders {
} else if fromHeader, okHeaders := extractCodeFromHeaders(reqHeaders); okHeaders {
code = fromHeader
} else {
code = cfg.DefaultCodeToRender
@ -35,23 +38,23 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
httpCode = http.StatusOK
}
var format = detectPreferredFormatForClient(r.Header)
var format = detectPreferredFormatForClient(reqHeaders)
{ // deal with the headers
switch format {
case jsonFormat:
w.Header().Set(contentTypeHeader, "application/json; charset=utf-8")
ctx.SetContentType("application/json; charset=utf-8")
case xmlFormat:
w.Header().Set(contentTypeHeader, "application/xml; charset=utf-8")
ctx.SetContentType("application/xml; charset=utf-8")
case htmlFormat:
w.Header().Set(contentTypeHeader, "text/html; charset=utf-8")
ctx.SetContentType("text/html; charset=utf-8")
default:
w.Header().Set(contentTypeHeader, "text/plain; charset=utf-8") // plainTextFormat as default
ctx.SetContentType("text/plain; charset=utf-8") // plainTextFormat as default
}
// https://developers.google.com/search/docs/crawling-indexing/robots-meta-tag
// disallow indexing of the error pages
w.Header().Set("X-Robots-Tag", "noindex")
ctx.Response.Header.Set("X-Robots-Tag", "noindex")
switch code {
case http.StatusRequestTimeout, http.StatusTooEarly, http.StatusTooManyRequests,
@ -59,18 +62,18 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
http.StatusGatewayTimeout:
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
// tell the client (search crawler) to retry the request after 120 seconds
w.Header().Set("Retry-After", "120")
ctx.Response.Header.Set("Retry-After", "120")
}
// proxy the headers from the incoming request to the error page response if they are defined in the config
for _, proxyHeader := range cfg.ProxyHeaders {
if value := r.Header.Get(proxyHeader); value != "" {
w.Header().Set(proxyHeader, value)
if value := reqHeaders.Peek(proxyHeader); len(value) > 0 {
ctx.Response.Header.SetBytesV(proxyHeader, value)
}
}
}
w.WriteHeader(httpCode)
ctx.SetStatusCode(httpCode)
// prepare the template properties for rendering
var tplProps = template.Props{
@ -81,14 +84,14 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
//nolint:lll
if cfg.ShowDetails { // https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/
tplProps.OriginalURI = r.Header.Get("X-Original-URI") // (ingress-nginx) URI that caused the error
tplProps.Namespace = r.Header.Get("X-Namespace") // (ingress-nginx) namespace where the backend Service is located
tplProps.IngressName = r.Header.Get("X-Ingress-Name") // (ingress-nginx) name of the Ingress where the backend is defined
tplProps.ServiceName = r.Header.Get("X-Service-Name") // (ingress-nginx) name of the Service backing the backend
tplProps.ServicePort = r.Header.Get("X-Service-Port") // (ingress-nginx) port number of the Service backing the backend
tplProps.RequestID = r.Header.Get("X-Request-Id") // (ingress-nginx) unique ID that identifies the request - same as for backend service
tplProps.ForwardedFor = r.Header.Get("X-Forwarded-For") // the value of the `X-Forwarded-For` header
tplProps.Host = r.Host // the value of the `Host` header
tplProps.OriginalURI = string(reqHeaders.Peek("X-Original-URI")) // (ingress-nginx) URI that caused the error
tplProps.Namespace = string(reqHeaders.Peek("X-Namespace")) // (ingress-nginx) namespace where the backend Service is located
tplProps.IngressName = string(reqHeaders.Peek("X-Ingress-Name")) // (ingress-nginx) name of the Ingress where the backend is defined
tplProps.ServiceName = string(reqHeaders.Peek("X-Service-Name")) // (ingress-nginx) name of the Service backing the backend
tplProps.ServicePort = string(reqHeaders.Peek("X-Service-Port")) // (ingress-nginx) port number of the Service backing the backend
tplProps.RequestID = string(reqHeaders.Peek("X-Request-Id")) // (ingress-nginx) unique ID that identifies the request - same as for backend service
tplProps.ForwardedFor = string(reqHeaders.Peek("X-Forwarded-For")) // the value of the `X-Forwarded-For` header
tplProps.Host = string(reqHeaders.Peek("Host")) // the value of the `Host` header
}
// try to find the code message and description in the config and if not - use the standard status text or fallback
@ -105,18 +108,18 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
case format == jsonFormat && cfg.Formats.JSON != "":
if content, err := template.Render(cfg.Formats.JSON, tplProps); err != nil {
j, _ := json.Marshal(fmt.Sprintf("Failed to render the JSON template: %s", err.Error()))
write(w, log, j)
write(ctx, log, j)
} else {
write(w, log, content)
write(ctx, log, content)
}
case format == xmlFormat && cfg.Formats.XML != "":
if content, err := template.Render(cfg.Formats.XML, tplProps); err != nil {
write(w, log, fmt.Sprintf(
write(ctx, log, fmt.Sprintf(
"<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<error>Failed to render the XML template: %s</error>", err.Error(),
))
} else {
write(w, log, content)
write(ctx, log, content)
}
case format == htmlFormat:
@ -125,16 +128,16 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
if tpl, found := cfg.Templates.Get(templateName); found {
if content, err := template.Render(tpl, tplProps); err != nil {
// TODO: add GZIP compression for the HTML content support
write(w, log, fmt.Sprintf(
write(ctx, log, fmt.Sprintf(
"<!DOCTYPE html>\n<html><body>Failed to render the HTML template %s: %s</body></html>",
templateName,
err.Error(),
))
} else {
write(w, log, content)
write(ctx, log, content)
}
} else {
write(w, log, fmt.Sprintf(
write(ctx, log, fmt.Sprintf(
"<!DOCTYPE html>\n<html><body>Template %s not found and cannot be used</body></html>", templateName,
))
}
@ -142,18 +145,18 @@ func New(cfg *config.Config, log *logger.Logger) http.Handler { //nolint:funlen,
default: // plainTextFormat as default
if cfg.Formats.PlainText != "" {
if content, err := template.Render(cfg.Formats.PlainText, tplProps); err != nil {
write(w, log, fmt.Sprintf("Failed to render the PlainText template: %s", err.Error()))
write(ctx, log, fmt.Sprintf("Failed to render the PlainText template: %s", err.Error()))
} else {
write(w, log, content)
write(ctx, log, content)
}
} else {
write(w, log, `The requested content format is not supported.
write(ctx, log, `The requested content format is not supported.
Please create an issue on the project's GitHub page to request support for this format.
Supported formats: JSON, XML, HTML, Plain Text`)
}
}
})
}
}
var (
@ -204,7 +207,7 @@ func templateToUse(cfg *config.Config) string {
}
// write the content to the response writer and log the error if any.
func write[T string | []byte](w http.ResponseWriter, log *logger.Logger, content T) {
func write[T string | []byte](ctx *fasthttp.RequestCtx, log *logger.Logger, content T) {
var data []byte
if s, ok := any(content).(string); ok {
@ -213,7 +216,7 @@ func write[T string | []byte](w http.ResponseWriter, log *logger.Logger, content
data = any(content).([]byte)
}
if _, err := w.Write(data); err != nil && log != nil {
if _, err := ctx.Write(data); err != nil && log != nil {
log.Error("failed to write the response body",
logger.String("content", string(data)),
logger.Error(err),

View File

@ -2,13 +2,14 @@ package error_page_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gh.tarampamp.am/error-pages/internal/config"
"gh.tarampamp.am/error-pages/internal/http/handlers/error_page"
"gh.tarampamp.am/error-pages/internal/http/httptest"
"gh.tarampamp.am/error-pages/internal/logger"
)
@ -26,7 +27,7 @@ func TestHandler(t *testing.T) {
}{
"common, plain text": {
giveConfig: func() *config.Config { cfg := config.New(); return &cfg },
giveUrl: "/",
giveUrl: "http://testing/",
giveHeaders: map[string]string{"Content-Type": "text/plain"},
wantStatusCode: http.StatusOK,
@ -41,7 +42,7 @@ func TestHandler(t *testing.T) {
return &cfg
},
giveUrl: "/",
giveUrl: "http://testing/",
giveHeaders: map[string]string{"X-Format": "text/html", "X-Code": "407"},
wantStatusCode: http.StatusOK,
@ -60,7 +61,7 @@ func TestHandler(t *testing.T) {
return &cfg
},
giveUrl: "/503.html",
giveUrl: "http://testing/503.html",
giveHeaders: map[string]string{"Accept": "application/json", "X-FooBar": "baz"},
wantStatusCode: http.StatusServiceUnavailable,
@ -78,7 +79,7 @@ func TestHandler(t *testing.T) {
return &cfg
},
giveUrl: "/500",
giveUrl: "http://testing/500",
giveHeaders: map[string]string{"Accept": "application/xml", "X-FooBar": "baz"},
wantStatusCode: http.StatusOK,
@ -96,7 +97,7 @@ func TestHandler(t *testing.T) {
return &cfg
},
giveUrl: "/503",
giveUrl: "http://example.com/503",
giveHeaders: map[string]string{
"Accept": "application/json",
"X-Original-URI": "/foo/bar",
@ -106,7 +107,6 @@ func TestHandler(t *testing.T) {
"X-Service-Port": "666",
"X-Request-ID": "req-id-777",
"X-Forwarded-For": "123.123.123.123:12312",
"Host": "example.com",
},
wantStatusCode: http.StatusOK,
@ -133,7 +133,7 @@ func TestHandler(t *testing.T) {
return &cfg
},
giveUrl: "/100",
giveUrl: "http://testing/100",
giveHeaders: map[string]string{"Accept": "application/json"},
wantStatusCode: http.StatusOK,
@ -148,7 +148,7 @@ func TestHandler(t *testing.T) {
return &cfg
},
giveUrl: "/1",
giveUrl: "http://testing/1",
giveHeaders: map[string]string{"Accept": "application/json"},
wantStatusCode: http.StatusOK,
@ -159,31 +159,30 @@ func TestHandler(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
var (
req = httptest.NewRequest(http.MethodGet, tt.giveUrl, http.NoBody)
handler = error_page.New(tt.giveConfig(), logger.NewNop())
rr = httptest.NewRecorder()
)
var handler = error_page.New(tt.giveConfig(), logger.NewNop())
req, reqErr := http.NewRequest(http.MethodGet, tt.giveUrl, http.NoBody)
require.NoError(t, reqErr)
for k, v := range tt.giveHeaders {
req.Header.Set(k, v)
}
handler.ServeHTTP(rr, req)
httptest.HandleFastRequest(t, handler, req, func(status int, body string, headers http.Header) {
assert.Equal(t, tt.wantStatusCode, status)
assert.Equal(t, rr.Code, tt.wantStatusCode)
for hName, hWant := range tt.wantHeaders {
for hGot := range rr.Header() {
if hGot == hName {
assert.Contains(t, hWant, rr.Header().Get(hGot))
for hName, hWant := range tt.wantHeaders {
for hGot := range headers {
if hGot == hName {
assert.Contains(t, hWant, headers.Get(hGot))
}
}
}
}
for _, wantBodyInclude := range tt.wantBodyIncludes {
assert.Contains(t, rr.Body.String(), wantBodyInclude)
}
for _, wantBodyInclude := range tt.wantBodyIncludes {
assert.Contains(t, body, wantBodyInclude)
}
})
})
}
}
@ -207,19 +206,17 @@ func TestRotationModeOnEachRequest(t *testing.T) {
)
for range 300 {
var (
req = httptest.NewRequest(http.MethodGet, "/", http.NoBody)
rr = httptest.NewRecorder()
)
req, reqErr := http.NewRequest(http.MethodGet, "http://testing/", http.NoBody)
require.NoError(t, reqErr)
req.Header.Set("Accept", "text/html")
handler.ServeHTTP(rr, req)
if lastResponseBody != rr.Body.String() {
changedTimes++
lastResponseBody = rr.Body.String()
}
httptest.HandleFastRequest(t, handler, req, func(status int, body string, headers http.Header) {
if lastResponseBody != body {
changedTimes++
lastResponseBody = body
}
})
}
assert.True(t, changedTimes > 30, "the template should be changed at least 30 times")

View File

@ -2,24 +2,29 @@ package live
import (
"net/http"
"github.com/valyala/fasthttp"
)
// New creates a new handler that returns "OK" for GET and HEAD requests.
func New() http.Handler {
var body = []byte("OK\n")
func New() fasthttp.RequestHandler {
var (
body = []byte("OK\n")
notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n"
)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(body)
return func(ctx *fasthttp.RequestCtx) {
switch string(ctx.Method()) {
case fasthttp.MethodGet:
ctx.SetContentType("text/plain; charset=utf-8")
ctx.SetStatusCode(http.StatusOK)
_, _ = ctx.Write(body)
case http.MethodHead:
w.WriteHeader(http.StatusOK)
case fasthttp.MethodHead:
ctx.SetStatusCode(http.StatusOK)
default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
ctx.Error(notAllowed, http.StatusMethodNotAllowed)
}
})
}
}

View File

@ -2,43 +2,37 @@ package live_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"gh.tarampamp.am/error-pages/internal/http/handlers/live"
"gh.tarampamp.am/error-pages/internal/http/httptest"
)
func TestServeHTTP(t *testing.T) {
t.Parallel()
var handler = live.New()
var (
handler = live.New()
url = "http://testing"
body = http.NoBody
)
t.Run("get", func(t *testing.T) {
var (
req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody)
rr = httptest.NewRecorder()
)
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8")
assert.Equal(t, rr.Code, http.StatusOK)
assert.Equal(t, "OK\n", rr.Body.String())
httptest.HandleFast(t, handler, http.MethodGet, url, body, func(status int, body string, headers http.Header) {
assert.Equal(t, http.StatusOK, status)
assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type"))
assert.Equal(t, "OK\n", body)
})
})
t.Run("head", func(t *testing.T) {
var (
req = httptest.NewRequest(http.MethodHead, "http://testing", http.NoBody)
rr = httptest.NewRecorder()
)
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Code, http.StatusOK)
assert.Empty(t, rr.Header().Get("Content-Type"))
assert.Empty(t, rr.Body.Bytes())
httptest.HandleFast(t, handler, http.MethodHead, url, body, func(status int, body string, headers http.Header) {
assert.Equal(t, http.StatusOK, status)
assert.Empty(t, headers.Get("Content-Type"))
assert.Empty(t, body)
})
})
t.Run("method not allowed", func(t *testing.T) {
@ -48,16 +42,11 @@ func TestServeHTTP(t *testing.T) {
http.MethodPost,
http.MethodPut,
} {
var (
req = httptest.NewRequest(method, "http://testing", http.NoBody)
rr = httptest.NewRecorder()
)
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8")
assert.Equal(t, rr.Code, http.StatusMethodNotAllowed)
assert.Equal(t, "Method Not Allowed\n", rr.Body.String())
httptest.HandleFast(t, handler, method, url, body, func(status int, body string, headers http.Header) {
assert.Equal(t, http.StatusMethodNotAllowed, status)
assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type"))
assert.Equal(t, "Method Not Allowed\n", body)
})
}
})
}

View File

@ -3,25 +3,29 @@ package static
import (
_ "embed"
"net/http"
"github.com/valyala/fasthttp"
)
//go:embed favicon.ico
var Favicon []byte
// New creates a new handler that returns the provided content for GET and HEAD requests.
func New(content []byte) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
w.Header().Set("Content-Type", http.DetectContentType(content))
w.WriteHeader(http.StatusOK)
_, _ = w.Write(content)
func New(content []byte) fasthttp.RequestHandler {
var notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n"
case http.MethodHead:
w.WriteHeader(http.StatusOK)
return func(ctx *fasthttp.RequestCtx) {
switch string(ctx.Method()) {
case fasthttp.MethodGet:
ctx.SetContentType(http.DetectContentType(content))
ctx.SetStatusCode(http.StatusOK)
_, _ = ctx.Write(content)
case fasthttp.MethodHead:
ctx.SetStatusCode(http.StatusOK)
default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
ctx.Error(notAllowed, http.StatusMethodNotAllowed)
}
})
}
}

View File

@ -2,43 +2,37 @@ package static_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"gh.tarampamp.am/error-pages/internal/http/handlers/static"
"gh.tarampamp.am/error-pages/internal/http/httptest"
)
func TestServeHTTP(t *testing.T) {
t.Parallel()
var handler = static.New([]byte{1, 2, 3})
var (
handler = static.New([]byte{1, 2, 3})
url = "http://testing"
body = http.NoBody
)
t.Run("get", func(t *testing.T) {
var (
req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody)
rr = httptest.NewRecorder()
)
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Header().Get("Content-Type"), "application/octet-stream")
assert.Equal(t, rr.Code, http.StatusOK)
assert.Equal(t, rr.Body.Bytes(), []byte{1, 2, 3})
httptest.HandleFast(t, handler, http.MethodGet, url, body, func(status int, body string, headers http.Header) {
assert.Equal(t, http.StatusOK, status)
assert.Equal(t, "application/octet-stream", headers.Get("Content-Type"))
assert.Equal(t, []byte{1, 2, 3}, []byte(body))
})
})
t.Run("head", func(t *testing.T) {
var (
req = httptest.NewRequest(http.MethodHead, "http://testing", http.NoBody)
rr = httptest.NewRecorder()
)
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Code, http.StatusOK)
assert.Empty(t, rr.Header().Get("Content-Type"))
assert.Empty(t, rr.Body.Bytes())
httptest.HandleFast(t, handler, http.MethodHead, url, body, func(status int, body string, headers http.Header) {
assert.Equal(t, http.StatusOK, status)
assert.Empty(t, headers.Get("Content-Type"))
assert.Empty(t, body)
})
})
t.Run("method not allowed", func(t *testing.T) {
@ -48,16 +42,11 @@ func TestServeHTTP(t *testing.T) {
http.MethodPost,
http.MethodPut,
} {
var (
req = httptest.NewRequest(method, "http://testing", http.NoBody)
rr = httptest.NewRecorder()
)
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8")
assert.Equal(t, rr.Code, http.StatusMethodNotAllowed)
assert.Equal(t, "Method Not Allowed\n", rr.Body.String())
httptest.HandleFast(t, handler, method, url, body, func(status int, body string, headers http.Header) {
assert.Equal(t, http.StatusMethodNotAllowed, status)
assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type"))
assert.Equal(t, "Method Not Allowed\n", body)
})
}
})
}
@ -65,16 +54,15 @@ func TestServeHTTP(t *testing.T) {
func TestServeHTTP_Favicon(t *testing.T) {
t.Parallel()
var (
handler = static.New(static.Favicon)
req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody)
rr = httptest.NewRecorder()
httptest.HandleFast(t,
static.New(static.Favicon),
http.MethodGet,
"http://testing",
http.NoBody,
func(status int, body string, headers http.Header) {
assert.Equal(t, http.StatusOK, status)
assert.Equal(t, "image/x-icon", headers.Get("Content-Type"))
assert.Equal(t, static.Favicon, []byte(body))
},
)
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Header().Get("Content-Type"), "image/x-icon")
assert.Equal(t, rr.Code, http.StatusOK)
assert.Equal(t, rr.Body.Bytes(), static.Favicon)
}

View File

@ -4,28 +4,32 @@ import (
"encoding/json"
"net/http"
"strings"
"github.com/valyala/fasthttp"
)
// New creates a handler that returns the version of the service in JSON format.
func New(ver string) http.Handler {
func New(ver string) fasthttp.RequestHandler {
var body, _ = json.Marshal(struct { //nolint:errchkjson
Version string `json:"version"`
}{
Version: strings.TrimSpace(ver),
})
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(body)
var notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n"
case http.MethodHead:
w.WriteHeader(http.StatusOK)
return func(ctx *fasthttp.RequestCtx) {
switch string(ctx.Method()) {
case fasthttp.MethodGet:
ctx.SetContentType("application/json; charset=utf-8")
ctx.SetStatusCode(http.StatusOK)
_, _ = ctx.Write(body)
case fasthttp.MethodHead:
ctx.SetStatusCode(http.StatusOK)
default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
ctx.Error(notAllowed, http.StatusMethodNotAllowed)
}
})
}
}

View File

@ -2,43 +2,37 @@ package version_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"gh.tarampamp.am/error-pages/internal/http/handlers/version"
"gh.tarampamp.am/error-pages/internal/http/httptest"
)
func TestServeHTTP(t *testing.T) {
t.Parallel()
var handler = version.New("\t\n foo@bar ")
var (
handler = version.New("\t\n foo@bar ")
url = "http://testing"
body = http.NoBody
)
t.Run("get", func(t *testing.T) {
var (
req = httptest.NewRequest(http.MethodGet, "http://testing", http.NoBody)
rr = httptest.NewRecorder()
)
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Header().Get("Content-Type"), "application/json; charset=utf-8")
assert.Equal(t, rr.Code, http.StatusOK)
assert.Equal(t, rr.Body.String(), `{"version":"foo@bar"}`)
httptest.HandleFast(t, handler, http.MethodGet, url, body, func(status int, body string, headers http.Header) {
assert.Equal(t, http.StatusOK, status)
assert.Equal(t, "application/json; charset=utf-8", headers.Get("Content-Type"))
assert.Equal(t, `{"version":"foo@bar"}`, body)
})
})
t.Run("head", func(t *testing.T) {
var (
req = httptest.NewRequest(http.MethodHead, "http://testing", http.NoBody)
rr = httptest.NewRecorder()
)
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Code, http.StatusOK)
assert.Empty(t, rr.Header().Get("Content-Type"))
assert.Empty(t, rr.Body.Bytes())
httptest.HandleFast(t, handler, http.MethodHead, url, body, func(status int, body string, headers http.Header) {
assert.Equal(t, http.StatusOK, status)
assert.Empty(t, headers.Get("Content-Type"))
assert.Empty(t, body)
})
})
t.Run("method not allowed", func(t *testing.T) {
@ -48,16 +42,11 @@ func TestServeHTTP(t *testing.T) {
http.MethodPost,
http.MethodPut,
} {
var (
req = httptest.NewRequest(method, "http://testing", http.NoBody)
rr = httptest.NewRecorder()
)
handler.ServeHTTP(rr, req)
assert.Equal(t, rr.Header().Get("Content-Type"), "text/plain; charset=utf-8")
assert.Equal(t, rr.Code, http.StatusMethodNotAllowed)
assert.Equal(t, "Method Not Allowed\n", rr.Body.String())
httptest.HandleFast(t, handler, method, url, body, func(status int, body string, headers http.Header) {
assert.Equal(t, http.StatusMethodNotAllowed, status)
assert.Equal(t, "text/plain; charset=utf-8", headers.Get("Content-Type"))
assert.Equal(t, "Method Not Allowed\n", body)
})
}
})
}

View File

@ -0,0 +1,69 @@
// Package httptest provides utilities for (fast-)HTTP testing.
package httptest
import (
"context"
"io"
"net"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
)
// HandleFastRequest serves http request using provided fasthttp handler and HTTP request.
func HandleFastRequest(
t *testing.T,
handler fasthttp.RequestHandler,
req *http.Request,
check func(status int, body string, _ http.Header),
) {
t.Helper()
// create in-memory listener
var ln = fasthttputil.NewInmemoryListener()
defer func() { require.NoError(t, ln.Close()) }()
// start fasthttp server
go func() { require.NoError(t, fasthttp.Serve(ln, handler)) }()
// send http request
resp, respErr := (&http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return ln.Dial() },
},
}).Do(req)
require.NoError(t, respErr)
// close response body after the test
defer func() { assert.NoError(t, resp.Body.Close()) }()
// read response body
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// check the response
check(resp.StatusCode, string(respBody), resp.Header)
}
// HandleFast serves http request using provided fasthttp handler.
func HandleFast(
t *testing.T,
handler fasthttp.RequestHandler,
method string,
url string,
body io.Reader,
check func(status int, body string, _ http.Header),
) {
t.Helper()
// create http request
req, reqErr := http.NewRequest(method, url, body)
require.NoError(t, reqErr)
// serve http request
HandleFastRequest(t, handler, req, check)
}

View File

@ -1,20 +1,24 @@
package logreq
import (
"net/http"
"time"
"github.com/valyala/fasthttp"
"gh.tarampamp.am/error-pages/internal/logger"
)
// New creates a middleware for [http.ServeMux] that logs every incoming request.
// New creates a middleware that logs every incoming request.
//
// The skipper function should return true if the request should be skipped. It's ok to pass nil.
func New(log *logger.Logger, skipper func(*http.Request) bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if skipper != nil && skipper(r) {
next.ServeHTTP(w, r)
func New(
log *logger.Logger,
skipper func(*fasthttp.RequestCtx) bool,
) func(fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
if skipper != nil && skipper(ctx) {
next(ctx)
return
}
@ -23,27 +27,35 @@ func New(log *logger.Logger, skipper func(*http.Request) bool) func(http.Handler
defer func() {
var fields = []logger.Attr{
logger.String("useragent", r.UserAgent()),
logger.String("method", r.Method),
logger.String("url", r.URL.String()),
logger.String("referer", r.Referer()),
logger.String("content type", w.Header().Get("Content-Type")),
logger.String("remote addr", r.RemoteAddr),
logger.String("method", r.Method),
logger.Int("status code", ctx.Response.StatusCode()),
logger.String("useragent", string(ctx.UserAgent())),
logger.String("method", string(ctx.Method())),
logger.String("url", string(ctx.RequestURI())),
logger.String("referer", string(ctx.Referer())),
logger.String("content type", string(ctx.Response.Header.ContentType())),
logger.String("remote addr", ctx.RemoteAddr().String()),
logger.Duration("duration", time.Since(now).Round(time.Microsecond)),
}
if log.Level() <= logger.DebugLevel {
var (
reqHeaders = make(map[string]string)
respHeaders = make(map[string]string)
)
ctx.Request.Header.VisitAll(func(key, value []byte) { reqHeaders[string(key)] = string(value) })
ctx.Response.Header.VisitAll(func(key, value []byte) { respHeaders[string(key)] = string(value) })
fields = append(fields,
logger.Any("request headers", r.Header.Clone()),
logger.Any("response headers", w.Header().Clone()),
logger.Any("request headers", reqHeaders),
logger.Any("response headers", respHeaders),
)
}
log.Info("HTTP request processed", fields...)
}()
next.ServeHTTP(w, r)
})
next(ctx)
}
}
}

View File

@ -3,11 +3,12 @@ package logreq_test
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/valyala/fasthttp"
"gh.tarampamp.am/error-pages/internal/http/httptest"
"gh.tarampamp.am/error-pages/internal/http/middleware/logreq"
"gh.tarampamp.am/error-pages/internal/logger"
)
@ -19,18 +20,19 @@ func TestNew(t *testing.T) {
buf bytes.Buffer
log, _ = logger.New(logger.DebugLevel, logger.JSONFormat, &buf)
mw = logreq.New(log, nil)
rr = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/foo/bar", http.NoBody)
mw = logreq.New(log, nil)
req, _ = http.NewRequest(http.MethodPut, "http://testing/foo/bar", http.NoBody)
)
req.Header.Set("User-Agent", "test")
req.Header.Set("Referer", "https://example.com")
req.Header.Set("Content-Type", "application/json")
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(rr, req)
httptest.HandleFastRequest(t,
mw(func(ctx *fasthttp.RequestCtx) { ctx.SetStatusCode(http.StatusOK) }),
req,
func(status int, body string, _ http.Header) { assert.Equal(t, http.StatusOK, status) },
)
var logRecord = buf.String()

View File

@ -2,12 +2,15 @@ package http
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"time"
"github.com/valyala/fasthttp"
"gh.tarampamp.am/error-pages/internal/appmeta"
"gh.tarampamp.am/error-pages/internal/config"
ep "gh.tarampamp.am/error-pages/internal/http/handlers/error_page"
@ -21,11 +24,11 @@ import (
// Server is an HTTP server for serving error pages.
type Server struct {
log *logger.Logger
server *http.Server
server *fasthttp.Server
}
// NewServer creates a new HTTP server.
func NewServer(baseCtx context.Context, log *logger.Logger) Server {
func NewServer(log *logger.Logger) Server {
const (
readTimeout = 30 * time.Second
writeTimeout = readTimeout + 10*time.Second // should be bigger than the read timeout
@ -34,13 +37,14 @@ func NewServer(baseCtx context.Context, log *logger.Logger) Server {
return Server{
log: log,
server: &http.Server{
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
ReadHeaderTimeout: readTimeout,
MaxHeaderBytes: maxHeaderBytes,
ErrorLog: logger.NewStdLog(log),
BaseContext: func(net.Listener) context.Context { return baseCtx },
server: &fasthttp.Server{
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
ReadBufferSize: maxHeaderBytes,
DisablePreParseMultipartForm: true,
NoDefaultServerHeader: true,
CloseOnShutdown: true,
Logger: logger.NewStdLog(log),
},
}
}
@ -52,60 +56,78 @@ func (s *Server) Register(cfg *config.Config) error {
versionHandler = version.New(appmeta.Version())
errorPagesHandler = ep.New(cfg, s.log)
faviconHandler = static.New(static.Favicon)
notFound = http.StatusText(http.StatusNotFound) + "\n"
notAllowed = http.StatusText(http.StatusMethodNotAllowed) + "\n"
)
s.server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var url, method = r.URL.Path, r.Method
s.server.Handler = func(ctx *fasthttp.RequestCtx) {
var url, method = string(ctx.RequestURI()), string(ctx.Method())
switch {
// live endpoints
case url == "/health/live" || url == "/health" || url == "/healthz" || url == "/live":
liveHandler.ServeHTTP(w, r)
liveHandler(ctx)
// version endpoint
case url == "/version":
versionHandler.ServeHTTP(w, r)
versionHandler(ctx)
// favicon.ico endpoint
case url == "/favicon.ico":
faviconHandler.ServeHTTP(w, r)
faviconHandler(ctx)
// error pages endpoints:
// - /
// - /{code}.html
// - /{code}.htm
// - /{code}
case method == http.MethodGet && (url == "/" || ep.URLContainsCode(url) || ep.HeadersContainCode(r.Header)):
errorPagesHandler.ServeHTTP(w, r)
case method == fasthttp.MethodGet &&
(url == "/" || ep.URLContainsCode(url) || ep.HeadersContainCode(&ctx.Request.Header)):
errorPagesHandler(ctx)
// wrong requests handling
default:
switch {
case method == http.MethodHead:
w.WriteHeader(http.StatusNotFound)
case method == http.MethodGet:
http.NotFound(w, r)
case method == fasthttp.MethodHead:
ctx.Error(notAllowed, fasthttp.StatusNotFound)
case method == fasthttp.MethodGet:
ctx.Error(notFound, fasthttp.StatusNotFound)
default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
ctx.Error(notAllowed, fasthttp.StatusMethodNotAllowed)
}
}
})
}
// apply middleware
s.server.Handler = logreq.New(s.log, func(r *http.Request) bool {
s.server.Handler = logreq.New(s.log, func(ctx *fasthttp.RequestCtx) bool {
// skip logging healthcheck and .ico (favicon) requests
return strings.Contains(strings.ToLower(r.UserAgent()), "healthcheck") ||
strings.HasSuffix(r.URL.Path, ".ico")
return strings.Contains(strings.ToLower(string(ctx.UserAgent())), "healthcheck") ||
strings.HasSuffix(string(ctx.Path()), ".ico")
})(s.server.Handler)
return nil
}
// Start server.
func (s *Server) Start(ip string, port uint16) error {
s.server.Addr = ip + ":" + strconv.Itoa(int(port))
func (s *Server) Start(ip string, port uint16) (err error) {
if net.ParseIP(ip) == nil {
return errors.New("invalid IP address")
}
return s.server.ListenAndServe()
var ln net.Listener
if strings.Count(ip, ":") >= 2 { //nolint:mnd // ipv6
if ln, err = net.Listen("tcp6", fmt.Sprintf("[%s]:%d", ip, port)); err != nil {
return err
}
} else { // ipv4
if ln, err = net.Listen("tcp4", fmt.Sprintf("%s:%d", ip, port)); err != nil {
return err
}
}
return s.server.Serve(ln)
}
// Stop server gracefully.
@ -113,5 +135,5 @@ func (s *Server) Stop(timeout time.Duration) error {
var ctx, cancel = context.WithTimeout(context.Background(), timeout)
defer cancel()
return s.server.Shutdown(ctx)
return s.server.ShutdownWithContext(ctx)
}

View File

@ -1,7 +1,6 @@
package http_test
import (
"context"
"errors"
"fmt"
"io"
@ -21,7 +20,7 @@ import (
// TestRouting in fact is a test for the whole server, because it tests all the routes and their handlers.
func TestRouting(t *testing.T) {
var (
srv = appHttp.NewServer(context.Background(), logger.NewNop())
srv = appHttp.NewServer(logger.NewNop())
cfg = config.New()
)
@ -296,7 +295,7 @@ func TestRouting(t *testing.T) {
assert.Equal(t, http.StatusNotFound, status)
assert.Empty(t, body)
assert.Empty(t, headers.Get("Content-Type"))
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
}
})