diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_records.py index 298b30735a..787a9e40fa 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_records.py @@ -41,17 +41,13 @@ async def list_model_records( ) -> ModelsList: """Get a list of models.""" record_store = ApiDependencies.invoker.services.model_records - if base_models and len(base_models) > 0: - models_raw = list() + models = list() + if base_models: for base_model in base_models: - models_raw.extend( - [x.model_dump() for x in record_store.search_by_attr(base_model=base_model, model_type=model_type)] - ) + models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type)) else: - models_raw = [x.model_dump() for x in record_store.search_by_attr(model_type=model_type)] - models = ModelsListValidator.validate_python({"models": models_raw}) - return models - + models.extend(record_store.search_by_attr(model_type=model_type)) + return ModelsList(models=models) @model_records_router.get( "/i/{key}", diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 0fb027e3fa..a353fd88e0 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -107,7 +107,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): config TEXT NOT NULL, created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- Updated via trigger - updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- unique constraint on combo of name, base and type + UNIQUE(name, base, type) ); """ ) @@ -200,6 +202,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): if "UNIQUE constraint failed" in str(e): if "model_config.path" in str(e): msg = f"A model with path '{record.path}' is already installed" + elif "model_config.name" in str(e): + msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed" else: msg = f"A model with key '{key}' is already installed" raise DuplicateModelException(msg) from e diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index ff835b1f3f..2ffdbd9919 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -298,7 +298,7 @@ class ModelConfigFactory(object): @classmethod def make_config( cls, - model_data: Union[dict, ModelConfigBase], + model_data: Union[dict, AnyModelConfig], key: Optional[str] = None, dest_class: Optional[Type] = None, ) -> AnyModelConfig: diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 52a7c40dfd..8d99e581a8 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -170,6 +170,44 @@ def test_filter(store: ModelRecordServiceBase): matches = store.all_models() assert len(matches) == 3 +def test_unique(store: ModelRecordServiceBase): + config1 = MainDiffusersConfig( + path="/tmp/config1", + base=BaseModelType("sd-1"), + type=ModelType("main"), + name="nonuniquename", + original_hash="CONFIG1HASH", + ) + config2 = MainDiffusersConfig( + path="/tmp/config2", + base=BaseModelType("sd-2"), + type=ModelType("main"), + name="nonuniquename", + original_hash="CONFIG1HASH", + ) + config3 = VaeDiffusersConfig( + path="/tmp/config3", + base=BaseModelType("sd-2"), + type=ModelType("vae"), + name="nonuniquename", + original_hash="CONFIG1HASH", + ) + config4 = MainDiffusersConfig( + path="/tmp/config4", + base=BaseModelType("sd-1"), + type=ModelType("main"), + name="nonuniquename", + original_hash="CONFIG1HASH", + ) + # config1, config2 and config3 are compatible because they have unique combos + # of name, type and base + for c in config1, config2, config3: + store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c) + + # config4 clashes with config1 and should raise an integrity error + with pytest.raises(DuplicateModelException): + store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), config4) + def test_filter_2(store: ModelRecordServiceBase): config1 = MainDiffusersConfig( @@ -196,7 +234,7 @@ def test_filter_2(store: ModelRecordServiceBase): config4 = MainDiffusersConfig( path="/tmp/config4", name="dup_name1", - base=BaseModelType("sd-2"), + base=BaseModelType("sdxl"), type=ModelType("main"), original_hash="CONFIG3HASH", )