Add backend unit tests

This commit is contained in:
Jamie Curnow 2023-07-25 11:59:02 +10:00
parent 72b071dbaa
commit b123ca4fd0
No known key found for this signature in database
GPG Key ID: FFBB624C43388E9E
17 changed files with 399 additions and 283 deletions

View File

@ -1,4 +1,4 @@
version: "2"
version: "3"
tasks:
default:

View File

@ -136,7 +136,15 @@ func getCommonArgs() []string {
}
// This is split out into it's own function so it's testable
func buildCertRequestArgs(domains []string, method, outputFullchainFile, outputKeyFile string, dnsProvider *dnsprovider.Model, ca *certificateauthority.Model, force bool) ([]string, error) {
func buildCertRequestArgs(
domains []string,
method,
outputFullchainFile,
outputKeyFile string,
dnsProvider *dnsprovider.Model,
ca *certificateauthority.Model,
force bool,
) ([]string, error) {
// The argument order matters.
// see https://github.com/acmesh-official/acme.sh/wiki/How-to-issue-a-cert#3-multiple-domains-san-mode--hybrid-mode
// for multiple domains and note that the method of validation is required just after the domain arg, each time.

View File

@ -193,3 +193,33 @@ func TestBuildCertRequestArgs(t *testing.T) {
})
}
}
func TestGetAcmeShFilePath(t *testing.T) {
t.Run("basic test", func(t *testing.T) {
path, err := getAcmeShFilePath()
assert.Equal(t, "/bin/acme.sh", path)
assert.Equal(t, nil, err)
})
}
func TestGetCommonEnvVars(t *testing.T) {
t.Run("basic test", func(t *testing.T) {
expected := []string{
"ACMESH_CONFIG_HOME=/data/.acme.sh/config",
"ACMESH_HOME=/data/.acme.sh",
"CERT_HOME=/data/.acme.sh/certs",
"LE_CONFIG_HOME=/data/.acme.sh/config",
"LE_WORKING_DIR=/data/.acme.sh",
}
vals := getCommonEnvVars()
assert.Equal(t, expected, vals)
})
}
func TestGetAcmeShVersion(t *testing.T) {
t.Run("basic test", func(t *testing.T) {
resp := GetAcmeShVersion()
assert.Greater(t, len(resp), 1)
assert.Equal(t, "v", resp[:1])
})
}

View File

@ -0,0 +1,13 @@
package context
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetString(t *testing.T) {
t.Run("basic test", func(t *testing.T) {
assert.Equal(t, "context value: Body", BodyCtxKey.String())
})
}

View File

