diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_records.py index 8df988dec1..d6eb11d671 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_records.py @@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, from ..dependencies import ApiDependencies -model_records_router = APIRouter(prefix="/v1/model_records", tags=["model_records"]) +model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"]) ModelConfigValidator = TypeAdapter(AnyModelConfig) @@ -34,7 +34,7 @@ ModelsListValidator = TypeAdapter(ModelsList) @model_records_router.get( "/", - operation_id="list_model_configs", + operation_id="list_model_recordss", responses={200: {"model": ModelsList}}, ) async def list_model_records( diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 897047a518..2bf3e3acf4 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -27,7 +27,7 @@ Typical usage: # fetching config new_config = store.get_model('key1') - print(new_config.name, new_config.base_model) + print(new_config.name, new_config.base) assert new_config.key == 'key1' # deleting @@ -100,11 +100,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """--sql CREATE TABLE IF NOT EXISTS model_config ( id TEXT NOT NULL PRIMARY KEY, - -- These 4 fields are enums in python, unrestricted string here - base_model TEXT NOT NULL, - model_type TEXT NOT NULL, - model_name TEXT NOT NULL, - model_path TEXT NOT NULL, + -- The next 3 fields are enums in python, unrestricted string here + base TEXT NOT NULL, + type TEXT NOT NULL, + name TEXT NOT NULL, + path TEXT NOT NULL, original_hash TEXT, -- could be null -- Serialized JSON representation of the whole config object, -- which will contain additional fields from subclasses @@ -139,6 +139,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """ ) + # Add indexes for searchable fields + for stmt in [ + "CREATE INDEX IF NOT EXISTS base_index ON model_config(base);", + "CREATE INDEX IF NOT EXISTS type_index ON model_config(type);", + "CREATE INDEX IF NOT EXISTS name_index ON model_config(name);", + "CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);", + ]: + self._cursor.execute(stmt) + # Add our version to the metadata table self._cursor.execute( """--sql @@ -169,10 +178,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """--sql INSERT INTO model_config ( id, - base_model, - model_type, - model_name, - model_path, + base, + type, + name, + path, original_hash, config ) @@ -180,7 +189,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """, ( key, - record.base_model, + record.base, record.type, record.name, record.path, @@ -193,7 +202,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): except sqlite3.IntegrityError as e: self._conn.rollback() if "UNIQUE constraint failed" in str(e): - raise DuplicateModelException(f"A model with key '{key}' is already installed") from e + if "model_config.path" in str(e): + msg = f"A model with path '{record.path}' is already installed" + else: + msg = f"A model with key '{key}' is already installed" + raise DuplicateModelException(msg) from e else: raise e except sqlite3.Error as e: @@ -257,14 +270,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): self._cursor.execute( """--sql UPDATE model_config - SET base_model=?, - model_type=?, - model_name=?, - model_path=?, + SET base=?, + type=?, + name=?, + path=?, config=? WHERE id=?; """, - (record.base_model, record.type, record.name, record.path, json_serialized, key), + (record.base, record.type, record.name, record.path, json_serialized, key), ) if self._cursor.rowcount == 0: raise UnknownModelException("model not found") @@ -338,13 +351,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): where_clause = [] bindings = [] if model_name: - where_clause.append("model_name=?") + where_clause.append("name=?") bindings.append(model_name) if base_model: - where_clause.append("base_model=?") + where_clause.append("base=?") bindings.append(base_model) if model_type: - where_clause.append("model_type=?") + where_clause.append("type=?") bindings.append(model_type) where = f"WHERE {' AND '.join(where_clause)}" if where_clause else "" with self._lock: diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index ab89c093a1..2937eb3a27 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -7,8 +7,8 @@ Typical usage: from invokeai.backend.model_manager import ModelConfigFactory raw = dict(path='models/sd-1/main/foo.ckpt', name='foo', - base_model='sd-1', - model_type='main', + base='sd-1', + type='main', config='configs/stable-diffusion/v1-inference.yaml', variant='normal', format='checkpoint' @@ -103,7 +103,7 @@ class ModelConfigBase(BaseModel): path: str name: str - base_model: BaseModelType + base: BaseModelType type: ModelType format: ModelFormat key: str = Field(description="unique key for model", default="") @@ -181,20 +181,29 @@ class MainConfig(ModelConfigBase): vae: Optional[str] = Field(None) variant: ModelVariantType = ModelVariantType.Normal + ztsnr_training: bool = False class MainCheckpointConfig(CheckpointConfig, MainConfig): """Model config for main checkpoint models.""" + # Note that we do not need prediction_type or upcast_attention here + # because they are provided in the checkpoint's own config file. + class MainDiffusersConfig(DiffusersConfig, MainConfig): """Model config for main diffusers models.""" + prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon + upcast_attention: bool = False + class ONNXSD1Config(MainConfig): """Model config for ONNX format models based on sd-1.""" format: Literal[ModelFormat.Onnx, ModelFormat.Olive] + prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon + upcast_attention: bool = False class ONNXSD2Config(MainConfig): @@ -202,8 +211,8 @@ class ONNXSD2Config(MainConfig): format: Literal[ModelFormat.Onnx, ModelFormat.Olive] # No yaml config file for ONNX, so these are part of config - prediction_type: SchedulerPredictionType - upcast_attention: bool + prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction + upcast_attention: bool = True class IPAdapterConfig(ModelConfigBase): @@ -305,7 +314,7 @@ class ModelConfigFactory(object): try: format = model_data.get("format") type = model_data.get("type") - model_base = model_data.get("base_model") + model_base = model_data.get("base") class_to_return = dest_class or cls._class_map[format][type] if isinstance(class_to_return, dict): # additional level allowed class_to_return = class_to_return[model_base] diff --git a/invokeai/backend/model_manager/migrate_to_db.py b/invokeai/backend/model_manager/migrate_to_db.py index 76216805af..6cb7f478d1 100644 --- a/invokeai/backend/model_manager/migrate_to_db.py +++ b/invokeai/backend/model_manager/migrate_to_db.py @@ -50,7 +50,7 @@ class Migrate: hash = FastModelHash.hash(self.config.models_path / stanza.path) new_key = sha1(model_key.encode("utf-8")).hexdigest() - stanza["base_model"] = BaseModelType(base_type) + stanza["base"] = BaseModelType(base_type) stanza["type"] = ModelType(model_type) stanza["name"] = model_name stanza["original_hash"] = hash diff --git a/pyproject.toml b/pyproject.toml index f0e3c543c7..b7a851ce75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "fastapi~=0.103.2", "fastapi-events~=0.9.1", "huggingface-hub~=0.16.4", + "imohash", "invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids "matplotlib", # needed for plotting of Penner easing functions "mediapipe", # needed for "mediapipeface" controlnet model diff --git a/tests/backend/model_manager_2/test_model_storage_sql.py b/tests/backend/model_manager_2/test_model_storage_sql.py index 4fcb375358..308ec7a66f 100644 --- a/tests/backend/model_manager_2/test_model_storage_sql.py +++ b/tests/backend/model_manager_2/test_model_storage_sql.py @@ -8,7 +8,12 @@ from hashlib import sha256 import pytest from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException +from invokeai.app.services.model_records import ( + DuplicateModelException, + ModelRecordServiceBase, + ModelRecordServiceSQL, + UnknownModelException, +) from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.backend.model_manager.config import ( BaseModelType, @@ -32,8 +37,8 @@ def example_config() -> TextualInversionConfig: return TextualInversionConfig( path="/tmp/pokemon.bin", name="old name", - base_model="sd-1", - type="embedding", + base=BaseModelType("sd-1"), + type=ModelType("embedding"), format="embedding_file", original_hash="ABC123", ) @@ -43,7 +48,7 @@ def test_add(store: ModelRecordServiceBase): raw = dict( path="/tmp/foo.ckpt", name="model1", - base_model="sd-1", + base=BaseModelType("sd-1"), type="main", config="/tmp/foo.yaml", variant="normal", @@ -53,16 +58,25 @@ def test_add(store: ModelRecordServiceBase): store.add_model("key1", raw) config1 = store.get_model("key1") assert config1 is not None - raw["name"] = "model2" - raw["base_model"] = "sd-2" - raw["format"] = "diffusers" - raw.pop("config") - store.add_model("key2", raw) - config2 = store.get_model("key2") + assert config1.base == BaseModelType("sd-1") assert config1.name == "model1" - assert config2.name == "model2" - assert config1.base_model == "sd-1" - assert config2.base_model == "sd-2" + assert config1.original_hash == "111222333444" + assert config1.current_hash is None + + +def test_dup(store: ModelRecordServiceBase): + config = example_config() + store.add_model("key1", example_config()) + try: + store.add_model("key1", config) + assert False, "Duplicate model key should have been caught" + except DuplicateModelException: + assert True + try: + store.add_model("key2", config) + assert False, "Duplicate model path should have been caught" + except DuplicateModelException: + assert True def test_update(store: ModelRecordServiceBase): @@ -115,21 +129,21 @@ def test_filter(store: ModelRecordServiceBase): config1 = DiffusersConfig( path="/tmp/config1", name="config1", - base_model=BaseModelType("sd-1"), + base=BaseModelType("sd-1"), type=ModelType("main"), original_hash="CONFIG1HASH", ) config2 = DiffusersConfig( path="/tmp/config2", name="config2", - base_model=BaseModelType("sd-1"), + base=BaseModelType("sd-1"), type=ModelType("main"), original_hash="CONFIG2HASH", ) config3 = VaeDiffusersConfig( path="/tmp/config3", name="config3", - base_model=BaseModelType("sd-2"), + base=BaseModelType("sd-2"), type=ModelType("vae"), original_hash="CONFIG3HASH", )