mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): revise update_model to use ModelRecordChanges
This commit is contained in:
parent
37b969d339
commit
5551cf8ac4
@ -19,6 +19,7 @@ from invokeai.app.services.model_records import (
|
|||||||
ModelSummary,
|
ModelSummary,
|
||||||
UnknownModelException,
|
UnknownModelException,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -263,15 +264,13 @@ async def scan_for_models(
|
|||||||
)
|
)
|
||||||
async def update_model_record(
|
async def update_model_record(
|
||||||
key: Annotated[str, Path(description="Unique key of model")],
|
key: Annotated[str, Path(description="Unique key of model")],
|
||||||
info: Annotated[
|
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
|
||||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
|
||||||
],
|
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
"""Update a model's config."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
try:
|
try:
|
||||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
|
||||||
logger.info(f"Updated model: {key}")
|
logger.info(f"Updated model: {key}")
|
||||||
except UnknownModelException as e:
|
except UnknownModelException as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
@ -563,7 +562,8 @@ async def convert_model(
|
|||||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||||
original_name = model_config.name
|
original_name = model_config.name
|
||||||
model_config.name = f"{original_name}.DELETE"
|
model_config.name = f"{original_name}.DELETE"
|
||||||
store.update_model(key, config=model_config)
|
changes = ModelRecordChanges(name=model_config.name)
|
||||||
|
store.update_model(key, changes=changes)
|
||||||
|
|
||||||
# install the diffusers
|
# install the diffusers
|
||||||
try:
|
try:
|
||||||
|
@ -6,7 +6,7 @@ Abstract base class for storing and retrieving model configuration records.
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import List, Optional, Set, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -18,7 +18,7 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.config import ModelDefaultSettings
|
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
|
||||||
|
|
||||||
|
|
||||||
class DuplicateModelException(Exception):
|
class DuplicateModelException(Exception):
|
||||||
@ -60,15 +60,24 @@ class ModelSummary(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ModelRecordChanges(BaseModelExcludeNull, extra="allow"):
|
class ModelRecordChanges(BaseModelExcludeNull, extra="allow"):
|
||||||
"""A set of changes to apply to model metadata.
|
"""A set of changes to apply to a model."""
|
||||||
Only limited changes are valid:
|
|
||||||
- `default_settings`: the user-configured default settings for this model
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
# Changes applicable to all models
|
||||||
|
name: Optional[str] = Field(description="Name of the model.", default=None)
|
||||||
|
description: Optional[str] = Field(description="Model description", default=None)
|
||||||
|
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||||
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||||
default=None, description="The user-configured default settings for this model"
|
description="Default settings for this model", default=None
|
||||||
)
|
)
|
||||||
"""The user-configured default settings for this model"""
|
|
||||||
|
# Checkpoint-specific changes
|
||||||
|
# TODO(MM2): Should we expose these? Feels footgun-y...
|
||||||
|
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
|
||||||
|
prediction_type: Optional[SchedulerPredictionType] = Field(
|
||||||
|
description="The prediction type of the model.", default=None
|
||||||
|
)
|
||||||
|
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
|
||||||
|
|
||||||
|
|
||||||
class ModelRecordServiceBase(ABC):
|
class ModelRecordServiceBase(ABC):
|
||||||
@ -99,13 +108,12 @@ class ModelRecordServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Update the model, returning the updated version.
|
Update the model, returning the updated version.
|
||||||
|
|
||||||
:param key: Unique key for the model to be updated
|
:param key: Unique key for the model to be updated.
|
||||||
:param config: Model configuration record. Either a dict with the
|
:param changes: A set of changes to apply to this model. Changes are validated before being written.
|
||||||
required fields, or a ModelConfigBase instance.
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -194,21 +202,3 @@ class ModelRecordServiceBase(ABC):
|
|||||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
||||||
)
|
)
|
||||||
return model_configs[0]
|
return model_configs[0]
|
||||||
|
|
||||||
def rename_model(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
new_name: str,
|
|
||||||
) -> AnyModelConfig:
|
|
||||||
"""
|
|
||||||
Rename the indicated model. Just a special case of update_model().
|
|
||||||
|
|
||||||
In some implementations, renaming the model may involve changing where
|
|
||||||
it is stored on the filesystem. So this is broken out.
|
|
||||||
|
|
||||||
:param key: Model key
|
|
||||||
:param new_name: New name for model
|
|
||||||
"""
|
|
||||||
config = self.get_model(key)
|
|
||||||
config.name = new_name
|
|
||||||
return self.update_model(key, config)
|
|
||||||
|
@ -43,7 +43,7 @@ import json
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
@ -57,6 +57,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from .model_records_base import (
|
from .model_records_base import (
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
|
ModelRecordChanges,
|
||||||
ModelRecordOrderBy,
|
ModelRecordOrderBy,
|
||||||
ModelRecordServiceBase,
|
ModelRecordServiceBase,
|
||||||
ModelSummary,
|
ModelSummary,
|
||||||
@ -151,16 +152,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._db.conn.rollback()
|
self._db.conn.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||||
"""
|
record = self.get_model(key)
|
||||||
Update the model, returning the updated version.
|
|
||||||
|
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||||
|
for field_name in changes.model_fields_set:
|
||||||
|
setattr(record, field_name, getattr(changes, field_name))
|
||||||
|
|
||||||
|
json_serialized = record.model_dump_json()
|
||||||
|
|
||||||
:param key: Unique key for the model to be updated
|
|
||||||
:param config: Model configuration record. Either a dict with the
|
|
||||||
required fields, or a ModelConfigBase instance.
|
|
||||||
"""
|
|
||||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
|
||||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
|
||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
try:
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
|
@ -162,7 +162,7 @@ class ModelConfigBase(BaseModel):
|
|||||||
description="Default settings for this model", default=None
|
description="Default settings for this model", default=None
|
||||||
)
|
)
|
||||||
|
|
||||||
model_config = ConfigDict(use_enum_values=False, validate_assignment=True)
|
model_config = ConfigDict(validate_assignment=True)
|
||||||
|
|
||||||
|
|
||||||
class CheckpointConfigBase(ModelConfigBase):
|
class CheckpointConfigBase(ModelConfigBase):
|
||||||
|
@ -6,6 +6,7 @@ from hashlib import sha256
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.model_records import (
|
from invokeai.app.services.model_records import (
|
||||||
@ -14,6 +15,7 @@ from invokeai.app.services.model_records import (
|
|||||||
ModelRecordServiceSQL,
|
ModelRecordServiceSQL,
|
||||||
UnknownModelException,
|
UnknownModelException,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
MainDiffusersConfig,
|
MainDiffusersConfig,
|
||||||
@ -73,34 +75,33 @@ def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase):
|
|||||||
store.add_model(config2)
|
store.add_model(config2)
|
||||||
|
|
||||||
|
|
||||||
def test_update(store: ModelRecordServiceBase):
|
def test_model_records_updates_model(store: ModelRecordServiceBase):
|
||||||
config = example_config("key1")
|
config = example_config("key1")
|
||||||
store.add_model(config)
|
store.add_model(config)
|
||||||
config = store.get_model("key1")
|
config = store.get_model("key1")
|
||||||
assert config.name == "old name"
|
assert config.name == "old name"
|
||||||
|
new_name = "new name"
|
||||||
config.name = "new name"
|
changes = ModelRecordChanges(name=new_name)
|
||||||
store.update_model(config.key, config)
|
store.update_model(config.key, changes)
|
||||||
new_config = store.get_model("key1")
|
new_config = store.get_model("key1")
|
||||||
assert new_config.name == "new name"
|
assert new_config.name == new_name
|
||||||
|
|
||||||
|
|
||||||
def test_rename(store: ModelRecordServiceBase):
|
def test_model_records_rejects_invalid_changes(store: ModelRecordServiceBase):
|
||||||
config = example_config("key1")
|
config = example_config("key1")
|
||||||
store.add_model(config)
|
store.add_model(config)
|
||||||
config = store.get_model("key1")
|
config = store.get_model("key1")
|
||||||
assert config.name == "old name"
|
# upcast_attention is an invalid field for TIs
|
||||||
|
changes = ModelRecordChanges(upcast_attention=True)
|
||||||
store.rename_model("key1", "new name")
|
with pytest.raises(ValidationError):
|
||||||
new_config = store.get_model("key1")
|
store.update_model(config.key, changes)
|
||||||
assert new_config.name == "new name"
|
|
||||||
|
|
||||||
|
|
||||||
def test_unknown_key(store: ModelRecordServiceBase):
|
def test_unknown_key(store: ModelRecordServiceBase):
|
||||||
config = example_config("key1")
|
config = example_config("key1")
|
||||||
store.add_model(config)
|
store.add_model(config)
|
||||||
with pytest.raises(UnknownModelException):
|
with pytest.raises(UnknownModelException):
|
||||||
store.update_model("unknown_key", config)
|
store.update_model("unknown_key", ModelRecordChanges())
|
||||||
|
|
||||||
|
|
||||||
def test_delete(store: ModelRecordServiceBase):
|
def test_delete(store: ModelRecordServiceBase):
|
||||||
|
Loading…
Reference in New Issue
Block a user