Add backend unit tests

This commit is contained in:
Jamie Curnow
2023-07-25 11:59:02 +10:00
parent 72b071dbaa
commit b123ca4fd0
17 changed files with 399 additions and 283 deletions

View File

@ -26,7 +26,7 @@ func GetCertificates() func(http.ResponseWriter, *http.Request) {
return
}
certificates, err := certificate.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r))
certificates, err := certificate.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r))
if err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
} else {
@ -41,7 +41,7 @@ func GetCertificate() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if item := getCertificateFromRequest(w, r); item != nil {
// nolint: errcheck,gosec
item.Expand(getExpandFromContext(r))
item.Expand(middleware.GetExpandFromContext(r))
h.ResultResponseJSON(w, r, http.StatusOK, item)
}
}

View File

@ -4,8 +4,6 @@ import (
"net/http"
"strconv"
"npm/internal/api/context"
"npm/internal/api/middleware"
"npm/internal/model"
"github.com/go-chi/chi/v5"
@ -15,7 +13,7 @@ import (
const defaultLimit = 10
func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) {
var pageInfo model.PageInfo
pageInfo := model.PageInfo{}
var err error
pageInfo.Offset, pageInfo.Limit, err = getPagination(r)
@ -23,7 +21,7 @@ func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) {
return pageInfo, err
}
pageInfo.Sort = middleware.GetSortFromContext(r)
// pageInfo.Sort = middleware.GetSortFromContext(r)
return pageInfo, nil
}
@ -93,12 +91,3 @@ func getPagination(r *http.Request) (int, int, error) {
return offset, limit, nil
}
// getExpandFromContext returns the Expansion setting
func getExpandFromContext(r *http.Request) []string {
expand, ok := r.Context().Value(context.ExpansionCtxKey).([]string)
if !ok {
return nil
}
return expand
}

View File

@ -0,0 +1,118 @@
package handler
import (
"net/http"
"net/http/httptest"
"npm/internal/model"
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetPageInfoFromRequest(t *testing.T) {
t.Run("basic test", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/hosts", nil)
p, err := getPageInfoFromRequest(r)
var nilStringSlice []string
var nilSortSlice []model.Sort
defaultSort := model.Sort{Field: "name", Direction: "asc"}
assert.Equal(t, nil, err)
assert.Equal(t, 0, p.Offset)
assert.Equal(t, 10, p.Limit)
assert.Equal(t, nilStringSlice, p.Expand)
assert.Equal(t, nilSortSlice, p.Sort)
assert.Equal(t, []model.Sort{defaultSort}, p.GetSort(defaultSort))
})
}
func TestGetQueryVarInt(t *testing.T) {
type want struct {
val int
err string
}
tests := []struct {
name string
url string
queryVar string
required bool
defaultValue int
want want
}{
{
name: "simple default",
url: "/hosts",
queryVar: "something",
required: false,
defaultValue: 100,
want: want{
val: 100,
err: "",
},
},
{
name: "required flag",
url: "/hosts",
queryVar: "something",
required: true,
want: want{
val: 0,
err: "something was not supplied in the request",
},
},
{
name: "simple get",
url: "/hosts?something=50",
queryVar: "something",
required: true,
want: want{
val: 50,
err: "",
},
},
{
name: "invalid number",
url: "/hosts?something=aaa",
queryVar: "something",
required: true,
want: want{
val: 0,
err: "",
},
},
{
name: "preceding zeros",
url: "/hosts?something=0000050",
queryVar: "something",
required: true,
want: want{
val: 50,
err: "",
},
},
{
name: "decimals",
url: "/hosts?something=50.50",
queryVar: "something",
required: true,
want: want{
val: 0,
err: "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, tt.url, nil)
val, err := getQueryVarInt(r, tt.queryVar, tt.required, tt.defaultValue)
assert.Equal(t, tt.want.val, val)
if tt.want.err != "" {
assert.NotEqual(t, nil, err)
assert.Equal(t, tt.want.err, err.Error())
}
})
}
}

