mirror of
https://github.com/tarampampam/error-pages.git
synced 2024-08-30 18:22:40 +00:00
222 lines
5.7 KiB
Go
222 lines
5.7 KiB
Go
package http_test
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/zap"
|
|
|
|
"gh.tarampamp.am/error-pages/internal/config"
|
|
appHttp "gh.tarampamp.am/error-pages/internal/http"
|
|
)
|
|
|
|
func TestRouting(t *testing.T) {
|
|
var (
|
|
srv = appHttp.NewServer(context.Background(), zap.NewNop())
|
|
cfg = config.New()
|
|
)
|
|
|
|
require.NoError(t, srv.Register(&cfg))
|
|
|
|
var baseUrl, stopServer = startServer(t, &srv)
|
|
|
|
defer stopServer()
|
|
|
|
t.Run("health", func(t *testing.T) {
|
|
var routes = []string{"/health/live", "/health", "/healthz", "/live"}
|
|
|
|
t.Run("success (get)", func(t *testing.T) {
|
|
for _, route := range routes {
|
|
status, body, headers := sendRequest(t, http.MethodGet, baseUrl+route)
|
|
|
|
assert.Equal(t, http.StatusOK, status)
|
|
assert.NotEmpty(t, body)
|
|
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
|
|
}
|
|
})
|
|
|
|
t.Run("success (head)", func(t *testing.T) {
|
|
for _, route := range routes {
|
|
status, body, headers := sendRequest(t, http.MethodHead, baseUrl+route)
|
|
|
|
assert.Equal(t, http.StatusOK, status)
|
|
assert.Empty(t, body)
|
|
assert.Empty(t, headers.Get("Content-Type"))
|
|
}
|
|
})
|
|
|
|
t.Run("method not allowed", func(t *testing.T) {
|
|
for _, route := range routes {
|
|
var url = baseUrl + route
|
|
|
|
for _, method := range []string{http.MethodDelete, http.MethodPatch, http.MethodPost, http.MethodPut} {
|
|
status, body, headers := sendRequest(t, method, url)
|
|
|
|
assert.Equal(t, http.StatusMethodNotAllowed, status)
|
|
assert.NotEmpty(t, body)
|
|
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
|
|
}
|
|
}
|
|
})
|
|
})
|
|
|
|
t.Run("version", func(t *testing.T) {
|
|
var url = baseUrl + "/version"
|
|
|
|
t.Run("success (get)", func(t *testing.T) {
|
|
status, body, headers := sendRequest(t, http.MethodGet, url)
|
|
|
|
assert.Equal(t, http.StatusOK, status)
|
|
assert.NotEmpty(t, body)
|
|
assert.Contains(t, headers.Get("Content-Type"), "application/json")
|
|
})
|
|
|
|
t.Run("success (head)", func(t *testing.T) {
|
|
status, body, headers := sendRequest(t, http.MethodHead, url)
|
|
|
|
assert.Equal(t, http.StatusOK, status)
|
|
assert.Empty(t, body)
|
|
assert.Empty(t, headers.Get("Content-Type"))
|
|
})
|
|
|
|
t.Run("method not allowed", func(t *testing.T) {
|
|
for _, method := range []string{http.MethodDelete, http.MethodPatch, http.MethodPost, http.MethodPut} {
|
|
status, body, headers := sendRequest(t, method, url)
|
|
|
|
assert.Equal(t, http.StatusMethodNotAllowed, status)
|
|
assert.NotEmpty(t, body)
|
|
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
|
|
}
|
|
})
|
|
})
|
|
|
|
t.Run("error page", func(t *testing.T) {
|
|
t.Skip("not implemented")
|
|
})
|
|
|
|
t.Run("errors handling", func(t *testing.T) {
|
|
var missingRoutes = []string{"/not-found", "/not-found/", "/not-found.html"}
|
|
|
|
t.Run("not found (get)", func(t *testing.T) {
|
|
for _, path := range missingRoutes {
|
|
status, body, headers := sendRequest(t, http.MethodGet, baseUrl+path)
|
|
|
|
assert.Equal(t, http.StatusNotFound, status)
|
|
assert.NotEmpty(t, body)
|
|
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
|
|
}
|
|
})
|
|
|
|
t.Run("not found (head)", func(t *testing.T) {
|
|
for _, path := range missingRoutes {
|
|
status, body, headers := sendRequest(t, http.MethodHead, baseUrl+path)
|
|
|
|
assert.Equal(t, http.StatusNotFound, status)
|
|
assert.Empty(t, body)
|
|
assert.Empty(t, headers.Get("Content-Type"))
|
|
}
|
|
})
|
|
|
|
t.Run("methods not allowed", func(t *testing.T) {
|
|
for _, path := range missingRoutes {
|
|
for _, method := range []string{http.MethodDelete, http.MethodPatch, http.MethodPost, http.MethodPut} {
|
|
status, body, headers := sendRequest(t, method, baseUrl+path)
|
|
|
|
assert.Equal(t, http.StatusMethodNotAllowed, status)
|
|
assert.NotEmpty(t, body)
|
|
assert.Contains(t, headers.Get("Content-Type"), "text/plain")
|
|
}
|
|
}
|
|
})
|
|
})
|
|
}
|
|
|
|
// sendRequest is a helper function to send an HTTP request and return its status code, body, and headers.
|
|
func sendRequest(t *testing.T, method, url string, headers ...map[string]string) (
|
|
status int,
|
|
body []byte,
|
|
_ http.Header,
|
|
) {
|
|
t.Helper()
|
|
|
|
req, reqErr := http.NewRequest(method, url, nil)
|
|
|
|
require.NoError(t, reqErr)
|
|
|
|
if len(headers) > 0 {
|
|
for key, value := range headers[0] {
|
|
req.Header.Add(key, value)
|
|
}
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
require.NoError(t, err)
|
|
|
|
body, _ = io.ReadAll(resp.Body)
|
|
|
|
require.NoError(t, resp.Body.Close())
|
|
|
|
return resp.StatusCode, body, resp.Header
|
|
}
|
|
|
|
// startServer is a helper function to start an HTTP server and return its base URL and a stop function.
|
|
func startServer(t *testing.T, srv *appHttp.Server) (_ string, stop func()) {
|
|
t.Helper()
|
|
|
|
var (
|
|
port = getFreeTcpPort(t)
|
|
hostPort = fmt.Sprintf("%s:%d", "127.0.0.1", port)
|
|
)
|
|
|
|
go func() {
|
|
if err := srv.Start("127.0.0.1", port); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
assert.NoError(t, err)
|
|
}
|
|
}()
|
|
|
|
// wait until the server starts
|
|
for {
|
|
if conn, err := net.DialTimeout("tcp", hostPort, time.Second); err == nil {
|
|
require.NoError(t, conn.Close())
|
|
|
|
break
|
|
}
|
|
|
|
<-time.After(5 * time.Millisecond)
|
|
}
|
|
|
|
return fmt.Sprintf("http://%s", hostPort), func() { assert.NoError(t, srv.Stop(10*time.Millisecond)) }
|
|
}
|
|
|
|
// getFreeTcpPort is a helper function to get a free TCP port number.
|
|
func getFreeTcpPort(t *testing.T) uint16 {
|
|
t.Helper()
|
|
|
|
l, lErr := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, lErr)
|
|
|
|
port := l.Addr().(*net.TCPAddr).Port
|
|
require.NoError(t, l.Close())
|
|
|
|
// make sure port is closed
|
|
for {
|
|
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
|
if err != nil {
|
|
break
|
|
}
|
|
|
|
require.NoError(t, conn.Close())
|
|
<-time.After(5 * time.Millisecond)
|
|
}
|
|
|
|
return uint16(port)
|
|
}
|