diff --git a/backend/internal/api/handler/upstreams.go b/backend/internal/api/handler/upstreams.go index d5469f81..4794ece6 100644 --- a/backend/internal/api/handler/upstreams.go +++ b/backend/internal/api/handler/upstreams.go @@ -9,6 +9,7 @@ import ( c "npm/internal/api/context" h "npm/internal/api/http" "npm/internal/api/middleware" + "npm/internal/entity/host" "npm/internal/entity/upstream" "npm/internal/jobqueue" "npm/internal/logger" @@ -151,6 +152,12 @@ func DeleteUpstream() func(http.ResponseWriter, *http.Request) { case sql.ErrNoRows: h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil) case nil: + // Ensure that this upstream isn't in use by a host + cnt := host.GetUpstreamUseCount(upstreamID) + if cnt > 0 { + h.ResultErrorJSON(w, r, http.StatusBadRequest, "Cannot delete upstream that is in use by at least 1 host", nil) + return + } h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) configureUpstream(item) default: diff --git a/backend/internal/entity/host/methods.go b/backend/internal/entity/host/methods.go index b40fa018..0c0b0781 100644 --- a/backend/internal/entity/host/methods.go +++ b/backend/internal/entity/host/methods.go @@ -29,7 +29,7 @@ func create(host *Model) (int, error) { db := database.GetInstance() // nolint: gosec - result, err := db.NamedExec(`INSERT INTO `+fmt.Sprintf("`%s`", tableName)+` ( + result, err := db.NamedExec(`INSERT INTO `+tableName+` ( created_on, modified_on, user_id, @@ -198,6 +198,21 @@ func List(pageInfo model.PageInfo, filters []model.Filter, expand []string) (Lis return result, nil } +// GetUpstreamUseCount returns the number of hosts that are using +// an upstream, and have not been deleted. +func GetUpstreamUseCount(upstreamID int) int { + db := database.GetInstance() + query := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE upstream_id = ? AND is_deleted = ?", tableName) + countRow := db.QueryRowx(query, upstreamID, 0) + var totalRows int + queryErr := countRow.Scan(&totalRows) + if queryErr != nil && queryErr != sql.ErrNoRows { + logger.Debug("%s", query) + return 0 + } + return totalRows +} + // AddPendingJobs is intended to be used at startup to add // anything pending to the JobQueue just once, based on // the database row status