From 9ac0e8c943d0d272527d81d98a88a97e8c067dd5 Mon Sep 17 00:00:00 2001 From: Jamie Curnow Date: Wed, 2 Aug 2023 15:32:16 +1000 Subject: [PATCH] Add more unit tests --- .../api/middleware/access_control_test.go | 25 ++++++ .../api/middleware/body_context_test.go | 39 ++++++++ backend/internal/api/middleware/cors_test.go | 79 ++++++++++++++++ .../api/middleware/enforce_setup_test.go | 61 +++++++++++++ .../internal/api/middleware/expansion_test.go | 69 ++++++++++++++ .../api/middleware/list_query_test.go | 89 +++++++++++++++++++ .../internal/api/middleware/pretty_print.go | 3 +- 7 files changed, 363 insertions(+), 2 deletions(-) create mode 100644 backend/internal/api/middleware/access_control_test.go create mode 100644 backend/internal/api/middleware/body_context_test.go create mode 100644 backend/internal/api/middleware/cors_test.go create mode 100644 backend/internal/api/middleware/enforce_setup_test.go create mode 100644 backend/internal/api/middleware/expansion_test.go create mode 100644 backend/internal/api/middleware/list_query_test.go diff --git a/backend/internal/api/middleware/access_control_test.go b/backend/internal/api/middleware/access_control_test.go new file mode 100644 index 00000000..001736c0 --- /dev/null +++ b/backend/internal/api/middleware/access_control_test.go @@ -0,0 +1,25 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "npm/internal/api/middleware" + + "github.com/stretchr/testify/assert" +) + +func TestAccessControl(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + rr := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/", nil) + assert.Nil(t, err) + accessControl := middleware.AccessControl(handler) + accessControl.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin")) +} diff --git a/backend/internal/api/middleware/body_context_test.go b/backend/internal/api/middleware/body_context_test.go new file mode 100644 index 00000000..97ae734b --- /dev/null +++ b/backend/internal/api/middleware/body_context_test.go @@ -0,0 +1,39 @@ +package middleware_test + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + c "npm/internal/api/context" + "npm/internal/api/middleware" + + "github.com/stretchr/testify/assert" +) + +func TestBodyContext(t *testing.T) { + // Create a test request with a body + body := []byte(`{"name": "John", "age": 30}`) + req, err := http.NewRequest("POST", "/test", bytes.NewBuffer(body)) + assert.Nil(t, err) + + // Create a test response recorder + rr := httptest.NewRecorder() + + // Create a test handler that checks the context for the body data + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bodyData := r.Context().Value(c.BodyCtxKey).([]byte) + assert.Equal(t, body, bodyData) + }) + + // Wrap the handler with the BodyContext middleware + mw := middleware.BodyContext()(handler) + + // Call the middleware with the test request and response recorder + mw.ServeHTTP(rr, req) + + // Check that the response status code is 200 + status := rr.Code + assert.Equal(t, http.StatusOK, status) +} diff --git a/backend/internal/api/middleware/cors_test.go b/backend/internal/api/middleware/cors_test.go new file mode 100644 index 00000000..ffa1802c --- /dev/null +++ b/backend/internal/api/middleware/cors_test.go @@ -0,0 +1,79 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "npm/internal/api/middleware" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" +) + +func TestCors(t *testing.T) { + r := chi.NewRouter() + r.Use(middleware.Cors(r)) + + r.Get("/test", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("test")) + }) + + req, err := http.NewRequest("GET", "/test", nil) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, "GET,OPTIONS", rr.Header().Get("Access-Control-Allow-Methods")) + assert.Equal(t, "Authorization,Host,Content-Type,Connection,User-Agent,Cache-Control,Accept-Encoding", rr.Header().Get("Access-Control-Allow-Headers")) + assert.Equal(t, "test", rr.Body.String()) +} + +func TestCorsNoRoute(t *testing.T) { + r := chi.NewRouter() + r.Use(middleware.Cors(r)) + + req, err := http.NewRequest("GET", "/test", nil) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Methods")) + assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Headers")) +} + +func TestOptions(t *testing.T) { + r := chi.NewRouter() + r.Use(middleware.Options(r)) + + r.Get("/test", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("test")) + }) + + req, err := http.NewRequest("OPTIONS", "/test", nil) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) + assert.Equal(t, "{}", rr.Body.String()) +} + +func TestOptionsNoRoute(t *testing.T) { + r := chi.NewRouter() + r.Use(middleware.Options(r)) + + req, err := http.NewRequest("OPTIONS", "/test", nil) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Methods")) + assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Headers")) +} diff --git a/backend/internal/api/middleware/enforce_setup_test.go b/backend/internal/api/middleware/enforce_setup_test.go new file mode 100644 index 00000000..ce542bd9 --- /dev/null +++ b/backend/internal/api/middleware/enforce_setup_test.go @@ -0,0 +1,61 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "npm/internal/api/middleware" + "npm/internal/config" +) + +func TestEnforceSetup(t *testing.T) { + tests := []struct { + name string + shouldBeSetup bool + isSetup bool + expectedCode int + }{ + { + name: "should allow request when setup is expected and is setup", + shouldBeSetup: true, + isSetup: true, + expectedCode: http.StatusOK, + }, + { + name: "should error when setup is expected but not setup", + shouldBeSetup: true, + isSetup: false, + expectedCode: http.StatusForbidden, + }, + { + name: "should allow request when setup is not expected and not setup", + shouldBeSetup: false, + isSetup: false, + expectedCode: http.StatusOK, + }, + { + name: "should error when setup is not expected but is setup", + shouldBeSetup: false, + isSetup: true, + expectedCode: http.StatusForbidden, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config.IsSetup = tt.isSetup + + handler := middleware.EnforceSetup(tt.shouldBeSetup)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, tt.expectedCode, w.Code) + }) + } +} diff --git a/backend/internal/api/middleware/expansion_test.go b/backend/internal/api/middleware/expansion_test.go new file mode 100644 index 00000000..da38d378 --- /dev/null +++ b/backend/internal/api/middleware/expansion_test.go @@ -0,0 +1,69 @@ +package middleware_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + c "npm/internal/api/context" + "npm/internal/api/middleware" + + "github.com/stretchr/testify/assert" +) + +func TestExpansion(t *testing.T) { + t.Run("with expand query param", func(t *testing.T) { + req, err := http.NewRequest("GET", "/path?expand=item1,item2", nil) + assert.NoError(t, err) + + rr := httptest.NewRecorder() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expand := middleware.GetExpandFromContext(r) + assert.Equal(t, []string{"item1", "item2"}, expand) + }) + + middleware.Expansion(handler).ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + }) + + t.Run("without expand query param", func(t *testing.T) { + req, err := http.NewRequest("GET", "/path", nil) + assert.NoError(t, err) + + rr := httptest.NewRecorder() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expand := middleware.GetExpandFromContext(r) + assert.Nil(t, expand) + }) + + middleware.Expansion(handler).ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + }) +} + +func TestGetExpandFromContext(t *testing.T) { + t.Run("with context value", func(t *testing.T) { + req, err := http.NewRequest("GET", "/path", nil) + assert.NoError(t, err) + + ctx := req.Context() + ctx = context.WithValue(ctx, c.ExpansionCtxKey, []string{"item1", "item2"}) + req = req.WithContext(ctx) + + expand := middleware.GetExpandFromContext(req) + assert.Equal(t, []string{"item1", "item2"}, expand) + }) + + t.Run("without context value", func(t *testing.T) { + req, err := http.NewRequest("GET", "/path", nil) + assert.NoError(t, err) + + expand := middleware.GetExpandFromContext(req) + assert.Nil(t, expand) + }) +} diff --git a/backend/internal/api/middleware/list_query_test.go b/backend/internal/api/middleware/list_query_test.go new file mode 100644 index 00000000..53973fad --- /dev/null +++ b/backend/internal/api/middleware/list_query_test.go @@ -0,0 +1,89 @@ +package middleware_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + c "npm/internal/api/context" + "npm/internal/api/middleware" + "npm/internal/entity/user" + "npm/internal/model" + "npm/internal/tags" + + "github.com/stretchr/testify/assert" +) + +func TestListQuery(t *testing.T) { + tests := []struct { + name string + queryParams string + expectedStatus int + }{ + { + name: "valid query params", + queryParams: "?name:contains=John&sort=name.desc", + expectedStatus: http.StatusOK, + }, + { + name: "invalid sort field", + queryParams: "?name:contains=John&sort=invalid_field", + expectedStatus: http.StatusBadRequest, + }, + { + name: "invalid filter value", + queryParams: "?name=123", + expectedStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := http.NewRequest("GET", "/test"+tt.queryParams, nil) + assert.NoError(t, err) + + testObj := user.Model{} + + ctx := context.Background() + ctx = context.WithValue(ctx, c.FiltersCtxKey, tags.GetFilterSchema(testObj)) + + rr := httptest.NewRecorder() + handler := middleware.ListQuery(testObj)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + handler.ServeHTTP(rr, req.WithContext(ctx)) + + assert.Equal(t, tt.expectedStatus, rr.Code) + }) + } +} + +func TestGetFiltersFromContext(t *testing.T) { + req, err := http.NewRequest("GET", "/test", nil) + assert.NoError(t, err) + + filters := []model.Filter{ + {Field: "name", Modifier: "contains", Value: []string{"test"}}, + } + ctx := context.WithValue(req.Context(), c.FiltersCtxKey, filters) + req = req.WithContext(ctx) + + result := middleware.GetFiltersFromContext(req) + assert.Equal(t, filters, result) +} + +func TestGetSortFromContext(t *testing.T) { + req, err := http.NewRequest("GET", "/test", nil) + assert.NoError(t, err) + + sorts := []model.Sort{ + {Field: "name", Direction: "asc"}, + } + ctx := context.WithValue(req.Context(), c.SortCtxKey, sorts) + req = req.WithContext(ctx) + + result := middleware.GetSortFromContext(req) + assert.Equal(t, sorts, result) +} diff --git a/backend/internal/api/middleware/pretty_print.go b/backend/internal/api/middleware/pretty_print.go index 270d2a24..7baa14a2 100644 --- a/backend/internal/api/middleware/pretty_print.go +++ b/backend/internal/api/middleware/pretty_print.go @@ -13,8 +13,7 @@ func PrettyPrint(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { prettyStr := r.URL.Query().Get("pretty") if prettyStr == "1" || prettyStr == "true" { - ctx := r.Context() - ctx = context.WithValue(ctx, c.PrettyPrintCtxKey, true) + ctx := context.WithValue(r.Context(), c.PrettyPrintCtxKey, true) next.ServeHTTP(w, r.WithContext(ctx)) } else { next.ServeHTTP(w, r)