@ -1,208 +0,0 @@
package filters
import (
"fmt"
"strings"
)
// NewFilterSchema is the main method to specify a new Filter Schema for use in Middleware
func NewFilterSchema(fieldSchemas []string) string {
return fmt.Sprintf(baseFilterSchema, strings.Join(fieldSchemas, ", "))
}
// BoolFieldSchema returns the Field Schema for a Boolean accepted value field
func BoolFieldSchema(fieldName string) string {
return fmt.Sprintf(`{
"type": "object",
"properties": {
"field": {
"type": "string",
"pattern": "^%s$"
},
"modifier": %s,
"value": {
"oneOf": [
%s,
{
"type": "array",
"items": %s
}
]
}
}
}`, fieldName, boolModifiers, filterBool, filterBool)
}
// IntFieldSchema returns the Field Schema for a Integer accepted value field
func IntFieldSchema(fieldName string) string {
return fmt.Sprintf(`{
"type": "object",
"properties": {
"field": {
"type": "string",
"pattern": "^%s$"
},
"modifier": %s,
"value": {
"oneOf": [
{
"type": "string",
"pattern": "^[0-9]+$"
},
{
"type": "array",
"items": {
"type": "string",
"pattern": "^[0-9]+$"
}
}
]
}
}
}`, fieldName, allModifiers)
}
// StringFieldSchema returns the Field Schema for a String accepted value field
func StringFieldSchema(fieldName string) string {
return fmt.Sprintf(`{
"type": "object",
"properties": {
"field": {
"type": "string",
"pattern": "^%s$"
},
"modifier": %s,
"value": {
"oneOf": [
%s,
{
"type": "array",
"items": %s
}
]
}
}
}`, fieldName, stringModifiers, filterString, filterString)
}
// RegexFieldSchema returns the Field Schema for a String accepted value field matching a Regex
func RegexFieldSchema(fieldName string, regex string) string {
return fmt.Sprintf(`{
"type": "object",
"properties": {
"field": {
"type": "string",
"pattern": "^%s$"
},
"modifier": %s,
"value": {
"oneOf": [
{
"type": "string",
"pattern": "%s"
},
{
"type": "array",
"items": {
"type": "string",
"pattern": "%s"
}
}
]
}
}
}`, fieldName, stringModifiers, regex, regex)
}
// DateFieldSchema returns the Field Schema for a String accepted value field matching a Date format
func DateFieldSchema(fieldName string) string {
return fmt.Sprintf(`{
"type": "object",
"properties": {
"field": {
"type": "string",
"pattern": "^%s$"
},
"modifier": %s,
"value": {
"oneOf": [
{
"type": "string",
"pattern": "^([12]\\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\\d|3[01]))$"
},
{
"type": "array",
"items": {
"type": "string",
"pattern": "^([12]\\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\\d|3[01]))$"
}
}
]
}
}
}`, fieldName, allModifiers)
}
// DateTimeFieldSchema returns the Field Schema for a String accepted value field matching a Date format
// 2020-03-01T10:30:00+10:00
func DateTimeFieldSchema(fieldName string) string {
return fmt.Sprintf(`{
"type": "object",
"properties": {
"field": {
"type": "string",
"pattern": "^%s$"
},
"modifier": %s,
"value": {
"oneOf": [
{
"type": "string",
"pattern": "^([12]\\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\\d|3[01]))$"
},
{
"type": "array",
"items": {
"type": "string",
"pattern": "^([12]\\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\\d|3[01]))$"
}
}
]
}
}
}`, fieldName, allModifiers)
}
const allModifiers = `{
"type": "string",
"pattern": "^(equals|not|contains|starts|ends|in|notin|min|max|greater|less)$"
}`
const boolModifiers = `{
"type": "string",
"pattern": "^(equals|not)$"
}`
const stringModifiers = `{
"type": "string",
"pattern": "^(equals|not|contains|starts|ends|in|notin)$"
}`
const filterBool = `{
"type": "string",
"pattern": "^(TRUE|true|t|yes|y|on|1|FALSE|f|false|n|no|off|0)$"
}`
const filterString = `{
"type": "string",
"minLength": 1
}`
const baseFilterSchema = `{
"type": "array",
"items": {
"oneOf": [
%s
]
}
}`

View File

@ -26,7 +26,7 @@ func GetCertificates() func(http.ResponseWriter, *http.Request) {
return
}
certificates, err := certificate.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r))
certificates, err := certificate.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r))
if err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
} else {
@ -41,7 +41,7 @@ func GetCertificate() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if item := getCertificateFromRequest(w, r); item != nil {
// nolint: errcheck,gosec
item.Expand(getExpandFromContext(r))
item.Expand(middleware.GetExpandFromContext(r))
h.ResultResponseJSON(w, r, http.StatusOK, item)
}
}

View File

@ -4,8 +4,6 @@ import (
"net/http"
"strconv"
"npm/internal/api/context"
"npm/internal/api/middleware"
"npm/internal/model"
"github.com/go-chi/chi/v5"
@ -15,7 +13,7 @@ import (
const defaultLimit = 10
func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) {
var pageInfo model.PageInfo
pageInfo := model.PageInfo{}
var err error
pageInfo.Offset, pageInfo.Limit, err = getPagination(r)
@ -23,7 +21,7 @@ func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) {
return pageInfo, err
}
pageInfo.Sort = middleware.GetSortFromContext(r)
// pageInfo.Sort = middleware.GetSortFromContext(r)
return pageInfo, nil
}
@ -93,12 +91,3 @@ func getPagination(r *http.Request) (int, int, error) {
return offset, limit, nil
}
// getExpandFromContext returns the Expansion setting
func getExpandFromContext(r *http.Request) []string {
expand, ok := r.Context().Value(context.ExpansionCtxKey).([]string)
if !ok {
return nil
}
return expand
}

View File

@ -0,0 +1,118 @@
package handler
import (
"net/http"
"net/http/httptest"
"npm/internal/model"
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetPageInfoFromRequest(t *testing.T) {
t.Run("basic test", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/hosts", nil)
p, err := getPageInfoFromRequest(r)
var nilStringSlice []string
var nilSortSlice []model.Sort
defaultSort := model.Sort{Field: "name", Direction: "asc"}
assert.Equal(t, nil, err)
assert.Equal(t, 0, p.Offset)
assert.Equal(t, 10, p.Limit)
assert.Equal(t, nilStringSlice, p.Expand)
assert.Equal(t, nilSortSlice, p.Sort)
assert.Equal(t, []model.Sort{defaultSort}, p.GetSort(defaultSort))
})
}
func TestGetQueryVarInt(t *testing.T) {
type want struct {
val int
err string
}
tests := []struct {
name string
url string
queryVar string
required bool
defaultValue int
want want
}{
{
name: "simple default",
url: "/hosts",
queryVar: "something",
required: false,
defaultValue: 100,
want: want{
val: 100,
err: "",
},
},
{
name: "required flag",
url: "/hosts",
queryVar: "something",
required: true,
want: want{
val: 0,
err: "something was not supplied in the request",
},
},
{
name: "simple get",
url: "/hosts?something=50",
queryVar: "something",
required: true,
want: want{
val: 50,
err: "",
},
},
{
name: "invalid number",
url: "/hosts?something=aaa",
queryVar: "something",
required: true,
want: want{
val: 0,
err: "",
},
},
{
name: "preceding zeros",
url: "/hosts?something=0000050",
queryVar: "something",
required: true,
want: want{
val: 50,
err: "",
},
},
{
name: "decimals",
url: "/hosts?something=50.50",
queryVar: "something",
required: true,
want: want{
val: 0,
err: "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, tt.url, nil)
val, err := getQueryVarInt(r, tt.queryVar, tt.required, tt.defaultValue)
assert.Equal(t, tt.want.val, val)
if tt.want.err != "" {
assert.NotEqual(t, nil, err)
assert.Equal(t, tt.want.err, err.Error())
}
})
}
}

View File

@ -26,7 +26,7 @@ func GetHosts() func(http.ResponseWriter, *http.Request) {
return
}
hosts, err := host.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r))
hosts, err := host.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r))
if err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
} else {
@ -52,7 +52,7 @@ func GetHost() func(http.ResponseWriter, *http.Request) {
h.NotFound(w, r)
case nil:
// nolint: errcheck,gosec
item.Expand(getExpandFromContext(r))
item.Expand(middleware.GetExpandFromContext(r))
h.ResultResponseJSON(w, r, http.StatusOK, item)
default:
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
@ -132,7 +132,7 @@ func UpdateHost() func(http.ResponseWriter, *http.Request) {
}
// nolint: errcheck,gosec
hostObject.Expand(getExpandFromContext(r))
hostObject.Expand(middleware.GetExpandFromContext(r))
configureHost(hostObject)

View File

@ -25,14 +25,21 @@ func NotFound() func(http.ResponseWriter, *http.Request) {
assetsSub, _ = fs.Sub(embed.Assets, "assets")
return func(w http.ResponseWriter, r *http.Request) {
defaultFile := "index.html"
path := strings.TrimLeft(r.URL.Path, "/")
isAPI := false
if len(path) >= 3 && path[0:3] == "api" {
isAPI = true
}
if path == "" {
path = "index.html"
path = defaultFile
}
err := tryRead(assetsSub, path, w)
if err == errIsDir {
err = tryRead(assetsSub, "index.html", w)
err = tryRead(assetsSub, defaultFile, w)
if err != nil {
h.NotFound(w, r)
}
@ -40,6 +47,16 @@ func NotFound() func(http.ResponseWriter, *http.Request) {
return
}
// Check if the path has an extension and not in the "/api" path
ext := filepath.Ext(path)
if !isAPI && ext == "" {
// Not an api endpoint and Not a specific file, return the default index file
err := tryRead(assetsSub, defaultFile, w)
if err == nil {
return
}
}
h.NotFound(w, r)
}
}

View File

@ -27,7 +27,7 @@ func GetUpstreams() func(http.ResponseWriter, *http.Request) {
return
}
items, err := upstream.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r))
items, err := upstream.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r))
if err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
} else {
@ -53,7 +53,7 @@ func GetUpstream() func(http.ResponseWriter, *http.Request) {
h.NotFound(w, r)
case nil:
// nolint: errcheck,gosec
item.Expand(getExpandFromContext(r))
item.Expand(middleware.GetExpandFromContext(r))
h.ResultResponseJSON(w, r, http.StatusOK, item)
default:
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
@ -127,7 +127,7 @@ func UpdateUpstream() func(http.ResponseWriter, *http.Request) {
}
// nolint: errcheck,gosec
// item.Expand(getExpandFromContext(r))
// item.Expand(middleware.GetExpandFromContext(r))
configureUpstream(item)

View File

@ -27,7 +27,7 @@ func GetUsers() func(http.ResponseWriter, *http.Request) {
return
}
users, err := user.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r))
users, err := user.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r))
if err != nil {
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
} else {
@ -52,7 +52,7 @@ func GetUser() func(http.ResponseWriter, *http.Request) {
h.NotFound(w, r)
case nil:
// nolint: errcheck,gosec
item.Expand(getExpandFromContext(r))
item.Expand(middleware.GetExpandFromContext(r))
h.ResultResponseJSON(w, r, http.StatusOK, item)
default:
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
@ -108,7 +108,7 @@ func UpdateUser() func(http.ResponseWriter, *http.Request) {
}
// nolint: errcheck,gosec
userObject.Expand(getExpandFromContext(r))
userObject.Expand(middleware.GetExpandFromContext(r))
h.ResultResponseJSON(w, r, http.StatusOK, userObject)
default:

View File

@ -1,46 +0,0 @@
package http
import (
"context"
"encoding/json"
"github.com/qri-io/jsonschema"
"github.com/rotisserie/eris"
)
var (
// ErrInvalidJSON is an error for invalid json
ErrInvalidJSON = eris.New("JSON is invalid")
// ErrInvalidPayload is an error for invalid incoming data
ErrInvalidPayload = eris.New("Payload is invalid")
)
// ValidateRequestSchema takes a Schema and the Content to validate against it
func ValidateRequestSchema(schema string, requestBody []byte) ([]jsonschema.KeyError, error) {
var jsonErrors []jsonschema.KeyError
var schemaBytes = []byte(schema)
// Make sure the body is valid JSON
if !isJSON(requestBody) {
return jsonErrors, ErrInvalidJSON
}
rs := &jsonschema.Schema{}
if err := json.Unmarshal(schemaBytes, rs); err != nil {
return jsonErrors, err
}
var validationErr error
ctx := context.TODO()
if jsonErrors, validationErr = rs.ValidateBytes(ctx, requestBody); len(jsonErrors) > 0 {
return jsonErrors, validationErr
}
// Valid
return nil, nil
}
func isJSON(bytes []byte) bool {
var js map[string]interface{}
return json.Unmarshal(bytes, &js) == nil
}

View File

@ -11,6 +11,12 @@ import (
"npm/internal/logger"
"github.com/qri-io/jsonschema"
"github.com/rotisserie/eris"
)
var (
// ErrInvalidPayload is an error for invalid incoming data
ErrInvalidPayload = eris.New("Payload is invalid")
)
// Response interface for standard API results

View File

@ -0,0 +1,180 @@
package http
import (
"io"
"net/http"
"net/http/httptest"
"npm/internal/entity/user"
"npm/internal/model"
"testing"
"github.com/qri-io/jsonschema"
"github.com/stretchr/testify/assert"
)
func TestResultResponseJSON(t *testing.T) {
tests := []struct {
name string
status int
given interface{}
want string
}{
{
name: "simple response",
status: http.StatusOK,
given: true,
want: "{\"result\":true}",
},
{
name: "detailed response",
status: http.StatusBadRequest,
given: user.Model{
ModelBase: model.ModelBase{ID: 10},
Email: "me@example.com",
Name: "John Doe",
Nickname: "Jonny",
},
want: "{\"result\":{\"id\":10,\"created_at\":0,\"updated_at\":0,\"name\":\"John Doe\",\"nickname\":\"Jonny\",\"email\":\"me@example.com\",\"is_disabled\":false,\"gravatar_url\":\"\"}}",
},
{
name: "error response",
status: http.StatusNotFound,
given: ErrorResponse{
Code: 404,
Message: "Not found",
Invalid: []string{"your", "page", "was", "not", "found"},
},
want: "{\"result\":null,\"error\":{\"code\":404,\"message\":\"Not found\",\"invalid\":[\"your\",\"page\",\"was\",\"not\",\"found\"]}}",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/anything", nil)
w := httptest.NewRecorder()
ResultResponseJSON(w, r, tt.status, tt.given)
res := w.Result()
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
t.Errorf("expected error to be nil got %v", err)
}
assert.Equal(t, tt.want, string(body))
assert.Equal(t, tt.status, res.StatusCode)
assert.Equal(t, "application/json; charset=utf-8", res.Header.Get("Content-Type"))
})
}
}
func TestResultSchemaErrorJSON(t *testing.T) {
tests := []struct {
name string
given []jsonschema.KeyError
want string
}{
{
name: "case a",
given: []jsonschema.KeyError{
{
PropertyPath: "/something",
InvalidValue: "name",
Message: "Name cannot be empty",
},
},
want: "{\"result\":null,\"error\":{\"code\":400,\"message\":{},\"invalid\":[{\"propertyPath\":\"/something\",\"invalidValue\":\"name\",\"message\":\"Name cannot be empty\"}]}}",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/anything", nil)
w := httptest.NewRecorder()
ResultSchemaErrorJSON(w, r, tt.given)
res := w.Result()
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
t.Errorf("expected error to be nil got %v", err)
}
assert.Equal(t, tt.want, string(body))
assert.Equal(t, 400, res.StatusCode)
assert.Equal(t, "application/json; charset=utf-8", res.Header.Get("Content-Type"))
})
}
}
func TestResultErrorJSON(t *testing.T) {
tests := []struct {
name string
status int
message string
extended interface{}
want string
}{
{
name: "case a",
status: http.StatusBadGateway,
message: "Oh not something is not acceptable",
extended: nil,
want: "{\"result\":null,\"error\":{\"code\":502,\"message\":\"Oh not something is not acceptable\"}}",
},
{
name: "case b",
status: http.StatusNotAcceptable,
message: "Oh not something is not acceptable again",
extended: []string{"name is not allowed", "dob is wrong or something"},
want: "{\"result\":null,\"error\":{\"code\":406,\"message\":\"Oh not something is not acceptable again\",\"invalid\":[\"name is not allowed\",\"dob is wrong or something\"]}}",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/anything", nil)
w := httptest.NewRecorder()
ResultErrorJSON(w, r, tt.status, tt.message, tt.extended)
res := w.Result()
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
t.Errorf("expected error to be nil got %v", err)
}
assert.Equal(t, tt.want, string(body))
assert.Equal(t, tt.status, res.StatusCode)
assert.Equal(t, "application/json; charset=utf-8", res.Header.Get("Content-Type"))
})
}
}
func TestNotFound(t *testing.T) {
t.Run("basic test", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/anything", nil)
w := httptest.NewRecorder()
NotFound(w, r)
res := w.Result()
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
t.Errorf("expected error to be nil got %v", err)
}
assert.Equal(t, "{\"result\":null,\"error\":{\"code\":404,\"message\":\"Not found\"}}", string(body))
assert.Equal(t, http.StatusNotFound, res.StatusCode)
assert.Equal(t, "application/json; charset=utf-8", res.Header.Get("Content-Type"))
})
}
func TestResultResponseText(t *testing.T) {
t.Run("basic test", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/anything", nil)
w := httptest.NewRecorder()
ResultResponseText(w, r, http.StatusOK, "omg this works")
res := w.Result()
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
t.Errorf("expected error to be nil got %v", err)
}
assert.Equal(t, "omg this works", string(body))
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-Type"))
})
}

View File

@ -22,3 +22,12 @@ func Expansion(next http.Handler) http.Handler {
}
})
}
// GetExpandFromContext returns the Expansion setting
func GetExpandFromContext(r *http.Request) []string {
expand, ok := r.Context().Value(c.ExpansionCtxKey).([]string)
if !ok {
return nil
}
return expand
}

View File

@ -35,7 +35,7 @@ func TestGetHealthz(t *testing.T) {
func TestNonExistent(t *testing.T) {
respRec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/non-existent-endpoint", nil)
req, _ := http.NewRequest("GET", "/non-existent-endpoint.jpg", nil)
r.ServeHTTP(respRec, req)
assert.Equal(t, http.StatusNotFound, respRec.Code)