diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 5e12a9923f..2de31dc2bd 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -19,6 +19,7 @@ from invokeai.app.services.model_records import ( ModelSummary, UnknownModelException, ) +from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager.config import ( AnyModelConfig, @@ -263,15 +264,13 @@ async def scan_for_models( ) async def update_model_record( key: Annotated[str, Path(description="Unique key of model")], - info: Annotated[ - AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) - ], + changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)], ) -> AnyModelConfig: - """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" + """Update a model's config.""" logger = ApiDependencies.invoker.services.logger record_store = ApiDependencies.invoker.services.model_manager.store try: - model_response: AnyModelConfig = record_store.update_model(key, config=info) + model_response: AnyModelConfig = record_store.update_model(key, changes=changes) logger.info(f"Updated model: {key}") except UnknownModelException as e: raise HTTPException(status_code=404, detail=str(e)) @@ -563,7 +562,8 @@ async def convert_model( # temporarily rename the original safetensors file so that there is no naming conflict original_name = model_config.name model_config.name = f"{original_name}.DELETE" - store.update_model(key, config=model_config) + changes = ModelRecordChanges(name=model_config.name) + store.update_model(key, changes=changes) # install the diffusers try: diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 1650f52a65..8d0110b1e1 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -6,7 +6,7 @@ Abstract base class for storing and retrieving model configuration records. from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Union +from typing import List, Optional, Set, Union from pydantic import BaseModel, Field @@ -18,7 +18,7 @@ from invokeai.backend.model_manager import ( ModelFormat, ModelType, ) -from invokeai.backend.model_manager.config import ModelDefaultSettings +from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType class DuplicateModelException(Exception): @@ -60,15 +60,24 @@ class ModelSummary(BaseModel): class ModelRecordChanges(BaseModelExcludeNull, extra="allow"): - """A set of changes to apply to model metadata. - Only limited changes are valid: - - `default_settings`: the user-configured default settings for this model - """ + """A set of changes to apply to a model.""" + # Changes applicable to all models + name: Optional[str] = Field(description="Name of the model.", default=None) + description: Optional[str] = Field(description="Model description", default=None) + base: Optional[BaseModelType] = Field(description="The base model.", default=None) + trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) default_settings: Optional[ModelDefaultSettings] = Field( - default=None, description="The user-configured default settings for this model" + description="Default settings for this model", default=None ) - """The user-configured default settings for this model""" + + # Checkpoint-specific changes + # TODO(MM2): Should we expose these? Feels footgun-y... + variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None) + prediction_type: Optional[SchedulerPredictionType] = Field( + description="The prediction type of the model.", default=None + ) + upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None) class ModelRecordServiceBase(ABC): @@ -99,13 +108,12 @@ class ModelRecordServiceBase(ABC): pass @abstractmethod - def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: + def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig: """ Update the model, returning the updated version. - :param key: Unique key for the model to be updated - :param config: Model configuration record. Either a dict with the - required fields, or a ModelConfigBase instance. + :param key: Unique key for the model to be updated. + :param changes: A set of changes to apply to this model. Changes are validated before being written. """ pass @@ -194,21 +202,3 @@ class ModelRecordServiceBase(ABC): f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'." ) return model_configs[0] - - def rename_model( - self, - key: str, - new_name: str, - ) -> AnyModelConfig: - """ - Rename the indicated model. Just a special case of update_model(). - - In some implementations, renaming the model may involve changing where - it is stored on the filesystem. So this is broken out. - - :param key: Model key - :param new_name: New name for model - """ - config = self.get_model(key) - config.name = new_name - return self.update_model(key, config) diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index ce4467d833..35c182fb9d 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -43,7 +43,7 @@ import json import sqlite3 from math import ceil from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional, Union from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager.config import ( @@ -57,6 +57,7 @@ from invokeai.backend.model_manager.config import ( from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_records_base import ( DuplicateModelException, + ModelRecordChanges, ModelRecordOrderBy, ModelRecordServiceBase, ModelSummary, @@ -151,16 +152,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): self._db.conn.rollback() raise e - def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: - """ - Update the model, returning the updated version. + def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig: + record = self.get_model(key) + + # Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic. + for field_name in changes.model_fields_set: + setattr(record, field_name, getattr(changes, field_name)) + + json_serialized = record.model_dump_json() - :param key: Unique key for the model to be updated - :param config: Model configuration record. Either a dict with the - required fields, or a ModelConfigBase instance. - """ - record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect - json_serialized = record.model_dump_json() # and turn it into a json string. with self._db.lock: try: self._cursor.execute( diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 676edd3758..78149af190 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -162,7 +162,7 @@ class ModelConfigBase(BaseModel): description="Default settings for this model", default=None ) - model_config = ConfigDict(use_enum_values=False, validate_assignment=True) + model_config = ConfigDict(validate_assignment=True) class CheckpointConfigBase(ModelConfigBase): diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 1552d4c07d..a0151de4cb 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -6,6 +6,7 @@ from hashlib import sha256 from typing import Any, Optional import pytest +from pydantic import ValidationError from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.model_records import ( @@ -14,6 +15,7 @@ from invokeai.app.services.model_records import ( ModelRecordServiceSQL, UnknownModelException, ) +from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.backend.model_manager.config import ( BaseModelType, MainDiffusersConfig, @@ -73,34 +75,33 @@ def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase): store.add_model(config2) -def test_update(store: ModelRecordServiceBase): +def test_model_records_updates_model(store: ModelRecordServiceBase): config = example_config("key1") store.add_model(config) config = store.get_model("key1") assert config.name == "old name" - - config.name = "new name" - store.update_model(config.key, config) + new_name = "new name" + changes = ModelRecordChanges(name=new_name) + store.update_model(config.key, changes) new_config = store.get_model("key1") - assert new_config.name == "new name" + assert new_config.name == new_name -def test_rename(store: ModelRecordServiceBase): +def test_model_records_rejects_invalid_changes(store: ModelRecordServiceBase): config = example_config("key1") store.add_model(config) config = store.get_model("key1") - assert config.name == "old name" - - store.rename_model("key1", "new name") - new_config = store.get_model("key1") - assert new_config.name == "new name" + # upcast_attention is an invalid field for TIs + changes = ModelRecordChanges(upcast_attention=True) + with pytest.raises(ValidationError): + store.update_model(config.key, changes) def test_unknown_key(store: ModelRecordServiceBase): config = example_config("key1") store.add_model(config) with pytest.raises(UnknownModelException): - store.update_model("unknown_key", config) + store.update_model("unknown_key", ModelRecordChanges()) def test_delete(store: ModelRecordServiceBase):