mirror of
https://github.com/jc21/nginx-proxy-manager.git
synced 2024-08-30 18:22:48 +00:00
Better checking for api sort param to prevent sql injection
And moved filters out and cached object reflection
This commit is contained in:
parent
9b32329f41
commit
a0e17f9678
@ -7,6 +7,8 @@ var (
|
|||||||
UserIDCtxKey = &contextKey{"UserID"}
|
UserIDCtxKey = &contextKey{"UserID"}
|
||||||
// FiltersCtxKey is the name of the Filters value on the context
|
// FiltersCtxKey is the name of the Filters value on the context
|
||||||
FiltersCtxKey = &contextKey{"Filters"}
|
FiltersCtxKey = &contextKey{"Filters"}
|
||||||
|
// SortCtxKey is the name of the Sort value on the context
|
||||||
|
SortCtxKey = &contextKey{"Sort"}
|
||||||
// PrettyPrintCtxKey is the name of the pretty print context
|
// PrettyPrintCtxKey is the name of the pretty print context
|
||||||
PrettyPrintCtxKey = &contextKey{"Pretty"}
|
PrettyPrintCtxKey = &contextKey{"Pretty"}
|
||||||
// ExpansionCtxKey is the name of the expansion context
|
// ExpansionCtxKey is the name of the expansion context
|
||||||
|
@ -3,9 +3,9 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"npm/internal/api/context"
|
"npm/internal/api/context"
|
||||||
|
"npm/internal/api/middleware"
|
||||||
"npm/internal/model"
|
"npm/internal/model"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
@ -23,50 +23,11 @@ func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) {
|
|||||||
return pageInfo, err
|
return pageInfo, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pageInfo.Sort = getSortParameter(r)
|
pageInfo.Sort = middleware.GetSortFromContext(r)
|
||||||
|
|
||||||
return pageInfo, nil
|
return pageInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSortParameter(r *http.Request) []model.Sort {
|
|
||||||
var sortFields []model.Sort
|
|
||||||
|
|
||||||
queryValues := r.URL.Query()
|
|
||||||
sortString := queryValues.Get("sort")
|
|
||||||
if sortString == "" {
|
|
||||||
return sortFields
|
|
||||||
}
|
|
||||||
|
|
||||||
// Split sort fields up in to slice
|
|
||||||
sorts := strings.Split(sortString, ",")
|
|
||||||
for _, sortItem := range sorts {
|
|
||||||
if strings.Contains(sortItem, ".") {
|
|
||||||
theseItems := strings.Split(sortItem, ".")
|
|
||||||
|
|
||||||
switch strings.ToLower(theseItems[1]) {
|
|
||||||
case "desc":
|
|
||||||
fallthrough
|
|
||||||
case "descending":
|
|
||||||
theseItems[1] = "DESC"
|
|
||||||
default:
|
|
||||||
theseItems[1] = "ASC"
|
|
||||||
}
|
|
||||||
|
|
||||||
sortFields = append(sortFields, model.Sort{
|
|
||||||
Field: theseItems[0],
|
|
||||||
Direction: theseItems[1],
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
sortFields = append(sortFields, model.Sort{
|
|
||||||
Field: sortItem,
|
|
||||||
Direction: "ASC",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return sortFields
|
|
||||||
}
|
|
||||||
|
|
||||||
func getQueryVarInt(r *http.Request, varName string, required bool, defaultValue int) (int, error) {
|
func getQueryVarInt(r *http.Request, varName string, required bool, defaultValue int) (int, error) {
|
||||||
queryValues := r.URL.Query()
|
queryValues := r.URL.Query()
|
||||||
varValue := queryValues.Get(varName)
|
varValue := queryValues.Get(varName)
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Expansion will determine whether the request should have objects expanded
|
// Expansion will determine whether the request should have objects expanded
|
||||||
// with ?expand=1 or ?expand=true
|
// with ?expand=item,item
|
||||||
func Expansion(next http.Handler) http.Handler {
|
func Expansion(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
expandStr := r.URL.Query().Get("expand")
|
expandStr := r.URL.Query().Get("expand")
|
||||||
|
@ -1,118 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
c "npm/internal/api/context"
|
|
||||||
h "npm/internal/api/http"
|
|
||||||
"npm/internal/entity"
|
|
||||||
"npm/internal/model"
|
|
||||||
"npm/internal/util"
|
|
||||||
|
|
||||||
"github.com/qri-io/jsonschema"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Filters will accept a pre-defined schemaData to validate against the GET query params
|
|
||||||
// passed in to this endpoint. This will ensure that the filters are not injecting SQL.
|
|
||||||
// After we have determined what the Filters are to be, they are saved on the Context
|
|
||||||
// to be used later in other endpoints.
|
|
||||||
func Filters(obj interface{}) func(http.Handler) http.Handler {
|
|
||||||
schemaData := entity.GetFilterSchema(obj, true)
|
|
||||||
|
|
||||||
reservedFilterKeys := []string{
|
|
||||||
"limit",
|
|
||||||
"offset",
|
|
||||||
"sort",
|
|
||||||
"order",
|
|
||||||
"expand",
|
|
||||||
"t", // This is used as a timestamp paramater in some clients and can be ignored
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var filters []model.Filter
|
|
||||||
for key, val := range r.URL.Query() {
|
|
||||||
key = strings.ToLower(key)
|
|
||||||
|
|
||||||
// Split out the modifier from the field name and set a default modifier
|
|
||||||
var keyParts []string
|
|
||||||
keyParts = strings.Split(key, ":")
|
|
||||||
if len(keyParts) == 1 {
|
|
||||||
// Default modifier
|
|
||||||
keyParts = append(keyParts, "equals")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only use this filter if it's not a reserved get param
|
|
||||||
if !util.SliceContainsItem(reservedFilterKeys, keyParts[0]) {
|
|
||||||
for _, valItem := range val {
|
|
||||||
// Check that the val isn't empty
|
|
||||||
if len(strings.TrimSpace(valItem)) > 0 {
|
|
||||||
valSlice := []string{valItem}
|
|
||||||
if keyParts[1] == "in" || keyParts[1] == "notin" {
|
|
||||||
valSlice = strings.Split(valItem, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
filters = append(filters, model.Filter{
|
|
||||||
Field: keyParts[0],
|
|
||||||
Modifier: keyParts[1],
|
|
||||||
Value: valSlice,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only validate schema if there are filters to validate
|
|
||||||
if len(filters) > 0 {
|
|
||||||
ctx := r.Context()
|
|
||||||
|
|
||||||
// Marshal the Filters in to a JSON string so that the Schema Validation works against it
|
|
||||||
filterData, marshalErr := json.MarshalIndent(filters, "", " ")
|
|
||||||
if marshalErr != nil {
|
|
||||||
h.ResultErrorJSON(w, r, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", marshalErr), nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create root schema
|
|
||||||
rs := &jsonschema.Schema{}
|
|
||||||
if err := json.Unmarshal([]byte(schemaData), rs); err != nil {
|
|
||||||
h.ResultErrorJSON(w, r, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", err), nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate it
|
|
||||||
errors, jsonError := rs.ValidateBytes(ctx, filterData)
|
|
||||||
if jsonError != nil {
|
|
||||||
h.ResultErrorJSON(w, r, http.StatusBadRequest, jsonError.Error(), nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errors) > 0 {
|
|
||||||
h.ResultErrorJSON(w, r, http.StatusBadRequest, "Invalid Filters", errors)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// todo: populate filters object with the gorm database name
|
|
||||||
|
|
||||||
ctx = context.WithValue(ctx, c.FiltersCtxKey, filters)
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
} else {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFiltersFromContext returns the Filters
|
|
||||||
func GetFiltersFromContext(r *http.Request) []model.Filter {
|
|
||||||
filters, ok := r.Context().Value(c.FiltersCtxKey).([]model.Filter)
|
|
||||||
if !ok {
|
|
||||||
// the assertion failed
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return filters
|
|
||||||
}
|
|
196
backend/internal/api/middleware/list_query.go
Normal file
196
backend/internal/api/middleware/list_query.go
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
c "npm/internal/api/context"
|
||||||
|
h "npm/internal/api/http"
|
||||||
|
"npm/internal/entity"
|
||||||
|
"npm/internal/model"
|
||||||
|
"npm/internal/tags"
|
||||||
|
"npm/internal/util"
|
||||||
|
|
||||||
|
"github.com/qri-io/jsonschema"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ListQuery will accept a pre-defined schemaData to validate against the GET query params
|
||||||
|
// passed in to this endpoint. This will ensure that the filters are not injecting SQL
|
||||||
|
// and the sort parameter is valid as well.
|
||||||
|
// After we have determined what the Filters are to be, they are saved on the Context
|
||||||
|
// to be used later in other endpoints.
|
||||||
|
func ListQuery(obj interface{}) func(http.Handler) http.Handler {
|
||||||
|
schemaData := entity.GetFilterSchema(obj, true)
|
||||||
|
filterMap := tags.GetFilterMap(obj)
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
ctx, statusCode, errMsg, errors := listQueryFilters(r, ctx, schemaData)
|
||||||
|
if statusCode > 0 {
|
||||||
|
h.ResultErrorJSON(w, r, statusCode, errMsg, errors)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, statusCode, errMsg = listQuerySort(r, filterMap, ctx)
|
||||||
|
if statusCode > 0 {
|
||||||
|
h.ResultErrorJSON(w, r, statusCode, errMsg, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func listQuerySort(
|
||||||
|
r *http.Request,
|
||||||
|
filterMap map[string]model.FilterMapValue,
|
||||||
|
ctx context.Context,
|
||||||
|
) (context.Context, int, string) {
|
||||||
|
var sortFields []model.Sort
|
||||||
|
|
||||||
|
sortString := r.URL.Query().Get("sort")
|
||||||
|
if sortString == "" {
|
||||||
|
return ctx, 0, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split sort fields up in to slice
|
||||||
|
sorts := strings.Split(sortString, ",")
|
||||||
|
for _, sortItem := range sorts {
|
||||||
|
if strings.Contains(sortItem, ".") {
|
||||||
|
theseItems := strings.Split(sortItem, ".")
|
||||||
|
|
||||||
|
switch strings.ToLower(theseItems[1]) {
|
||||||
|
case "desc":
|
||||||
|
fallthrough
|
||||||
|
case "descending":
|
||||||
|
theseItems[1] = "DESC"
|
||||||
|
default:
|
||||||
|
theseItems[1] = "ASC"
|
||||||
|
}
|
||||||
|
|
||||||
|
sortFields = append(sortFields, model.Sort{
|
||||||
|
Field: theseItems[0],
|
||||||
|
Direction: theseItems[1],
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
sortFields = append(sortFields, model.Sort{
|
||||||
|
Field: sortItem,
|
||||||
|
Direction: "ASC",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check against filter schema
|
||||||
|
for _, f := range sortFields {
|
||||||
|
if _, exists := filterMap[f.Field]; !exists {
|
||||||
|
return ctx, http.StatusBadRequest, "Invalid sort field"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = context.WithValue(ctx, c.SortCtxKey, sortFields)
|
||||||
|
|
||||||
|
// No problems!
|
||||||
|
return ctx, 0, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func listQueryFilters(
|
||||||
|
r *http.Request,
|
||||||
|
ctx context.Context,
|
||||||
|
schemaData string,
|
||||||
|
) (context.Context, int, string, interface{}) {
|
||||||
|
reservedFilterKeys := []string{
|
||||||
|
"limit",
|
||||||
|
"offset",
|
||||||
|
"sort",
|
||||||
|
"expand",
|
||||||
|
"t", // This is used as a timestamp paramater in some clients and can be ignored
|
||||||
|
}
|
||||||
|
|
||||||
|
var filters []model.Filter
|
||||||
|
for key, val := range r.URL.Query() {
|
||||||
|
key = strings.ToLower(key)
|
||||||
|
|
||||||
|
// Split out the modifier from the field name and set a default modifier
|
||||||
|
var keyParts []string
|
||||||
|
keyParts = strings.Split(key, ":")
|
||||||
|
if len(keyParts) == 1 {
|
||||||
|
// Default modifier
|
||||||
|
keyParts = append(keyParts, "equals")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only use this filter if it's not a reserved get param
|
||||||
|
if !util.SliceContainsItem(reservedFilterKeys, keyParts[0]) {
|
||||||
|
for _, valItem := range val {
|
||||||
|
// Check that the val isn't empty
|
||||||
|
if len(strings.TrimSpace(valItem)) > 0 {
|
||||||
|
valSlice := []string{valItem}
|
||||||
|
if keyParts[1] == "in" || keyParts[1] == "notin" {
|
||||||
|
valSlice = strings.Split(valItem, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
filters = append(filters, model.Filter{
|
||||||
|
Field: keyParts[0],
|
||||||
|
Modifier: keyParts[1],
|
||||||
|
Value: valSlice,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only validate schema if there are filters to validate
|
||||||
|
if len(filters) > 0 {
|
||||||
|
// Marshal the Filters in to a JSON string so that the Schema Validation works against it
|
||||||
|
filterData, marshalErr := json.MarshalIndent(filters, "", " ")
|
||||||
|
if marshalErr != nil {
|
||||||
|
return ctx, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", marshalErr), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create root schema
|
||||||
|
rs := &jsonschema.Schema{}
|
||||||
|
if err := json.Unmarshal([]byte(schemaData), rs); err != nil {
|
||||||
|
return ctx, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", err), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate it
|
||||||
|
errors, jsonError := rs.ValidateBytes(ctx, filterData)
|
||||||
|
if jsonError != nil {
|
||||||
|
return ctx, http.StatusBadRequest, jsonError.Error(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return ctx, http.StatusBadRequest, "Invalid Filters", errors
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = context.WithValue(ctx, c.FiltersCtxKey, filters)
|
||||||
|
}
|
||||||
|
|
||||||
|
// No problems!
|
||||||
|
return ctx, 0, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFiltersFromContext returns the Filters
|
||||||
|
func GetFiltersFromContext(r *http.Request) []model.Filter {
|
||||||
|
filters, ok := r.Context().Value(c.FiltersCtxKey).([]model.Filter)
|
||||||
|
if !ok {
|
||||||
|
// the assertion failed
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return filters
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSortFromContext returns the Sort
|
||||||
|
func GetSortFromContext(r *http.Request) []model.Sort {
|
||||||
|
sorts, ok := r.Context().Value(c.SortCtxKey).([]model.Sort)
|
||||||
|
if !ok {
|
||||||
|
// the assertion failed
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return sorts
|
||||||
|
}
|
@ -104,7 +104,7 @@ func applyRoutes(r chi.Router) chi.Router {
|
|||||||
// List
|
// List
|
||||||
r.With(
|
r.With(
|
||||||
middleware.Enforce(user.CapabilityUsersManage),
|
middleware.Enforce(user.CapabilityUsersManage),
|
||||||
middleware.Filters(user.Model{}),
|
middleware.ListQuery(user.Model{}),
|
||||||
).Get("/", handler.GetUsers())
|
).Get("/", handler.GetUsers())
|
||||||
|
|
||||||
// Specific Item
|
// Specific Item
|
||||||
@ -136,7 +136,7 @@ func applyRoutes(r chi.Router) chi.Router {
|
|||||||
r.With(middleware.EnforceSetup(true), middleware.Enforce(user.CapabilitySettingsManage)).Route("/settings", func(r chi.Router) {
|
r.With(middleware.EnforceSetup(true), middleware.Enforce(user.CapabilitySettingsManage)).Route("/settings", func(r chi.Router) {
|
||||||
// List
|
// List
|
||||||
r.With(
|
r.With(
|
||||||
middleware.Filters(setting.Model{}),
|
middleware.ListQuery(setting.Model{}),
|
||||||
).Get("/", handler.GetSettings())
|
).Get("/", handler.GetSettings())
|
||||||
|
|
||||||
r.Get("/{name}", handler.GetSetting())
|
r.Get("/{name}", handler.GetSetting())
|
||||||
@ -151,7 +151,7 @@ func applyRoutes(r chi.Router) chi.Router {
|
|||||||
// List
|
// List
|
||||||
r.With(
|
r.With(
|
||||||
middleware.Enforce(user.CapabilityAccessListsView),
|
middleware.Enforce(user.CapabilityAccessListsView),
|
||||||
middleware.Filters(accesslist.Model{}),
|
middleware.ListQuery(accesslist.Model{}),
|
||||||
).Get("/", handler.GetAccessLists())
|
).Get("/", handler.GetAccessLists())
|
||||||
|
|
||||||
// Create
|
// Create
|
||||||
@ -175,7 +175,7 @@ func applyRoutes(r chi.Router) chi.Router {
|
|||||||
// List
|
// List
|
||||||
r.With(
|
r.With(
|
||||||
middleware.Enforce(user.CapabilityDNSProvidersView),
|
middleware.Enforce(user.CapabilityDNSProvidersView),
|
||||||
middleware.Filters(dnsprovider.Model{}),
|
middleware.ListQuery(dnsprovider.Model{}),
|
||||||
).Get("/", handler.GetDNSProviders())
|
).Get("/", handler.GetDNSProviders())
|
||||||
|
|
||||||
// Create
|
// Create
|
||||||
@ -205,7 +205,7 @@ func applyRoutes(r chi.Router) chi.Router {
|
|||||||
// List
|
// List
|
||||||
r.With(
|
r.With(
|
||||||
middleware.Enforce(user.CapabilityCertificateAuthoritiesView),
|
middleware.Enforce(user.CapabilityCertificateAuthoritiesView),
|
||||||
middleware.Filters(certificateauthority.Model{}),
|
middleware.ListQuery(certificateauthority.Model{}),
|
||||||
).Get("/", handler.GetCertificateAuthorities())
|
).Get("/", handler.GetCertificateAuthorities())
|
||||||
|
|
||||||
// Create
|
// Create
|
||||||
@ -235,7 +235,7 @@ func applyRoutes(r chi.Router) chi.Router {
|
|||||||
// List
|
// List
|
||||||
r.With(
|
r.With(
|
||||||
middleware.Enforce(user.CapabilityCertificatesView),
|
middleware.Enforce(user.CapabilityCertificatesView),
|
||||||
middleware.Filters(certificate.Model{}),
|
middleware.ListQuery(certificate.Model{}),
|
||||||
).Get("/", handler.GetCertificates())
|
).Get("/", handler.GetCertificates())
|
||||||
|
|
||||||
// Create
|
// Create
|
||||||
@ -262,7 +262,7 @@ func applyRoutes(r chi.Router) chi.Router {
|
|||||||
// List
|
// List
|
||||||
r.With(
|
r.With(
|
||||||
middleware.Enforce(user.CapabilityHostsView),
|
middleware.Enforce(user.CapabilityHostsView),
|
||||||
middleware.Filters(host.Model{}),
|
middleware.ListQuery(host.Model{}),
|
||||||
).Get("/", handler.GetHosts())
|
).Get("/", handler.GetHosts())
|
||||||
|
|
||||||
// Create
|
// Create
|
||||||
@ -288,7 +288,7 @@ func applyRoutes(r chi.Router) chi.Router {
|
|||||||
// List
|
// List
|
||||||
r.With(
|
r.With(
|
||||||
middleware.Enforce(user.CapabilityNginxTemplatesView),
|
middleware.Enforce(user.CapabilityNginxTemplatesView),
|
||||||
middleware.Filters(nginxtemplate.Model{}),
|
middleware.ListQuery(nginxtemplate.Model{}),
|
||||||
).Get("/", handler.GetNginxTemplates())
|
).Get("/", handler.GetNginxTemplates())
|
||||||
|
|
||||||
// Create
|
// Create
|
||||||
@ -312,7 +312,7 @@ func applyRoutes(r chi.Router) chi.Router {
|
|||||||
// List
|
// List
|
||||||
r.With(
|
r.With(
|
||||||
middleware.Enforce(user.CapabilityStreamsView),
|
middleware.Enforce(user.CapabilityStreamsView),
|
||||||
middleware.Filters(stream.Model{}),
|
middleware.ListQuery(stream.Model{}),
|
||||||
).Get("/", handler.GetStreams())
|
).Get("/", handler.GetStreams())
|
||||||
|
|
||||||
// Create
|
// Create
|
||||||
@ -336,7 +336,7 @@ func applyRoutes(r chi.Router) chi.Router {
|
|||||||
// List
|
// List
|
||||||
r.With(
|
r.With(
|
||||||
middleware.Enforce(user.CapabilityHostsView),
|
middleware.Enforce(user.CapabilityHostsView),
|
||||||
middleware.Filters(upstream.Model{}),
|
middleware.ListQuery(upstream.Model{}),
|
||||||
).Get("/", handler.GetUpstreams())
|
).Get("/", handler.GetUpstreams())
|
||||||
|
|
||||||
// Create
|
// Create
|
||||||
|
@ -1,66 +1,22 @@
|
|||||||
package entity
|
package entity
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"npm/internal/model"
|
||||||
"regexp"
|
"npm/internal/tags"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type filterMapValue struct {
|
|
||||||
Type string
|
|
||||||
Field string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFilterMap returns the filter map
|
// GetFilterMap returns the filter map
|
||||||
func GetFilterMap(m interface{}, includeBaseEntity bool) map[string]filterMapValue {
|
func GetFilterMap(m interface{}, includeBaseEntity bool) map[string]model.FilterMapValue {
|
||||||
filterMap := getFilterMapForInterface(m)
|
filterMap := tags.GetFilterMap(m)
|
||||||
if includeBaseEntity {
|
if includeBaseEntity {
|
||||||
return mergeFilterMaps(getFilterMapForInterface(ModelBase{}), filterMap)
|
return mergeFilterMaps(tags.GetFilterMap(ModelBase{}), filterMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
return filterMap
|
return filterMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func getFilterMapForInterface(m interface{}) map[string]filterMapValue {
|
func mergeFilterMaps(m1 map[string]model.FilterMapValue, m2 map[string]model.FilterMapValue) map[string]model.FilterMapValue {
|
||||||
var filterMap = make(map[string]filterMapValue)
|
merged := make(map[string]model.FilterMapValue, 0)
|
||||||
|
|
||||||
// TypeOf returns the reflection Type that represents the dynamic type of variable.
|
|
||||||
// If variable is a nil interface value, TypeOf returns nil.
|
|
||||||
t := reflect.TypeOf(m)
|
|
||||||
|
|
||||||
// Iterate over all available fields and read the tag value
|
|
||||||
for i := 0; i < t.NumField(); i++ {
|
|
||||||
// Get the field, returns https://golang.org/pkg/reflect/#StructField
|
|
||||||
field := t.Field(i)
|
|
||||||
|
|
||||||
// Get the field tag value
|
|
||||||
filterTag := field.Tag.Get("filter")
|
|
||||||
dbTag := field.Tag.Get("gorm")
|
|
||||||
if filterTag != "" && dbTag != "" && dbTag != "-" && filterTag != "-" {
|
|
||||||
// db can have many parts, we need to pull out the "column:value" part
|
|
||||||
dbField := field.Name
|
|
||||||
r := regexp.MustCompile(`(?:^|;)column:([^;|$]+)(?:$|;)`)
|
|
||||||
if matches := r.FindStringSubmatch(dbTag); len(matches) > 1 {
|
|
||||||
dbField = matches[1]
|
|
||||||
}
|
|
||||||
// Filter tag can be a 2 part thing: name,type
|
|
||||||
// ie: account_id,integer
|
|
||||||
// So we need to split and use the first part
|
|
||||||
parts := strings.Split(filterTag, ",")
|
|
||||||
if len(parts) > 1 {
|
|
||||||
filterMap[parts[0]] = filterMapValue{
|
|
||||||
Type: parts[1],
|
|
||||||
Field: dbField,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return filterMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func mergeFilterMaps(m1 map[string]filterMapValue, m2 map[string]filterMapValue) map[string]filterMapValue {
|
|
||||||
merged := make(map[string]filterMapValue, 0)
|
|
||||||
for k, v := range m1 {
|
for k, v := range m1 {
|
||||||
merged[k] = v
|
merged[k] = v
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,7 @@ type ListResponse struct {
|
|||||||
func ListQueryBuilder(
|
func ListQueryBuilder(
|
||||||
pageInfo *model.PageInfo,
|
pageInfo *model.PageInfo,
|
||||||
filters []model.Filter,
|
filters []model.Filter,
|
||||||
filterMap map[string]filterMapValue,
|
filterMap map[string]model.FilterMapValue,
|
||||||
) *gorm.DB {
|
) *gorm.DB {
|
||||||
scopes := make([]func(*gorm.DB) *gorm.DB, 0)
|
scopes := make([]func(*gorm.DB) *gorm.DB, 0)
|
||||||
scopes = append(scopes, ScopeOffsetLimit(pageInfo))
|
scopes = append(scopes, ScopeOffsetLimit(pageInfo))
|
||||||
|
@ -36,7 +36,7 @@ func ScopeOrderBy(pageInfo *model.PageInfo, defaultSort model.Sort) func(db *gor
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ScopeFilters(filters []model.Filter, filterMap map[string]filterMapValue) func(db *gorm.DB) *gorm.DB {
|
func ScopeFilters(filters []model.Filter, filterMap map[string]model.FilterMapValue) func(db *gorm.DB) *gorm.DB {
|
||||||
return func(db *gorm.DB) *gorm.DB {
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
like := database.GetCaseInsensitiveLike()
|
like := database.GetCaseInsensitiveLike()
|
||||||
for _, f := range filters {
|
for _, f := range filters {
|
||||||
|
@ -6,3 +6,9 @@ type Filter struct {
|
|||||||
Modifier string `json:"modifier"`
|
Modifier string `json:"modifier"`
|
||||||
Value []string `json:"value"`
|
Value []string `json:"value"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FilterMapValue ...
|
||||||
|
type FilterMapValue struct {
|
||||||
|
Type string
|
||||||
|
Field string
|
||||||
|
}
|
||||||
|
53
backend/internal/tags/filters.go
Normal file
53
backend/internal/tags/filters.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package tags
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"npm/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetFilterMap(m interface{}) map[string]model.FilterMapValue {
|
||||||
|
name := getName(m)
|
||||||
|
if val, exists := getCache(name); exists {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
var filterMap = make(map[string]model.FilterMapValue)
|
||||||
|
|
||||||
|
// TypeOf returns the reflection Type that represents the dynamic type of variable.
|
||||||
|
// If variable is a nil interface value, TypeOf returns nil.
|
||||||
|
t := reflect.TypeOf(m)
|
||||||
|
|
||||||
|
// Iterate over all available fields and read the tag value
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
// Get the field, returns https://golang.org/pkg/reflect/#StructField
|
||||||
|
field := t.Field(i)
|
||||||
|
|
||||||
|
// Get the field tag value
|
||||||
|
filterTag := field.Tag.Get("filter")
|
||||||
|
dbTag := field.Tag.Get("gorm")
|
||||||
|
if filterTag != "" && dbTag != "" && dbTag != "-" && filterTag != "-" {
|
||||||
|
// db can have many parts, we need to pull out the "column:value" part
|
||||||
|
dbField := field.Name
|
||||||
|
r := regexp.MustCompile(`(?:^|;)column:([^;|$]+)(?:$|;)`)
|
||||||
|
if matches := r.FindStringSubmatch(dbTag); len(matches) > 1 {
|
||||||
|
dbField = matches[1]
|
||||||
|
}
|
||||||
|
// Filter tag can be a 2 part thing: name,type
|
||||||
|
// ie: account_id,integer
|
||||||
|
// So we need to split and use the first part
|
||||||
|
parts := strings.Split(filterTag, ",")
|
||||||
|
if len(parts) > 1 {
|
||||||
|
filterMap[parts[0]] = model.FilterMapValue{
|
||||||
|
Type: parts[1],
|
||||||
|
Field: dbField,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setCache(name, filterMap)
|
||||||
|
return filterMap
|
||||||
|
}
|
33
backend/internal/tags/reflect.go
Normal file
33
backend/internal/tags/reflect.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package tags
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"npm/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
var tagCache map[string]map[string]model.FilterMapValue
|
||||||
|
|
||||||
|
// getName returns the name of the type given
|
||||||
|
func getName(m interface{}) string {
|
||||||
|
fc := reflect.TypeOf(m)
|
||||||
|
return fmt.Sprint(fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCache tries to find a cached item for this name
|
||||||
|
func getCache(name string) (map[string]model.FilterMapValue, bool) {
|
||||||
|
if val, ok := tagCache[name]; ok {
|
||||||
|
return val, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// setCache sets the name to this value
|
||||||
|
func setCache(name string, val map[string]model.FilterMapValue) {
|
||||||
|
// Hack to initialise empty map
|
||||||
|
if len(tagCache) == 0 {
|
||||||
|
tagCache = make(map[string]map[string]model.FilterMapValue, 0)
|
||||||
|
}
|
||||||
|
tagCache[name] = val
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user