View File

@ -26,7 +26,7 @@ func GetHosts() func(http.ResponseWriter, *http.Request) {
return
}
hosts, err := host.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r))
hosts, err := host.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r))
if err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
} else {
@ -52,7 +52,7 @@ func GetHost() func(http.ResponseWriter, *http.Request) {
h.NotFound(w, r)
case nil:
// nolint: errcheck,gosec
item.Expand(getExpandFromContext(r))
item.Expand(middleware.GetExpandFromContext(r))
h.ResultResponseJSON(w, r, http.StatusOK, item)
default:
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
@ -132,7 +132,7 @@ func UpdateHost() func(http.ResponseWriter, *http.Request) {
}
// nolint: errcheck,gosec
hostObject.Expand(getExpandFromContext(r))
hostObject.Expand(middleware.GetExpandFromContext(r))
configureHost(hostObject)

View File

@ -25,14 +25,21 @@ func NotFound() func(http.ResponseWriter, *http.Request) {
assetsSub, _ = fs.Sub(embed.Assets, "assets")
return func(w http.ResponseWriter, r *http.Request) {
defaultFile := "index.html"
path := strings.TrimLeft(r.URL.Path, "/")
isAPI := false
if len(path) >= 3 && path[0:3] == "api" {
isAPI = true
}
if path == "" {
path = "index.html"
path = defaultFile
}
err := tryRead(assetsSub, path, w)
if err == errIsDir {
err = tryRead(assetsSub, "index.html", w)
err = tryRead(assetsSub, defaultFile, w)
if err != nil {
h.NotFound(w, r)
}
@ -40,6 +47,16 @@ func NotFound() func(http.ResponseWriter, *http.Request) {
return
}
// Check if the path has an extension and not in the "/api" path
ext := filepath.Ext(path)
if !isAPI && ext == "" {
// Not an api endpoint and Not a specific file, return the default index file
err := tryRead(assetsSub, defaultFile, w)
if err == nil {
return
}
}
h.NotFound(w, r)
}
}

View File

@ -27,7 +27,7 @@ func GetUpstreams() func(http.ResponseWriter, *http.Request) {
return
}
items, err := upstream.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r))
items, err := upstream.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r))
if err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
} else {
@ -53,7 +53,7 @@ func GetUpstream() func(http.ResponseWriter, *http.Request) {
h.NotFound(w, r)
case nil:
// nolint: errcheck,gosec
item.Expand(getExpandFromContext(r))
item.Expand(middleware.GetExpandFromContext(r))
h.ResultResponseJSON(w, r, http.StatusOK, item)
default:
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
@ -127,7 +127,7 @@ func UpdateUpstream() func(http.ResponseWriter, *http.Request) {
}
// nolint: errcheck,gosec
// item.Expand(getExpandFromContext(r))
// item.Expand(middleware.GetExpandFromContext(r))
configureUpstream(item)

View File

@ -27,7 +27,7 @@ func GetUsers() func(http.ResponseWriter, *http.Request) {
return
}
users, err := user.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r))
users, err := user.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r))
if err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
} else {
@ -52,7 +52,7 @@ func GetUser() func(http.ResponseWriter, *http.Request) {
h.NotFound(w, r)
case nil:
// nolint: errcheck,gosec
item.Expand(getExpandFromContext(r))
item.Expand(middleware.GetExpandFromContext(r))
h.ResultResponseJSON(w, r, http.StatusOK, item)
default:
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
@ -108,7 +108,7 @@ func UpdateUser() func(http.ResponseWriter, *http.Request) {
}
// nolint: errcheck,gosec
userObject.Expand(getExpandFromContext(r))
userObject.Expand(middleware.GetExpandFromContext(r))
h.ResultResponseJSON(w, r, http.StatusOK, userObject)
default: