feat(mm): revise update_model to use ModelRecordChanges

This commit is contained in:
psychedelicious 2024-03-05 12:04:27 +11:00
parent 37b969d339
commit 5551cf8ac4
5 changed files with 50 additions and 59 deletions

View File

@ -19,6 +19,7 @@ from invokeai.app.services.model_records import (
ModelSummary,
UnknownModelException,
)
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import (
AnyModelConfig,
@ -263,15 +264,13 @@ async def scan_for_models(
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
"""Update a model's config."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
try:
model_response: AnyModelConfig = record_store.update_model(key, config=info)
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@ -563,7 +562,8 @@ async def convert_model(
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
store.update_model(key, config=model_config)
changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# install the diffusers
try:

View File

@ -6,7 +6,7 @@ Abstract base class for storing and retrieving model configuration records.
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union
from typing import List, Optional, Set, Union
from pydantic import BaseModel, Field
@ -18,7 +18,7 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import ModelDefaultSettings
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
class DuplicateModelException(Exception):
@ -60,15 +60,24 @@ class ModelSummary(BaseModel):
class ModelRecordChanges(BaseModelExcludeNull, extra="allow"):
"""A set of changes to apply to model metadata.
Only limited changes are valid:
- `default_settings`: the user-configured default settings for this model
"""
"""A set of changes to apply to a model."""
# Changes applicable to all models
name: Optional[str] = Field(description="Name of the model.", default=None)
description: Optional[str] = Field(description="Model description", default=None)
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
default=None, description="The user-configured default settings for this model"
description="Default settings for this model", default=None
)
"""The user-configured default settings for this model"""
# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None
)
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
class ModelRecordServiceBase(ABC):
@ -99,13 +108,12 @@ class ModelRecordServiceBase(ABC):
pass
@abstractmethod
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
:param key: Unique key for the model to be updated.
:param changes: A set of changes to apply to this model. Changes are validated before being written.
"""
pass
@ -194,21 +202,3 @@ class ModelRecordServiceBase(ABC):
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
)
return model_configs[0]
def rename_model(
self,
key: str,
new_name: str,
) -> AnyModelConfig:
"""
Rename the indicated model. Just a special case of update_model().
In some implementations, renaming the model may involve changing where
it is stored on the filesystem. So this is broken out.
:param key: Model key
:param new_name: New name for model
"""
config = self.get_model(key)
config.name = new_name
return self.update_model(key, config)

View File

@ -43,7 +43,7 @@ import json
import sqlite3
from math import ceil
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import List, Optional, Union
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import (
@ -57,6 +57,7 @@ from invokeai.backend.model_manager.config import (
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import (
DuplicateModelException,
ModelRecordChanges,
ModelRecordOrderBy,
ModelRecordServiceBase,
ModelSummary,
@ -151,16 +152,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback()
raise e
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
"""
Update the model, returning the updated version.
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
record = self.get_model(key)
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(

View File

@ -162,7 +162,7 @@ class ModelConfigBase(BaseModel):
description="Default settings for this model", default=None
)
model_config = ConfigDict(use_enum_values=False, validate_assignment=True)
model_config = ConfigDict(validate_assignment=True)
class CheckpointConfigBase(ModelConfigBase):

View File

@ -6,6 +6,7 @@ from hashlib import sha256
from typing import Any, Optional
import pytest
from pydantic import ValidationError
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import (
@ -14,6 +15,7 @@ from invokeai.app.services.model_records import (
ModelRecordServiceSQL,
UnknownModelException,
)
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import (
BaseModelType,
MainDiffusersConfig,
@ -73,34 +75,33 @@ def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase):
store.add_model(config2)
def test_update(store: ModelRecordServiceBase):
def test_model_records_updates_model(store: ModelRecordServiceBase):
config = example_config("key1")
store.add_model(config)
config = store.get_model("key1")
assert config.name == "old name"
config.name = "new name"
store.update_model(config.key, config)
new_name = "new name"
changes = ModelRecordChanges(name=new_name)
store.update_model(config.key, changes)
new_config = store.get_model("key1")
assert new_config.name == "new name"
assert new_config.name == new_name
def test_rename(store: ModelRecordServiceBase):
def test_model_records_rejects_invalid_changes(store: ModelRecordServiceBase):
config = example_config("key1")
store.add_model(config)
config = store.get_model("key1")
assert config.name == "old name"
store.rename_model("key1", "new name")
new_config = store.get_model("key1")
assert new_config.name == "new name"
# upcast_attention is an invalid field for TIs
changes = ModelRecordChanges(upcast_attention=True)
with pytest.raises(ValidationError):
store.update_model(config.key, changes)
def test_unknown_key(store: ModelRecordServiceBase):
config = example_config("key1")
store.add_model(config)
with pytest.raises(UnknownModelException):
store.update_model("unknown_key", config)
store.update_model("unknown_key", ModelRecordChanges())
def test_delete(store: ModelRecordServiceBase):