diff --git a/backend/embed/migrations/postgres/20201013035318_initial_schema.sql b/backend/embed/migrations/postgres/20201013035318_initial_schema.sql index e49a2da0..b68bbb0d 100644 --- a/backend/embed/migrations/postgres/20201013035318_initial_schema.sql +++ b/backend/embed/migrations/postgres/20201013035318_initial_schema.sql @@ -139,7 +139,7 @@ CREATE TABLE "upstream" ( "created_at" BIGINT NOT NULL DEFAULT 0, "updated_at" BIGINT NOT NULL DEFAULT 0, "is_deleted" INTEGER NOT NULL DEFAULT 0, -- int on purpose, gormism - "user_id" INTEGER NOT NULL REFERENCES "user"("id"), + "user_id" INTEGER NOT NULL REFERENCES "user"("id") ON DELETE CASCADE, "name" VARCHAR(50) NOT NULL, "nginx_template_id" INTEGER NOT NULL REFERENCES "nginx_template"("id") ON DELETE CASCADE, "ip_hash" BOOLEAN NOT NULL DEFAULT FALSE, diff --git a/backend/internal/database/helpers.go b/backend/internal/database/helpers.go index 0cd74473..0624cbda 100644 --- a/backend/internal/database/helpers.go +++ b/backend/internal/database/helpers.go @@ -1,8 +1,27 @@ package database +import ( + "fmt" + "npm/internal/config" + "strings" +) + const ( // DateFormat for DateFormat DateFormat = "2006-01-02" // DateTimeFormat for DateTimeFormat DateTimeFormat = "2006-01-02T15:04:05" ) + +// QuoteTableName is a special function that will quote a table +// name based on the driver. Gorm normally handles this but this +// is for special cases where we run raw sql +func QuoteTableName(tbl string) string { + switch strings.ToLower(config.Configuration.DB.Driver) { + case config.DatabasePostgres: + return fmt.Sprintf(`"%s"`, tbl) + default: + // This is the same for Mysql and Sqlite + return fmt.Sprintf("`%s`", tbl) + } +} diff --git a/backend/internal/entity/accesslist/methods.go b/backend/internal/entity/accesslist/methods.go index 650aeb09..9485d51f 100644 --- a/backend/internal/entity/accesslist/methods.go +++ b/backend/internal/entity/accesslist/methods.go @@ -21,7 +21,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter) (entity.ListResponse, Direction: "ASC", } - dbo := entity.ListQueryBuilder(&pageInfo, defaultSort, filters, entity.GetFilterMap(Model{}, true)) + dbo := entity.ListQueryBuilder(&pageInfo, filters, entity.GetFilterMap(Model{}, true)) // Get count of items in this search var totalRows int64 @@ -31,7 +31,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter) (entity.ListResponse, // Get rows items := make([]Model, 0) - if res := dbo.Find(&items); res.Error != nil { + if res := entity.AddOrderToList(dbo, &pageInfo, defaultSort).Find(&items); res.Error != nil { return result, res.Error } diff --git a/backend/internal/entity/certificate/methods.go b/backend/internal/entity/certificate/methods.go index 69224bbf..28b0ca46 100644 --- a/backend/internal/entity/certificate/methods.go +++ b/backend/internal/entity/certificate/methods.go @@ -37,7 +37,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter, expand []string) (ent Direction: "ASC", } - dbo := entity.ListQueryBuilder(&pageInfo, defaultSort, filters, entity.GetFilterMap(Model{}, true)) + dbo := entity.ListQueryBuilder(&pageInfo, filters, entity.GetFilterMap(Model{}, true)) // Get count of items in this search var totalRows int64 @@ -47,7 +47,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter, expand []string) (ent // Get rows items := make([]Model, 0) - if res := dbo.Find(&items); res.Error != nil { + if res := entity.AddOrderToList(dbo, &pageInfo, defaultSort).Find(&items); res.Error != nil { return result, res.Error } diff --git a/backend/internal/entity/certificateauthority/methods.go b/backend/internal/entity/certificateauthority/methods.go index d3f911ef..c77c370e 100644 --- a/backend/internal/entity/certificateauthority/methods.go +++ b/backend/internal/entity/certificateauthority/methods.go @@ -21,7 +21,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter) (entity.ListResponse, Direction: "ASC", } - dbo := entity.ListQueryBuilder(&pageInfo, defaultSort, filters, entity.GetFilterMap(Model{}, true)) + dbo := entity.ListQueryBuilder(&pageInfo, filters, entity.GetFilterMap(Model{}, true)) // Get count of items in this search var totalRows int64 @@ -31,7 +31,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter) (entity.ListResponse, // Get rows items := make([]Model, 0) - if res := dbo.Find(&items); res.Error != nil { + if res := entity.AddOrderToList(dbo, &pageInfo, defaultSort).Find(&items); res.Error != nil { return result, res.Error } diff --git a/backend/internal/entity/dnsprovider/methods.go b/backend/internal/entity/dnsprovider/methods.go index b0081565..036e4981 100644 --- a/backend/internal/entity/dnsprovider/methods.go +++ b/backend/internal/entity/dnsprovider/methods.go @@ -21,7 +21,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter) (entity.ListResponse, Direction: "ASC", } - dbo := entity.ListQueryBuilder(&pageInfo, defaultSort, filters, entity.GetFilterMap(Model{}, true)) + dbo := entity.ListQueryBuilder(&pageInfo, filters, entity.GetFilterMap(Model{}, true)) // Get count of items in this search var totalRows int64 @@ -31,7 +31,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter) (entity.ListResponse, // Get rows items := make([]Model, 0) - if res := dbo.Find(&items); res.Error != nil { + if res := entity.AddOrderToList(dbo, &pageInfo, defaultSort).Find(&items); res.Error != nil { return result, res.Error } diff --git a/backend/internal/entity/host/methods.go b/backend/internal/entity/host/methods.go index d4dd9559..52209235 100644 --- a/backend/internal/entity/host/methods.go +++ b/backend/internal/entity/host/methods.go @@ -23,7 +23,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter, expand []string) (ent Direction: "ASC", } - dbo := entity.ListQueryBuilder(&pageInfo, defaultSort, filters, entity.GetFilterMap(Model{}, true)) + dbo := entity.ListQueryBuilder(&pageInfo, filters, entity.GetFilterMap(Model{}, true)) // Get count of items in this search var totalRows int64 @@ -33,7 +33,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter, expand []string) (ent // Get rows items := make([]Model, 0) - if res := dbo.Find(&items); res.Error != nil { + if res := entity.AddOrderToList(dbo, &pageInfo, defaultSort).Find(&items); res.Error != nil { return result, res.Error } diff --git a/backend/internal/entity/lists.go b/backend/internal/entity/lists.go index f3fbb7cf..7b51e11d 100644 --- a/backend/internal/entity/lists.go +++ b/backend/internal/entity/lists.go @@ -25,13 +25,21 @@ type ListResponse struct { // ListQueryBuilder is used to setup queries for lists func ListQueryBuilder( pageInfo *model.PageInfo, - defaultSort model.Sort, filters []model.Filter, filterMap map[string]filterMapValue, ) *gorm.DB { scopes := make([]func(*gorm.DB) *gorm.DB, 0) - scopes = append(scopes, ScopeOrderBy(pageInfo, defaultSort)) scopes = append(scopes, ScopeOffsetLimit(pageInfo)) scopes = append(scopes, ScopeFilters(filters, filterMap)) return database.GetDB().Scopes(scopes...) } + +// AddOrderToList is used after query above is used for counting +// Postgres in particular doesn't like count(*) when ordering at the same time +func AddOrderToList( + dbo *gorm.DB, + pageInfo *model.PageInfo, + defaultSort model.Sort, +) *gorm.DB { + return dbo.Scopes(ScopeOrderBy(pageInfo, defaultSort)) +} diff --git a/backend/internal/entity/nginxtemplate/methods.go b/backend/internal/entity/nginxtemplate/methods.go index 5c51dd2f..aaa320b8 100644 --- a/backend/internal/entity/nginxtemplate/methods.go +++ b/backend/internal/entity/nginxtemplate/methods.go @@ -21,7 +21,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter) (entity.ListResponse, Direction: "ASC", } - dbo := entity.ListQueryBuilder(&pageInfo, defaultSort, filters, entity.GetFilterMap(Model{}, true)) + dbo := entity.ListQueryBuilder(&pageInfo, filters, entity.GetFilterMap(Model{}, true)) // Get count of items in this search var totalRows int64 @@ -31,7 +31,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter) (entity.ListResponse, // Get rows items := make([]Model, 0) - if res := dbo.Find(&items); res.Error != nil { + if res := entity.AddOrderToList(dbo, &pageInfo, defaultSort).Find(&items); res.Error != nil { return result, res.Error } diff --git a/backend/internal/entity/setting/methods.go b/backend/internal/entity/setting/methods.go index 4886c1ff..63d6fb1a 100644 --- a/backend/internal/entity/setting/methods.go +++ b/backend/internal/entity/setting/methods.go @@ -30,7 +30,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter) (entity.ListResponse, Direction: "ASC", } - dbo := entity.ListQueryBuilder(&pageInfo, defaultSort, filters, entity.GetFilterMap(Model{}, true)) + dbo := entity.ListQueryBuilder(&pageInfo, filters, entity.GetFilterMap(Model{}, true)) // Get count of items in this search var totalRows int64 @@ -40,7 +40,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter) (entity.ListResponse, // Get rows items := make([]Model, 0) - if res := dbo.Find(&items); res.Error != nil { + if res := entity.AddOrderToList(dbo, &pageInfo, defaultSort).Find(&items); res.Error != nil { return result, res.Error } diff --git a/backend/internal/entity/stream/methods.go b/backend/internal/entity/stream/methods.go index d70262bf..bbd38e27 100644 --- a/backend/internal/entity/stream/methods.go +++ b/backend/internal/entity/stream/methods.go @@ -21,7 +21,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter) (entity.ListResponse, Direction: "ASC", } - dbo := entity.ListQueryBuilder(&pageInfo, defaultSort, filters, entity.GetFilterMap(Model{}, true)) + dbo := entity.ListQueryBuilder(&pageInfo, filters, entity.GetFilterMap(Model{}, true)) // Get count of items in this search var totalRows int64 @@ -31,7 +31,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter) (entity.ListResponse, // Get rows items := make([]Model, 0) - if res := dbo.Find(&items); res.Error != nil { + if res := entity.AddOrderToList(dbo, &pageInfo, defaultSort).Find(&items); res.Error != nil { return result, res.Error } diff --git a/backend/internal/entity/upstream/methods.go b/backend/internal/entity/upstream/methods.go index f8859646..2c090658 100644 --- a/backend/internal/entity/upstream/methods.go +++ b/backend/internal/entity/upstream/methods.go @@ -21,7 +21,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter, expand []string) (ent Direction: "ASC", } - dbo := entity.ListQueryBuilder(&pageInfo, defaultSort, filters, entity.GetFilterMap(Model{}, true)) + dbo := entity.ListQueryBuilder(&pageInfo, filters, entity.GetFilterMap(Model{}, true)) // Get count of items in this search var totalRows int64 @@ -31,7 +31,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter, expand []string) (ent // Get rows items := make([]Model, 0) - if res := dbo.Find(&items); res.Error != nil { + if res := entity.AddOrderToList(dbo, &pageInfo, defaultSort).Find(&items); res.Error != nil { return result, res.Error } diff --git a/backend/internal/entity/upstreamserver/methods.go b/backend/internal/entity/upstreamserver/methods.go index eaf1057b..e14b8485 100644 --- a/backend/internal/entity/upstreamserver/methods.go +++ b/backend/internal/entity/upstreamserver/methods.go @@ -30,7 +30,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter) (entity.ListResponse, Direction: "ASC", } - dbo := entity.ListQueryBuilder(&pageInfo, defaultSort, filters, entity.GetFilterMap(Model{}, true)) + dbo := entity.ListQueryBuilder(&pageInfo, filters, entity.GetFilterMap(Model{}, true)) // Get count of items in this search var totalRows int64 @@ -40,7 +40,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter) (entity.ListResponse, // Get rows items := make([]Model, 0) - if res := dbo.Find(&items); res.Error != nil { + if res := entity.AddOrderToList(dbo, &pageInfo, defaultSort).Find(&items); res.Error != nil { return result, res.Error } diff --git a/backend/internal/entity/user/methods.go b/backend/internal/entity/user/methods.go index b6dbb19a..6ae14350 100644 --- a/backend/internal/entity/user/methods.go +++ b/backend/internal/entity/user/methods.go @@ -1,6 +1,7 @@ package user import ( + "fmt" "npm/internal/database" "npm/internal/entity" "npm/internal/logger" @@ -41,7 +42,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter, expand []string) (ent Direction: "ASC", } - dbo := entity.ListQueryBuilder(&pageInfo, defaultSort, filters, entity.GetFilterMap(Model{}, true)) + dbo := entity.ListQueryBuilder(&pageInfo, filters, entity.GetFilterMap(Model{}, true)) // Get count of items in this search var totalRows int64 @@ -51,7 +52,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter, expand []string) (ent // Get rows items := make([]Model, 0) - if res := dbo.Find(&items); res.Error != nil { + if res := entity.AddOrderToList(dbo, &pageInfo, defaultSort).Find(&items); res.Error != nil { return result, res.Error } @@ -84,8 +85,7 @@ func List(pageInfo model.PageInfo, filters []model.Filter, expand []string) (ent func DeleteAll() error { db := database.GetDB() // nolint errcheck - db.Exec("DELETE FROM auth") - result := db.Exec("DELETE FROM user WHERE is_system = ?", false) + result := db.Exec(fmt.Sprintf(`DELETE FROM %s WHERE is_system = ?`, database.QuoteTableName("user")), false) return result.Error }