Add more unit tests

This commit is contained in:
Jamie Curnow 2023-08-02 15:32:16 +10:00
parent 155e09407f
commit 9ac0e8c943
No known key found for this signature in database
GPG Key ID: FFBB624C43388E9E
7 changed files with 363 additions and 2 deletions

View File

@ -0,0 +1,25 @@
package middleware_test
import (
"net/http"
"net/http/httptest"
"testing"
"npm/internal/api/middleware"
"github.com/stretchr/testify/assert"
)
func TestAccessControl(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
rr := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
assert.Nil(t, err)
accessControl := middleware.AccessControl(handler)
accessControl.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
}

View File

@ -0,0 +1,39 @@
package middleware_test
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
c "npm/internal/api/context"
"npm/internal/api/middleware"
"github.com/stretchr/testify/assert"
)
func TestBodyContext(t *testing.T) {
// Create a test request with a body
body := []byte(`{"name": "John", "age": 30}`)
req, err := http.NewRequest("POST", "/test", bytes.NewBuffer(body))
assert.Nil(t, err)
// Create a test response recorder
rr := httptest.NewRecorder()
// Create a test handler that checks the context for the body data
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bodyData := r.Context().Value(c.BodyCtxKey).([]byte)
assert.Equal(t, body, bodyData)
})
// Wrap the handler with the BodyContext middleware
mw := middleware.BodyContext()(handler)
// Call the middleware with the test request and response recorder
mw.ServeHTTP(rr, req)
// Check that the response status code is 200
status := rr.Code
assert.Equal(t, http.StatusOK, status)
}

View File

@ -0,0 +1,79 @@
package middleware_test
import (
"net/http"
"net/http/httptest"
"testing"
"npm/internal/api/middleware"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
)
func TestCors(t *testing.T) {
r := chi.NewRouter()
r.Use(middleware.Cors(r))
r.Get("/test", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test"))
})
req, err := http.NewRequest("GET", "/test", nil)
assert.Nil(t, err)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
assert.Equal(t, "GET,OPTIONS", rr.Header().Get("Access-Control-Allow-Methods"))
assert.Equal(t, "Authorization,Host,Content-Type,Connection,User-Agent,Cache-Control,Accept-Encoding", rr.Header().Get("Access-Control-Allow-Headers"))
assert.Equal(t, "test", rr.Body.String())
}
func TestCorsNoRoute(t *testing.T) {
r := chi.NewRouter()
r.Use(middleware.Cors(r))
req, err := http.NewRequest("GET", "/test", nil)
assert.Nil(t, err)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Methods"))
assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Headers"))
}
func TestOptions(t *testing.T) {
r := chi.NewRouter()
r.Use(middleware.Options(r))
r.Get("/test", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test"))
})
req, err := http.NewRequest("OPTIONS", "/test", nil)
assert.Nil(t, err)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "application/json", rr.Header().Get("Content-Type"))
assert.Equal(t, "{}", rr.Body.String())
}
func TestOptionsNoRoute(t *testing.T) {
r := chi.NewRouter()
r.Use(middleware.Options(r))
req, err := http.NewRequest("OPTIONS", "/test", nil)
assert.Nil(t, err)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)
assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Methods"))
assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Headers"))
}

View File

@ -0,0 +1,61 @@
package middleware_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"npm/internal/api/middleware"
"npm/internal/config"
)
func TestEnforceSetup(t *testing.T) {
tests := []struct {
name string
shouldBeSetup bool
isSetup bool
expectedCode int
}{
{
name: "should allow request when setup is expected and is setup",
shouldBeSetup: true,
isSetup: true,
expectedCode: http.StatusOK,
},
{
name: "should error when setup is expected but not setup",
shouldBeSetup: true,
isSetup: false,
expectedCode: http.StatusForbidden,
},
{
name: "should allow request when setup is not expected and not setup",
shouldBeSetup: false,
isSetup: false,
expectedCode: http.StatusOK,
},
{
name: "should error when setup is not expected but is setup",
shouldBeSetup: false,
isSetup: true,
expectedCode: http.StatusForbidden,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config.IsSetup = tt.isSetup
handler := middleware.EnforceSetup(tt.shouldBeSetup)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, tt.expectedCode, w.Code)
})
}
}

