diff --git a/backend/internal/api/handler/certificates.go b/backend/internal/api/handler/certificates.go index e9943936..44f5f3d0 100644 --- a/backend/internal/api/handler/certificates.go +++ b/backend/internal/api/handler/certificates.go @@ -1,6 +1,7 @@ package handler import ( + "database/sql" "encoding/json" "fmt" "net/http" @@ -10,6 +11,7 @@ import ( "npm/internal/api/middleware" "npm/internal/api/schema" "npm/internal/entity/certificate" + "npm/internal/entity/host" "npm/internal/jobqueue" "npm/internal/logger" ) @@ -141,11 +143,20 @@ func DeleteCertificate() func(http.ResponseWriter, *http.Request) { return } - cert, err := certificate.GetByID(certificateID) - if err != nil { + item, err := certificate.GetByID(certificateID) + switch err { + case sql.ErrNoRows: + h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil) + case nil: + // Ensure that this upstream isn't in use by a host + cnt := host.GetCertificateUseCount(certificateID) + if cnt > 0 { + h.ResultErrorJSON(w, r, http.StatusBadRequest, "Cannot delete certificate that is in use by at least 1 host", nil) + return + } + h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) + default: h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { - h.ResultResponseJSON(w, r, http.StatusOK, cert.Delete()) } } } diff --git a/backend/internal/entity/host/methods.go b/backend/internal/entity/host/methods.go index 0c0b0781..ddb1ebd4 100644 --- a/backend/internal/entity/host/methods.go +++ b/backend/internal/entity/host/methods.go @@ -213,6 +213,21 @@ func GetUpstreamUseCount(upstreamID int) int { return totalRows } +// GetCertificateUseCount returns the number of hosts that are using +// a certificate, and have not been deleted. +func GetCertificateUseCount(certificateID int) int { + db := database.GetInstance() + query := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE certificate_id = ? AND is_deleted = ?", tableName) + countRow := db.QueryRowx(query, certificateID, 0) + var totalRows int + queryErr := countRow.Scan(&totalRows) + if queryErr != nil && queryErr != sql.ErrNoRows { + logger.Debug("%s", query) + return 0 + } + return totalRows +} + // AddPendingJobs is intended to be used at startup to add // anything pending to the JobQueue just once, based on // the database row status