diff --git a/.golangci.yml b/.golangci.yml index f984a57..50d0d53 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -37,9 +37,9 @@ linters-settings: ignore-words: [cancelled] lll: line-length: 120 -# forbidigo: -# forbid: -# - '^(fmt\.Print(|f|ln)|print(|ln))(# it looks like a forgotten debugging printing call)?$' + forbidigo: + forbid: + - '^(fmt\.Print(|f|ln)|print(|ln))(# it looks like a forgotten debugging printing call)?$' prealloc: simple: true range-loops: true diff --git a/internal/http/handlers/error_page/code.go b/internal/http/handlers/error_page/code.go new file mode 100644 index 0000000..8dfe7ec --- /dev/null +++ b/internal/http/handlers/error_page/code.go @@ -0,0 +1,57 @@ +package error_page + +import ( + "net/http" + "path/filepath" + "strconv" + "strings" +) + +// extractCodeFromURL extracts the error code from the given URL. +func extractCodeFromURL(url string) (uint16, bool) { + var parts = strings.SplitN(strings.TrimLeft(url, "/"), "/", 1) + + if len(parts) == 0 { + return 0, false + } + + var ( + fileName = strings.ToLower(parts[0]) + ext = filepath.Ext(fileName) // ".html", ".htm", ".%something%" or an empty string + ) + + if ext != "" && ext != ".html" && ext != ".htm" { + return 0, false + } else if ext != "" { + fileName = strings.TrimSuffix(fileName, ext) + } + + if code, err := strconv.ParseUint(fileName, 10, 16); err == nil && code > 0 && code < 999 { + return uint16(code), true + } + + return 0, false +} + +// URLContainsCode checks if the given URL contains an error code. +func URLContainsCode(url string) (ok bool) { _, ok = extractCodeFromURL(url); return } //nolint:nlreturn + +// extractCodeFromHeaders extracts the error code from the given headers. +func extractCodeFromHeaders(headers http.Header) (uint16, bool) { + // https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/ + // HTTP status code returned by the request + if value := headers.Get("X-Code"); len(value) > 0 && len(value) <= 3 { + if code, err := strconv.ParseUint(value, 10, 16); err == nil && code > 0 && code < 999 { + return uint16(code), true + } + } + + return 0, false +} + +// HeadersContainCode checks if the given headers contain an error code. +func HeadersContainCode(headers http.Header) (ok bool) { + _, ok = extractCodeFromHeaders(headers) + + return +} diff --git a/internal/http/handlers/error_page/code_test.go b/internal/http/handlers/error_page/code_test.go new file mode 100644 index 0000000..0413d3c --- /dev/null +++ b/internal/http/handlers/error_page/code_test.go @@ -0,0 +1,58 @@ +package error_page_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "gh.tarampamp.am/error-pages/internal/http/handlers/error_page" +) + +func TestURLContainsCode(t *testing.T) { + t.Parallel() + + for giveUrl, wantOk := range map[string]bool{ + "/404": true, + "/404.htm": true, + "/404.HTM": true, + "/404.html": true, + "/404.HtmL": true, + "/404.css": false, + "/foo/404": false, + "/foo/404.html": false, + "/error": false, + "/": false, + "/////": false, + "///404//": false, + "": false, + } { + t.Run(giveUrl, func(t *testing.T) { + assert.Equal(t, wantOk, error_page.URLContainsCode(giveUrl)) + }) + } +} + +func TestHeadersContainCode(t *testing.T) { + t.Parallel() + + for name, _tt := range map[string]struct { + giveHeaders http.Header + wantOk bool + }{ + "with code": {giveHeaders: http.Header{"X-Code": {"404"}}, wantOk: true}, + + "empty": {giveHeaders: nil}, + "no code": {giveHeaders: http.Header{"X-Code": {""}}}, + "wrong": {giveHeaders: http.Header{"X-Code": {"foo"}}}, + "too big": {giveHeaders: http.Header{"X-Code": {"1000"}}}, + "too small": {giveHeaders: http.Header{"X-Code": {"0"}}}, + "negative": {giveHeaders: http.Header{"X-Code": {"-1"}}}, + } { + tt := _tt + + t.Run(name, func(t *testing.T) { + assert.Equal(t, tt.wantOk, error_page.HeadersContainCode(tt.giveHeaders)) + }) + } +} diff --git a/internal/http/handlers/error_page/format.go b/internal/http/handlers/error_page/format.go new file mode 100644 index 0000000..909f90b --- /dev/null +++ b/internal/http/handlers/error_page/format.go @@ -0,0 +1,134 @@ +package error_page + +import ( + "math" + "net/http" + "slices" + "strconv" + "strings" +) + +type preferredFormat = byte + +const ( + unknownFormat preferredFormat = iota // should be first, no format detected + jsonFormat // json + xmlFormat // xml + htmlFormat // html + plainTextFormat // plain text +) + +// detectPreferredFormatForClient detects the preferred format for the client based on the headers. +// It supports the following headers: Content-Type, Accept, X-Format. +// If the headers are not set or the format is not recognized, it returns unknownFormat. +func detectPreferredFormatForClient(headers http.Header) preferredFormat { //nolint:funlen,gocognit + var contentType, accept string + + if contentTypeHeader := strings.TrimSpace(headers.Get("Content-Type")); contentTypeHeader != "" { //nolint:nestif + // https://developer.mozilla.org/docs/Web/HTTP/Headers/Content-Type + // text/html; charset=utf-8 + // multipart/form-data; boundary=something + // application/json + if parts := strings.SplitN(contentTypeHeader, ";", 2); len(parts) > 1 { //nolint:mnd + // take only the first part of the content type: + // text/html; charset=utf-8 + // ^^^^^^^^^ - will be taken + contentType = strings.TrimSpace(parts[0]) + } else { + // take the whole value + contentType = contentTypeHeader + } + } else if xFormatHeader := strings.TrimSpace(headers.Get("X-Format")); xFormatHeader != "" { + // https://kubernetes.github.io/ingress-nginx/user-guide/custom-errors/ + // Value of the `Accept` header sent by the client + accept = xFormatHeader + } else if acceptHeader := strings.TrimSpace(headers.Get("Accept")); acceptHeader != "" { + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept + // text/html, application/xhtml+xml, application/xml;q=0.9, image/webp, */*;q=0.8 + // text/html + // image/* + // */* + // text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 + accept = acceptHeader + } else { + return unknownFormat + } + + switch { + case contentType != "": + return mimeTypeToPreferredFormat(contentType) + + case accept != "": + type piece struct { + mimeType string + weight int // to avoid float32 comparison (weight 1.0 = 1_0, 0.9 = 0_9, 0.8 = 0_8, etc.) + } + + var pieces = make([]piece, 0, strings.Count(accept, ",")+1) + + // split application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 into parts: + // ^^^^^^^^^ - segment #3 + // ^^^^^^^^^^^^^^^^^^^^^ - segment #2 + // ^^^^^^^^^^^^^^^^^^^^^ - segment #1 + for _, segment := range strings.FieldsFunc(accept, func(r rune) bool { return r == ',' }) { + // split segment into parts: + // + // application/xhtml+xml + // ^^^^^^^^^^^^^^^^^^^^^ - part #1 + // + // application/xml;q=0.9 + // ^^^^^ - part #2 + // ^^^^^^^^^^^^^^^ - part #1 + // + // */*;q=0.8 + // ^^^^^ - part #2 + // ^^^ - part #1 + if parts := strings.SplitN(strings.TrimSpace(segment), ";", 2); len(parts) > 0 { //nolint:mnd,nestif + if parts[0] == "*/*" { + continue // skip the wildcard + } + + var p = piece{mimeType: parts[0], weight: 1_0} //nolint:mnd // by default the weight is 10 (1.0 in float) + + if len(parts) > 1 { // we need to extract the weight + // trim the `q=` prefix and try to parse the weight value + if weight, err := strconv.ParseFloat(strings.TrimPrefix(strings.ToLower(parts[1]), "q="), 32); err == nil { + if weight = math.Round(weight*100) / 100; weight <= 1 && weight >= 0 { //nolint:mnd + p.weight = int(weight * 10) //nolint:mnd + } else { + p.weight = 0 // invalid weight, set it to 0 + } + } + } + + pieces = append(pieces, p) + } + } + + if len(pieces) > 0 { + slices.SortStableFunc(pieces, func(a, b piece) int { return b.weight - a.weight }) + + return mimeTypeToPreferredFormat(pieces[0].mimeType) + } + } + + return unknownFormat +} + +// mimeTypeToPreferredFormat converts a MIME type to a preferred format, using non-string comparison. +func mimeTypeToPreferredFormat(mimeType string) preferredFormat { + switch value := strings.ToLower(mimeType); { + case strings.Contains(value, "/json"): // application/json text/json + return jsonFormat + case strings.Contains(value, "/xml"): // application/xml text/xml + return xmlFormat + case strings.Contains(value, "+xml"): // application/xhtml+xml + return xmlFormat + case strings.Contains(value, "/html"): // text/html + return htmlFormat + case strings.Contains(value, "/plain"): // text/plain + return plainTextFormat + } + + return unknownFormat +} diff --git a/internal/http/handlers/error_page/format_test.go b/internal/http/handlers/error_page/format_test.go new file mode 100644 index 0000000..06fa20d --- /dev/null +++ b/internal/http/handlers/error_page/format_test.go @@ -0,0 +1,106 @@ +package error_page + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_detectPreferredFormatForClient(t *testing.T) { + t.Parallel() + + for name, _tt := range map[string]struct { + giveHeaders http.Header + wantFormat preferredFormat + }{ + "content type json": { + giveHeaders: http.Header{"Content-Type": {"application/jSoN"}}, + wantFormat: jsonFormat, + }, + "content type xml": { + giveHeaders: http.Header{"Content-Type": {"application/xml; charset=UTF-8"}}, + wantFormat: xmlFormat, + }, + "content type html": { + giveHeaders: http.Header{"Content-Type": {"text/hTmL; charset=utf-8"}}, + wantFormat: htmlFormat, + }, + "content type plain": { + giveHeaders: http.Header{"Content-Type": {"text/plaIN"}}, + wantFormat: plainTextFormat, + }, + + "accept json": { + giveHeaders: http.Header{"Accept": {"application/jsoN,*/*;q=0.8"}}, + wantFormat: jsonFormat, + }, + "accept xml, depends on weight": { + giveHeaders: http.Header{"Accept": {"text/html;q=0.5,application/xhtml+xml;q=0.9,application/xml;q=1,*/*;q=0.8"}}, + wantFormat: xmlFormat, + }, + "accept json, depends on weight": { + giveHeaders: http.Header{"Accept": {"application/jsoN,*/*;q=0.8"}}, + wantFormat: jsonFormat, + }, + "accept xml": { + giveHeaders: http.Header{"Accept": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}}, + wantFormat: xmlFormat, + }, + "accept html": { + giveHeaders: http.Header{"Accept": {"text/html, application/xhtml+xml, application/xml;q=0.9, image/avif, image/webp, */*;q=0.8"}}, + wantFormat: htmlFormat, + }, + "accept plain": { + giveHeaders: http.Header{"Accept": {"text/plaiN,text/html,application/xml;q=0.9,,,*/*;q=0.8"}}, + wantFormat: plainTextFormat, + }, + "accept json, weighted values only": { + giveHeaders: http.Header{"Accept": {"application/jsoN;Q=0.1,text/html;q=1.1,application/xml;q=-1,*/*;q=0.8"}}, + wantFormat: jsonFormat, + }, + + "x-format json, depends on weight": { + giveHeaders: http.Header{"X-Format": {"application/jsoN,*/*;q=0.8"}}, + wantFormat: jsonFormat, + }, + "x-format xml": { + giveHeaders: http.Header{"X-Format": {"application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}}, + wantFormat: xmlFormat, + }, + + "content type has priority over accept": { + giveHeaders: http.Header{"Content-Type": {"text/plain"}, "Accept": {"application/xml"}}, + wantFormat: plainTextFormat, + }, + "accept has priority over x-format": { + giveHeaders: http.Header{"Accept": {"application/xml"}, "X-Format": {"text/plain"}}, + wantFormat: plainTextFormat, + }, + + "empty headers": { + giveHeaders: nil, + }, + "empty content type": { + giveHeaders: http.Header{"Content-Type": {" "}}, + }, + "wrong content type": { + giveHeaders: http.Header{"Content-Type": {"multipart/form-data; boundary=something"}}, + }, + "wrong accept": { + giveHeaders: http.Header{"Accept": {";q=foobar,bar/baz;;;;;application/xml"}}, + }, + "none on invalid input": { + giveHeaders: http.Header{"Content-Type": {"foo/bar; charset=utf-8"}, "Accept": {"foo/bar; charset=utf-8"}}, + }, + "completely unknown": { + giveHeaders: http.Header{"Content-Type": {"😀"}, "Accept": {"😄"}, "X-Format": {"😍"}}, + }, + } { + tt := _tt + + t.Run(name, func(t *testing.T) { + assert.Equal(t, tt.wantFormat, detectPreferredFormatForClient(tt.giveHeaders)) + }) + } +} diff --git a/internal/http/handlers/error_page/handler.go b/internal/http/handlers/error_page/handler.go index a1c354f..0e9a513 100644 --- a/internal/http/handlers/error_page/handler.go +++ b/internal/http/handlers/error_page/handler.go @@ -3,9 +3,6 @@ package error_page import ( "fmt" "net/http" - "path/filepath" - "strconv" - "strings" "gh.tarampamp.am/error-pages/internal/config" ) @@ -14,9 +11,9 @@ func New(cfg *config.Config) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var code uint16 - if fromUrl, okUrl := ExtractCodeFromURL(r.URL.Path); okUrl { + if fromUrl, okUrl := extractCodeFromURL(r.URL.Path); okUrl { code = fromUrl - } else if fromHeader, okHeaders := ExtractCodeFromHeaders(r.Header); okHeaders { + } else if fromHeader, okHeaders := extractCodeFromHeaders(r.Header); okHeaders { code = fromHeader } else { code = cfg.DefaultCodeToRender @@ -30,55 +27,38 @@ func New(cfg *config.Config) http.Handler { httpCode = http.StatusOK } - w.Header().Set("Content-Type", "text/html; charset=utf-8") // TODO: should depends on requested type + var format = detectPreferredFormatForClient(r.Header) + + switch headerName := "Content-Type"; format { + case jsonFormat: + w.Header().Set(headerName, "application/json; charset=utf-8") + case xmlFormat: + w.Header().Set(headerName, "application/xml; charset=utf-8") + case htmlFormat: + w.Header().Set(headerName, "text/html; charset=utf-8") + case plainTextFormat: + w.Header().Set(headerName, "text/plain; charset=utf-8") + default: + w.Header().Set(headerName, "text/html; charset=utf-8") + } + + // https://developers.google.com/search/docs/crawling-indexing/robots-meta-tag + // disallow indexing of the error pages + w.Header().Set("X-Robots-Tag", "noindex") + + if code >= 500 && code < 600 { + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After + // tell the client (search crawler) to retry the request after 120 seconds, it makes sense for the 5xx errors + w.Header().Set("Retry-After", "120") + } + + for _, proxyHeader := range cfg.ProxyHeaders { + if value := r.Header.Get(proxyHeader); value != "" { + w.Header().Set(proxyHeader, value) + } + } + w.WriteHeader(httpCode) _, _ = w.Write([]byte(fmt.Sprintf("error page for the code %d", code))) }) } - -// ExtractCodeFromURL extracts the error code from the given URL. -func ExtractCodeFromURL(url string) (uint16, bool) { - var parts = strings.SplitN(strings.TrimLeft(url, "/"), "/", 1) - - if len(parts) == 0 { - return 0, false - } - - var ( - fileName = parts[0] - ext = strings.ToLower(filepath.Ext(fileName)) // ".html", ".htm", ".%something%" or an empty string - ) - - if ext != "" && ext != ".html" && ext != ".htm" { - return 0, false - } else if ext != "" { - fileName = strings.TrimSuffix(fileName, ext) - } - - if code, err := strconv.ParseUint(fileName, 10, 16); err == nil && code > 0 && code < 999 { - return uint16(code), true - } - - return 0, false -} - -// URLContainsCode checks if the given URL contains an error code. -func URLContainsCode(url string) (ok bool) { _, ok = ExtractCodeFromURL(url); return } //nolint:nlreturn - -// ExtractCodeFromHeaders extracts the error code from the given headers. -func ExtractCodeFromHeaders(headers http.Header) (uint16, bool) { - if value := headers.Get("X-Code"); len(value) > 0 && len(value) <= 3 { - if code, err := strconv.ParseUint(value, 10, 16); err == nil && code > 0 && code < 999 { - return uint16(code), true - } - } - - return 0, false -} - -// HeadersContainCode checks if the given headers contain an error code. -func HeadersContainCode(headers http.Header) (ok bool) { - _, ok = ExtractCodeFromHeaders(headers) - - return -}