View File

@ -0,0 +1,69 @@
package middleware_test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
c "npm/internal/api/context"
"npm/internal/api/middleware"
"github.com/stretchr/testify/assert"
)
func TestExpansion(t *testing.T) {
t.Run("with expand query param", func(t *testing.T) {
req, err := http.NewRequest("GET", "/path?expand=item1,item2", nil)
assert.NoError(t, err)
rr := httptest.NewRecorder()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expand := middleware.GetExpandFromContext(r)
assert.Equal(t, []string{"item1", "item2"}, expand)
})
middleware.Expansion(handler).ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
})
t.Run("without expand query param", func(t *testing.T) {
req, err := http.NewRequest("GET", "/path", nil)
assert.NoError(t, err)
rr := httptest.NewRecorder()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expand := middleware.GetExpandFromContext(r)
assert.Nil(t, expand)
})
middleware.Expansion(handler).ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
})
}
func TestGetExpandFromContext(t *testing.T) {
t.Run("with context value", func(t *testing.T) {
req, err := http.NewRequest("GET", "/path", nil)
assert.NoError(t, err)
ctx := req.Context()
ctx = context.WithValue(ctx, c.ExpansionCtxKey, []string{"item1", "item2"})
req = req.WithContext(ctx)
expand := middleware.GetExpandFromContext(req)
assert.Equal(t, []string{"item1", "item2"}, expand)
})
t.Run("without context value", func(t *testing.T) {
req, err := http.NewRequest("GET", "/path", nil)
assert.NoError(t, err)
expand := middleware.GetExpandFromContext(req)
assert.Nil(t, expand)
})
}

View File

@ -0,0 +1,89 @@
package middleware_test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
c "npm/internal/api/context"
"npm/internal/api/middleware"
"npm/internal/entity/user"
"npm/internal/model"
"npm/internal/tags"
"github.com/stretchr/testify/assert"
)
func TestListQuery(t *testing.T) {
tests := []struct {
name string
queryParams string
expectedStatus int
}{
{
name: "valid query params",
queryParams: "?name:contains=John&sort=name.desc",
expectedStatus: http.StatusOK,
},
{
name: "invalid sort field",
queryParams: "?name:contains=John&sort=invalid_field",
expectedStatus: http.StatusBadRequest,
},
{
name: "invalid filter value",
queryParams: "?name=123",
expectedStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequest("GET", "/test"+tt.queryParams, nil)
assert.NoError(t, err)
testObj := user.Model{}
ctx := context.Background()
ctx = context.WithValue(ctx, c.FiltersCtxKey, tags.GetFilterSchema(testObj))
rr := httptest.NewRecorder()
handler := middleware.ListQuery(testObj)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
handler.ServeHTTP(rr, req.WithContext(ctx))
assert.Equal(t, tt.expectedStatus, rr.Code)
})
}
}
func TestGetFiltersFromContext(t *testing.T) {
req, err := http.NewRequest("GET", "/test", nil)
assert.NoError(t, err)
filters := []model.Filter{
{Field: "name", Modifier: "contains", Value: []string{"test"}},
}
ctx := context.WithValue(req.Context(), c.FiltersCtxKey, filters)
req = req.WithContext(ctx)
result := middleware.GetFiltersFromContext(req)
assert.Equal(t, filters, result)
}
func TestGetSortFromContext(t *testing.T) {
req, err := http.NewRequest("GET", "/test", nil)
assert.NoError(t, err)
sorts := []model.Sort{
{Field: "name", Direction: "asc"},
}
ctx := context.WithValue(req.Context(), c.SortCtxKey, sorts)
req = req.WithContext(ctx)
result := middleware.GetSortFromContext(req)
assert.Equal(t, sorts, result)
}

View File

@ -13,8 +13,7 @@ func PrettyPrint(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
prettyStr := r.URL.Query().Get("pretty") prettyStr := r.URL.Query().Get("pretty")
if prettyStr == "1" || prettyStr == "true" { if prettyStr == "1" || prettyStr == "true" {
ctx := r.Context() ctx := context.WithValue(r.Context(), c.PrettyPrintCtxKey, true)
ctx = context.WithValue(ctx, c.PrettyPrintCtxKey, true)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
} else { } else {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)