feat(mm): revise update_model to use ModelRecordChanges

This commit is contained in:
psychedelicious 2024-03-05 12:04:27 +11:00
parent 37b969d339
commit 5551cf8ac4
5 changed files with 50 additions and 59 deletions

View File

@ -19,6 +19,7 @@ from invokeai.app.services.model_records import (
ModelSummary, ModelSummary,
UnknownModelException, UnknownModelException,
) )
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
@ -263,15 +264,13 @@ async def scan_for_models(
) )
async def update_model_record( async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")], key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[ changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" """Update a model's config."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store record_store = ApiDependencies.invoker.services.model_manager.store
try: try:
model_response: AnyModelConfig = record_store.update_model(key, config=info) model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
logger.info(f"Updated model: {key}") logger.info(f"Updated model: {key}")
except UnknownModelException as e: except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@ -563,7 +562,8 @@ async def convert_model(
# temporarily rename the original safetensors file so that there is no naming conflict # temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name original_name = model_config.name
model_config.name = f"{original_name}.DELETE" model_config.name = f"{original_name}.DELETE"
store.update_model(key, config=model_config) changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# install the diffusers # install the diffusers
try: try:

View File

@ -6,7 +6,7 @@ Abstract base class for storing and retrieving model configuration records.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union from typing import List, Optional, Set, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -18,7 +18,7 @@ from invokeai.backend.model_manager import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.config import ModelDefaultSettings from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
class DuplicateModelException(Exception): class DuplicateModelException(Exception):
@ -60,15 +60,24 @@ class ModelSummary(BaseModel):
class ModelRecordChanges(BaseModelExcludeNull, extra="allow"): class ModelRecordChanges(BaseModelExcludeNull, extra="allow"):
"""A set of changes to apply to model metadata. """A set of changes to apply to a model."""
Only limited changes are valid:
- `default_settings`: the user-configured default settings for this model
"""
# Changes applicable to all models
name: Optional[str] = Field(description="Name of the model.", default=None)
description: Optional[str] = Field(description="Model description", default=None)
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[ModelDefaultSettings] = Field( default_settings: Optional[ModelDefaultSettings] = Field(
default=None, description="The user-configured default settings for this model" description="Default settings for this model", default=None
) )
"""The user-configured default settings for this model"""
# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None
)
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
class ModelRecordServiceBase(ABC): class ModelRecordServiceBase(ABC):
@ -99,13 +108,12 @@ class ModelRecordServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
""" """
Update the model, returning the updated version. Update the model, returning the updated version.
:param key: Unique key for the model to be updated :param key: Unique key for the model to be updated.
:param config: Model configuration record. Either a dict with the :param changes: A set of changes to apply to this model. Changes are validated before being written.
required fields, or a ModelConfigBase instance.
""" """
pass pass
@ -194,21 +202,3 @@ class ModelRecordServiceBase(ABC):
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'." f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
) )
return model_configs[0] return model_configs[0]
def rename_model(
self,
key: str,
new_name: str,
) -> AnyModelConfig:
"""
Rename the indicated model. Just a special case of update_model().
In some implementations, renaming the model may involve changing where
it is stored on the filesystem. So this is broken out.
:param key: Model key
:param new_name: New name for model
"""
config = self.get_model(key)
config.name = new_name
return self.update_model(key, config)

View File

@ -43,7 +43,7 @@ import json
import sqlite3 import sqlite3
from math import ceil from math import ceil
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import List, Optional, Union
from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
@ -57,6 +57,7 @@ from invokeai.backend.model_manager.config import (
from ..shared.sqlite.sqlite_database import SqliteDatabase from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import ( from .model_records_base import (
DuplicateModelException, DuplicateModelException,
ModelRecordChanges,
ModelRecordOrderBy, ModelRecordOrderBy,
ModelRecordServiceBase, ModelRecordServiceBase,
ModelSummary, ModelSummary,
@ -151,16 +152,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback() self._db.conn.rollback()
raise e raise e
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
""" record = self.get_model(key)
Update the model, returning the updated version.
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(

View File

@ -162,7 +162,7 @@ class ModelConfigBase(BaseModel):
description="Default settings for this model", default=None description="Default settings for this model", default=None
) )
model_config = ConfigDict(use_enum_values=False, validate_assignment=True) model_config = ConfigDict(validate_assignment=True)
class CheckpointConfigBase(ModelConfigBase): class CheckpointConfigBase(ModelConfigBase):

View File

@ -6,6 +6,7 @@ from hashlib import sha256
from typing import Any, Optional from typing import Any, Optional
import pytest import pytest
from pydantic import ValidationError
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
@ -14,6 +15,7 @@ from invokeai.app.services.model_records import (
ModelRecordServiceSQL, ModelRecordServiceSQL,
UnknownModelException, UnknownModelException,
) )
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
BaseModelType, BaseModelType,
MainDiffusersConfig, MainDiffusersConfig,
@ -73,34 +75,33 @@ def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase):
store.add_model(config2) store.add_model(config2)
def test_update(store: ModelRecordServiceBase): def test_model_records_updates_model(store: ModelRecordServiceBase):
config = example_config("key1") config = example_config("key1")
store.add_model(config) store.add_model(config)
config = store.get_model("key1") config = store.get_model("key1")
assert config.name == "old name" assert config.name == "old name"
new_name = "new name"
config.name = "new name" changes = ModelRecordChanges(name=new_name)
store.update_model(config.key, config) store.update_model(config.key, changes)
new_config = store.get_model("key1") new_config = store.get_model("key1")
assert new_config.name == "new name" assert new_config.name == new_name
def test_rename(store: ModelRecordServiceBase): def test_model_records_rejects_invalid_changes(store: ModelRecordServiceBase):
config = example_config("key1") config = example_config("key1")
store.add_model(config) store.add_model(config)
config = store.get_model("key1") config = store.get_model("key1")
assert config.name == "old name" # upcast_attention is an invalid field for TIs
changes = ModelRecordChanges(upcast_attention=True)
store.rename_model("key1", "new name") with pytest.raises(ValidationError):
new_config = store.get_model("key1") store.update_model(config.key, changes)
assert new_config.name == "new name"
def test_unknown_key(store: ModelRecordServiceBase): def test_unknown_key(store: ModelRecordServiceBase):
config = example_config("key1") config = example_config("key1")
store.add_model(config) store.add_model(config)
with pytest.raises(UnknownModelException): with pytest.raises(UnknownModelException):
store.update_model("unknown_key", config) store.update_model("unknown_key", ModelRecordChanges())
def test_delete(store: ModelRecordServiceBase): def test_delete(store: ModelRecordServiceBase):