diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index d09ddecac7..26a384d376 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -39,7 +39,7 @@ class ModelRecordServiceBase(ABC): pass @abstractmethod - def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> ModelConfigBase: + def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig: """ Add a model to the database. @@ -110,7 +110,7 @@ class ModelRecordServiceBase(ABC): pass @abstractmethod - def search_by_name( + def search_by_attr( self, model_name: Optional[str] = None, base_model: Optional[BaseModelType] = None, @@ -130,16 +130,16 @@ class ModelRecordServiceBase(ABC): def all_models(self) -> List[AnyModelConfig]: """Return all the model configs in the database.""" - return self.search_by_name() + return self.search_by_attr() - def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase: + def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> AnyModelConfig: """ Return information about a single model using its name, base type and model type. If there are more than one model that match, raises a DuplicateModelException. If no model matches, raises an UnknownModelException """ - model_configs = self.search_by_name(model_name=model_name, base_model=base_model, model_type=model_type) + model_configs = self.search_by_attr(model_name=model_name, base_model=base_model, model_type=model_type) if len(model_configs) > 1: raise DuplicateModelException( f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'." @@ -154,7 +154,7 @@ class ModelRecordServiceBase(ABC): self, key: str, new_name: str, - ) -> ModelConfigBase: + ) -> AnyModelConfig: """ Rename the indicated model. Just a special case of update_model(). @@ -164,4 +164,6 @@ class ModelRecordServiceBase(ABC): :param key: Model key :param new_name: New name for model """ - return self.update_model(key, {"name": new_name}) + config = self.get_model(key) + config.name = new_name + return self.update_model(key, config) diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index da911994bb..1af3ec8859 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -36,7 +36,7 @@ Typical usage: # searching configs = store.search_by_path(path='/tmp/pokemon.bin') configs = store.search_by_hash('750a499f35e43b7e1b4d15c207aa2f01') - configs = store.search_by_name(base_model='sd-2', model_type='main') + configs = store.search_by_attr(base_model='sd-2', model_type='main') """ @@ -77,7 +77,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """ super().__init__() self._db = db - self._db.conn.row_factory = sqlite3.Row self._cursor = self._db.conn.cursor() with self._db.lock: @@ -157,7 +156,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): ("version", CONFIG_FILE_VERSION), ) - def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: + def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> AnyModelConfig: """ Add a model to the database. @@ -168,7 +167,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): Can raise DuplicateModelException and InvalidModelConfigException exceptions. """ record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect. - json_serialized = json.dumps(record.model_dump()) # and turn it into a json string. + json_serialized = record.model_dump_json() # and turn it into a json string. with self._db.lock: try: self._cursor.execute( @@ -252,7 +251,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): self._db.conn.rollback() raise e - def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: + def update_model(self, key: str, config: ModelConfigBase) -> AnyModelConfig: """ Update the model, returning the updated version. @@ -261,7 +260,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): required fields, or a ModelConfigBase instance. """ record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect - json_serialized = json.dumps(record.model_dump()) # and turn it into a json string. + json_serialized = record.model_dump_json() # and turn it into a json string. with self._db.lock: try: self._cursor.execute( @@ -328,7 +327,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): raise e return count > 0 - def search_by_name( + def search_by_attr( self, model_name: Optional[str] = None, base_model: Optional[BaseModelType] = None, diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 2937eb3a27..7a6a2589c6 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -127,14 +127,14 @@ class ModelConfigBase(BaseModel): setattr(self, key, value) # may raise a validation error -class CheckpointConfig(ModelConfigBase): +class _CheckpointConfig(ModelConfigBase): """Model config for checkpoint-style models.""" format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint config: str = Field(description="path to the checkpoint model config file") -class DiffusersConfig(ModelConfigBase): +class _DiffusersConfig(ModelConfigBase): """Model config for diffusers-style models.""" format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers @@ -158,13 +158,13 @@ class VaeDiffusersConfig(ModelConfigBase): format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers -class ControlNetDiffusersConfig(DiffusersConfig): +class ControlNetDiffusersConfig(_DiffusersConfig): """Model config for ControlNet models (diffusers version).""" format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers -class ControlNetCheckpointConfig(CheckpointConfig): +class ControlNetCheckpointConfig(_CheckpointConfig): """Model config for ControlNet models (diffusers version).""" format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint @@ -176,29 +176,29 @@ class TextualInversionConfig(ModelConfigBase): format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder] -class MainConfig(ModelConfigBase): +class _MainConfig(ModelConfigBase): """Model config for main models.""" - vae: Optional[str] = Field(None) + vae: Optional[str] = Field(default=None) variant: ModelVariantType = ModelVariantType.Normal ztsnr_training: bool = False -class MainCheckpointConfig(CheckpointConfig, MainConfig): +class MainCheckpointConfig(_CheckpointConfig, _MainConfig): """Model config for main checkpoint models.""" # Note that we do not need prediction_type or upcast_attention here # because they are provided in the checkpoint's own config file. -class MainDiffusersConfig(DiffusersConfig, MainConfig): +class MainDiffusersConfig(_DiffusersConfig, _MainConfig): """Model config for main diffusers models.""" prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False -class ONNXSD1Config(MainConfig): +class ONNXSD1Config(_MainConfig): """Model config for ONNX format models based on sd-1.""" format: Literal[ModelFormat.Onnx, ModelFormat.Olive] @@ -206,7 +206,7 @@ class ONNXSD1Config(MainConfig): upcast_attention: bool = False -class ONNXSD2Config(MainConfig): +class ONNXSD2Config(_MainConfig): """Model config for ONNX format models based on sd-2.""" format: Literal[ModelFormat.Onnx, ModelFormat.Olive] diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 8668b8dd53..49e63b93b9 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -16,7 +16,7 @@ from invokeai.app.services.model_records import ( from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.backend.model_manager.config import ( BaseModelType, - DiffusersConfig, + MainDiffusersConfig, ModelType, TextualInversionConfig, VaeDiffusersConfig, @@ -83,6 +83,16 @@ def test_update(store: ModelRecordServiceBase): new_config = store.get_model("key1") assert new_config.name == "new name" +def test_rename(store: ModelRecordServiceBase): + config = example_config() + store.add_model("key1", config) + config = store.get_model("key1") + assert config.name == "old name" + + store.rename_model("key1", "new name") + new_config = store.get_model("key1") + assert new_config.name == "new name" + def test_unknown_key(store: ModelRecordServiceBase): config = example_config() @@ -108,14 +118,14 @@ def test_exists(store: ModelRecordServiceBase): def test_filter(store: ModelRecordServiceBase): - config1 = DiffusersConfig( + config1 = MainDiffusersConfig( path="/tmp/config1", name="config1", base=BaseModelType("sd-1"), type=ModelType("main"), original_hash="CONFIG1HASH", ) - config2 = DiffusersConfig( + config2 = MainDiffusersConfig( path="/tmp/config2", name="config2", base=BaseModelType("sd-1"), @@ -131,17 +141,17 @@ def test_filter(store: ModelRecordServiceBase): ) for c in config1, config2, config3: store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c) - matches = store.search_by_name(model_type=ModelType("main")) + matches = store.search_by_attr(model_type=ModelType("main")) assert len(matches) == 2 assert matches[0].name in {"config1", "config2"} - matches = store.search_by_name(model_type=ModelType("vae")) + matches = store.search_by_attr(model_type=ModelType("vae")) assert len(matches) == 1 assert matches[0].name == "config3" assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest() assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back - matches = store.search_by_name(model_type=BaseModelType("sd-2")) + matches = store.search_by_attr(model_type=BaseModelType("sd-2")) matches = store.search_by_hash("CONFIG1HASH") assert len(matches) == 1 @@ -152,28 +162,28 @@ def test_filter(store: ModelRecordServiceBase): def test_filter_2(store: ModelRecordServiceBase): - config1 = DiffusersConfig( + config1 = MainDiffusersConfig( path="/tmp/config1", name="config1", base=BaseModelType("sd-1"), type=ModelType("main"), original_hash="CONFIG1HASH", ) - config2 = DiffusersConfig( + config2 = MainDiffusersConfig( path="/tmp/config2", name="config2", base=BaseModelType("sd-1"), type=ModelType("main"), original_hash="CONFIG2HASH", ) - config3 = DiffusersConfig( + config3 = MainDiffusersConfig( path="/tmp/config3", name="dup_name1", base=BaseModelType("sd-2"), type=ModelType("main"), original_hash="CONFIG3HASH", ) - config4 = DiffusersConfig( + config4 = MainDiffusersConfig( path="/tmp/config4", name="dup_name1", base=BaseModelType("sd-2"), @@ -190,19 +200,19 @@ def test_filter_2(store: ModelRecordServiceBase): for c in config1, config2, config3, config4, config5: store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c) - matches = store.search_by_name( + matches = store.search_by_attr( model_type=ModelType("main"), model_name="dup_name1", ) assert len(matches) == 2 - matches = store.search_by_name( + matches = store.search_by_attr( base_model=BaseModelType("sd-1"), model_type=ModelType("main"), ) assert len(matches) == 2 - matches = store.search_by_name( + matches = store.search_by_attr( base_model=BaseModelType("sd-1"), model_type=ModelType("vae"), model_name="dup_name1",