error-pages/internal/http/server_test.go

397 lines
12 KiB
Go
Raw Normal View History

2024-06-23 12:12:06 +00:00
package http_test
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gh.tarampamp.am/error-pages/internal/config"
appHttp "gh.tarampamp.am/error-pages/internal/http"
2024-06-24 15:28:03 +00:00
"gh.tarampamp.am/error-pages/internal/logger"
2024-06-23 12:12:06 +00:00
)
2024-06-25 18:26:34 +00:00
// TestRouting in fact is a test for the whole server, because it tests all the routes and their handlers.
2024-06-23 12:12:06 +00:00
func TestRouting(t *testing.T) {
var (
2024-06-24 15:28:03 +00:00
srv = appHttp.NewServer(context.Background(), logger.NewNop())
2024-06-23 12:12:06 +00:00
cfg = config.New()
)
2024-06-25 18:26:34 +00:00
assert.NoError(t, cfg.Templates.Add("unit-test", `<!DOCTYPE html>
<html lang="en">
<h1>Error {{ code }}: {{ message }}</h1>{{ if description }}
<h2>{{ description }}</h2>{{ end }}{{ if show_details }}
<pre>
Host: {{ host }}
Original URI: {{ original_uri }}
Forwarded For: {{ forwarded_for }}
Namespace: {{ namespace }}
Ingress Name: {{ ingress_name }}
Service Name: {{ service_name }}
Service Port: {{ service_port }}
Request ID: {{ request_id }}
Timestamp: {{ now.Unix }}
</pre>{{ end }}
</html>`))
cfg.TemplateName = "unit-test"
2024-06-23 12:12:06 +00:00
require.NoError(t, srv.Register(&cfg))
var baseUrl, stopServer = startServer(t, &srv)
defer stopServer()
t.Run("health", func(t *testing.T) {
var routes = []string{"/health/live", "/health", "/healthz", "/live"}
t.Run("success (get)", func(t *testing.T) {
for _, route := range routes {
status, body, headers := sendRequest(t, http.MethodGet, baseUrl+route)
assert.Equal(t, http.StatusOK, status)
assert.NotEmpty(t, body)
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
}
})
t.Run("success (head)", func(t *testing.T) {
for _, route := range routes {
status, body, headers := sendRequest(t, http.MethodHead, baseUrl+route)
assert.Equal(t, http.StatusOK, status)
assert.Empty(t, body)
assert.Empty(t, headers.Get("Content-Type"))
}
})
t.Run("method not allowed", func(t *testing.T) {
for _, route := range routes {
var url = baseUrl + route
for _, method := range []string{http.MethodDelete, http.MethodPatch, http.MethodPost, http.MethodPut} {
status, body, headers := sendRequest(t, method, url)
assert.Equal(t, http.StatusMethodNotAllowed, status)
assert.NotEmpty(t, body)
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
}
}
})
})
t.Run("version", func(t *testing.T) {
var url = baseUrl + "/version"
t.Run("success (get)", func(t *testing.T) {
status, body, headers := sendRequest(t, http.MethodGet, url)
assert.Equal(t, http.StatusOK, status)
assert.NotEmpty(t, body)
assert.Contains(t, headers.Get("Content-Type"), "application/json")
})
t.Run("success (head)", func(t *testing.T) {
status, body, headers := sendRequest(t, http.MethodHead, url)
assert.Equal(t, http.StatusOK, status)
assert.Empty(t, body)
assert.Empty(t, headers.Get("Content-Type"))
})
t.Run("method not allowed", func(t *testing.T) {
for _, method := range []string{http.MethodDelete, http.MethodPatch, http.MethodPost, http.MethodPut} {
status, body, headers := sendRequest(t, method, url)
assert.Equal(t, http.StatusMethodNotAllowed, status)
assert.NotEmpty(t, body)
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
}
})
})
t.Run("error page", func(t *testing.T) {
2024-06-23 19:48:51 +00:00
t.Run("success", func(t *testing.T) {
2024-06-25 18:26:34 +00:00
t.Run("index, default (plain text by default)", func(t *testing.T) {
var status, body, headers = sendRequest(t, http.MethodGet, baseUrl+"/")
assert.Equal(t, http.StatusOK, status)
assert.Contains(t, string(body), "404: Not Found")
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
})
2024-06-23 19:48:51 +00:00
2024-06-25 18:26:34 +00:00
t.Run("index, default (json format)", func(t *testing.T) {
var status, body, headers = sendRequest(t,
http.MethodGet, baseUrl+"/", map[string]string{"Accept": "application/json"},
)
2024-06-23 19:48:51 +00:00
2024-06-25 18:26:34 +00:00
assert.Equal(t, http.StatusOK, status)
assert.Contains(t, string(body), `"code": 404`)
assert.Contains(t, headers.Get("Content-Type"), "application/json")
})
2024-06-23 19:48:51 +00:00
2024-06-25 18:26:34 +00:00
t.Run("index, default (xml format)", func(t *testing.T) {
var status, body, headers = sendRequest(t,
http.MethodGet, baseUrl+"/", map[string]string{"Accept": "application/xml"},
)
2024-06-23 19:48:51 +00:00
assert.Equal(t, http.StatusOK, status)
2024-06-25 18:26:34 +00:00
assert.Contains(t, string(body), `<code>404</code>`)
assert.Contains(t, headers.Get("Content-Type"), "application/xml")
})
t.Run("index, default (html format)", func(t *testing.T) {
var status, body, headers = sendRequest(t,
http.MethodGet, baseUrl+"/", map[string]string{"Content-Type": "text/html"},
)
assert.Equal(t, http.StatusOK, status)
assert.Contains(t, string(body), `<h1>Error 404: Not Found</h1>`)
assert.Contains(t, headers.Get("Content-Type"), "text/html")
2024-06-23 19:48:51 +00:00
})
t.Run("index, code in HTTP header", func(t *testing.T) {
var status, body, headers = sendRequest(t, http.MethodGet, baseUrl+"/", map[string]string{"X-Code": "404"})
assert.Equal(t, http.StatusOK, status) // because of [cfg.RespondWithSameHTTPCode] is false by default
2024-06-25 18:26:34 +00:00
assert.Contains(t, string(body), "404: Not Found")
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
2024-06-23 19:48:51 +00:00
})
t.Run("code in URL, .html", func(t *testing.T) {
var status, body, headers = sendRequest(t, http.MethodGet, baseUrl+"/500.html")
assert.Equal(t, http.StatusOK, status)
2024-06-25 18:26:34 +00:00
assert.Contains(t, string(body), "500: Internal Server Error")
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
2024-06-23 19:48:51 +00:00
})
t.Run("code in URL, .htm", func(t *testing.T) {
var status, body, headers = sendRequest(t, http.MethodGet, baseUrl+"/409.htm")
assert.Equal(t, http.StatusOK, status)
2024-06-25 18:26:34 +00:00
assert.Contains(t, string(body), "409: Conflict")
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
2024-06-23 19:48:51 +00:00
})
t.Run("code in URL, without extension", func(t *testing.T) {
var status, body, headers = sendRequest(t, http.MethodGet, baseUrl+"/405")
assert.Equal(t, http.StatusOK, status)
2024-06-25 18:26:34 +00:00
assert.Contains(t, string(body), "405: Method Not Allowed")
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
2024-06-23 19:48:51 +00:00
})
t.Run("code in the URL have higher priority than in the headers", func(t *testing.T) {
var status, body, headers = sendRequest(t, http.MethodGet, baseUrl+"/405", map[string]string{"X-Code": "404"})
assert.Equal(t, http.StatusOK, status)
2024-06-25 18:26:34 +00:00
assert.Contains(t, string(body), "405: Method Not Allowed")
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
2024-06-23 19:48:51 +00:00
})
t.Run("invalid code in HTTP header (with a string)", func(t *testing.T) {
var status, body, headers = sendRequest(t, http.MethodGet, baseUrl+"/", map[string]string{"X-Code": "foobar"})
assert.Equal(t, http.StatusOK, status)
2024-06-25 18:26:34 +00:00
assert.Contains(t, string(body), "404: Not Found")
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
2024-06-23 19:48:51 +00:00
})
t.Run("invalid code in HTTP header (too small)", func(t *testing.T) {
var status, body, headers = sendRequest(t, http.MethodGet, baseUrl+"/", map[string]string{"X-Code": "0"})
assert.Equal(t, http.StatusOK, status)
2024-06-25 18:26:34 +00:00
assert.Contains(t, string(body), "404: Not Found")
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
2024-06-23 19:48:51 +00:00
})
t.Run("invalid code in HTTP header (too big)", func(t *testing.T) {
var status, body, headers = sendRequest(t, http.MethodGet, baseUrl+"/", map[string]string{"X-Code": "1000"})
assert.Equal(t, http.StatusOK, status)
2024-06-25 18:26:34 +00:00
assert.Contains(t, string(body), "404: Not Found")
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
2024-06-23 19:48:51 +00:00
})
})
t.Run("failure", func(t *testing.T) {
var assertIsNotErrorPage = func(t *testing.T, body []byte) {
t.Helper()
assert.NotContains(t, string(body), "error page") // FIXME
}
t.Run("invalid code in URL (too small)", func(t *testing.T) {
var status, body, _ = sendRequest(t, http.MethodGet, baseUrl+"/0.html")
assert.Equal(t, http.StatusNotFound, status)
assertIsNotErrorPage(t, body)
})
t.Run("invalid code in URL (too big)", func(t *testing.T) {
var status, body, _ = sendRequest(t, http.MethodGet, baseUrl+"/1000.html")
assert.Equal(t, http.StatusNotFound, status)
assertIsNotErrorPage(t, body)
})
t.Run("invalid code in URL (with a string suffix)", func(t *testing.T) {
var status, body, _ = sendRequest(t, http.MethodGet, baseUrl+"/404foobar.html")
assert.Equal(t, http.StatusNotFound, status)
assertIsNotErrorPage(t, body)
})
t.Run("invalid code in URL (with a string prefix)", func(t *testing.T) {
var status, body, _ = sendRequest(t, http.MethodGet, baseUrl+"/foobar404.html")
assert.Equal(t, http.StatusNotFound, status)
assertIsNotErrorPage(t, body)
})
t.Run("invalid code in URL (with a string)", func(t *testing.T) {
var status, body, _ = sendRequest(t, http.MethodGet, baseUrl+"/foobar.html")
assert.Equal(t, http.StatusNotFound, status)
assertIsNotErrorPage(t, body)
})
t.Run("invalid HTTP methods", func(t *testing.T) {
for _, method := range []string{http.MethodDelete, http.MethodPatch, http.MethodPost, http.MethodPut} {
var status, body, _ = sendRequest(t, method, baseUrl+"/404.html")
assert.Equal(t, http.StatusMethodNotAllowed, status)
assertIsNotErrorPage(t, body)
}
})
})
2024-06-23 12:12:06 +00:00
})
t.Run("errors handling", func(t *testing.T) {
var missingRoutes = []string{"/not-found", "/not-found/", "/not-found.html"}
t.Run("not found (get)", func(t *testing.T) {
for _, path := range missingRoutes {
status, body, headers := sendRequest(t, http.MethodGet, baseUrl+path)
assert.Equal(t, http.StatusNotFound, status)
assert.NotEmpty(t, body)
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
}
})
t.Run("not found (head)", func(t *testing.T) {
for _, path := range missingRoutes {
status, body, headers := sendRequest(t, http.MethodHead, baseUrl+path)
assert.Equal(t, http.StatusNotFound, status)
assert.Empty(t, body)
assert.Empty(t, headers.Get("Content-Type"))
}
})
t.Run("methods not allowed", func(t *testing.T) {
for _, path := range missingRoutes {
for _, method := range []string{http.MethodDelete, http.MethodPatch, http.MethodPost, http.MethodPut} {
status, body, headers := sendRequest(t, method, baseUrl+path)
assert.Equal(t, http.StatusMethodNotAllowed, status)
assert.NotEmpty(t, body)
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
}
}
})
})
}
// sendRequest is a helper function to send an HTTP request and return its status code, body, and headers.
func sendRequest(t *testing.T, method, url string, headers ...map[string]string) (
status int,
body []byte,
_ http.Header,
) {
t.Helper()
req, reqErr := http.NewRequest(method, url, nil)
require.NoError(t, reqErr)
if len(headers) > 0 {
for key, value := range headers[0] {
req.Header.Add(key, value)
}
}
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
body, _ = io.ReadAll(resp.Body)
require.NoError(t, resp.Body.Close())
return resp.StatusCode, body, resp.Header
}
// startServer is a helper function to start an HTTP server and return its base URL and a stop function.
func startServer(t *testing.T, srv *appHttp.Server) (_ string, stop func()) {
t.Helper()
var (
port = getFreeTcpPort(t)
hostPort = fmt.Sprintf("%s:%d", "127.0.0.1", port)
)
go func() {
if err := srv.Start("127.0.0.1", port); err != nil && !errors.Is(err, http.ErrServerClosed) {
assert.NoError(t, err)
}
}()
// wait until the server starts
for {
if conn, err := net.DialTimeout("tcp", hostPort, time.Second); err == nil {
require.NoError(t, conn.Close())
break
}
<-time.After(5 * time.Millisecond)
}
return fmt.Sprintf("http://%s", hostPort), func() { assert.NoError(t, srv.Stop(10*time.Millisecond)) }
}
// getFreeTcpPort is a helper function to get a free TCP port number.
func getFreeTcpPort(t *testing.T) uint16 {
t.Helper()
l, lErr := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, lErr)
port := l.Addr().(*net.TCPAddr).Port
require.NoError(t, l.Close())
// make sure port is closed
for {
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
break
}
require.NoError(t, conn.Close())
<-time.After(5 * time.Millisecond)
}
return uint16(port)
}