multiple small fixes suggested in reviews from psychedelicious and ryan

This commit is contained in:
Lincoln Stein 2023-11-10 18:25:37 -05:00
parent fdaa661245
commit 0544917161
4 changed files with 48 additions and 37 deletions

View File

@ -39,7 +39,7 @@ class ModelRecordServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> ModelConfigBase: def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
""" """
Add a model to the database. Add a model to the database.
@ -110,7 +110,7 @@ class ModelRecordServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def search_by_name( def search_by_attr(
self, self,
model_name: Optional[str] = None, model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
@ -130,16 +130,16 @@ class ModelRecordServiceBase(ABC):
def all_models(self) -> List[AnyModelConfig]: def all_models(self) -> List[AnyModelConfig]:
"""Return all the model configs in the database.""" """Return all the model configs in the database."""
return self.search_by_name() return self.search_by_attr()
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase: def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> AnyModelConfig:
""" """
Return information about a single model using its name, base type and model type. Return information about a single model using its name, base type and model type.
If there are more than one model that match, raises a DuplicateModelException. If there are more than one model that match, raises a DuplicateModelException.
If no model matches, raises an UnknownModelException If no model matches, raises an UnknownModelException
""" """
model_configs = self.search_by_name(model_name=model_name, base_model=base_model, model_type=model_type) model_configs = self.search_by_attr(model_name=model_name, base_model=base_model, model_type=model_type)
if len(model_configs) > 1: if len(model_configs) > 1:
raise DuplicateModelException( raise DuplicateModelException(
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'." f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
@ -154,7 +154,7 @@ class ModelRecordServiceBase(ABC):
self, self,
key: str, key: str,
new_name: str, new_name: str,
) -> ModelConfigBase: ) -> AnyModelConfig:
""" """
Rename the indicated model. Just a special case of update_model(). Rename the indicated model. Just a special case of update_model().
@ -164,4 +164,6 @@ class ModelRecordServiceBase(ABC):
:param key: Model key :param key: Model key
:param new_name: New name for model :param new_name: New name for model
""" """
return self.update_model(key, {"name": new_name}) config = self.get_model(key)
config.name = new_name
return self.update_model(key, config)

View File

@ -36,7 +36,7 @@ Typical usage:
# searching # searching
configs = store.search_by_path(path='/tmp/pokemon.bin') configs = store.search_by_path(path='/tmp/pokemon.bin')
configs = store.search_by_hash('750a499f35e43b7e1b4d15c207aa2f01') configs = store.search_by_hash('750a499f35e43b7e1b4d15c207aa2f01')
configs = store.search_by_name(base_model='sd-2', model_type='main') configs = store.search_by_attr(base_model='sd-2', model_type='main')
""" """
@ -77,7 +77,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
""" """
super().__init__() super().__init__()
self._db = db self._db = db
self._db.conn.row_factory = sqlite3.Row
self._cursor = self._db.conn.cursor() self._cursor = self._db.conn.cursor()
with self._db.lock: with self._db.lock:
@ -157,7 +156,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
("version", CONFIG_FILE_VERSION), ("version", CONFIG_FILE_VERSION),
) )
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> AnyModelConfig:
""" """
Add a model to the database. Add a model to the database.
@ -168,7 +167,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions. Can raise DuplicateModelException and InvalidModelConfigException exceptions.
""" """
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect. record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string. json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
@ -252,7 +251,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback() self._db.conn.rollback()
raise e raise e
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: def update_model(self, key: str, config: ModelConfigBase) -> AnyModelConfig:
""" """
Update the model, returning the updated version. Update the model, returning the updated version.
@ -261,7 +260,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
required fields, or a ModelConfigBase instance. required fields, or a ModelConfigBase instance.
""" """
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string. json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
@ -328,7 +327,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
raise e raise e
return count > 0 return count > 0
def search_by_name( def search_by_attr(
self, self,
model_name: Optional[str] = None, model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,

View File

@ -127,14 +127,14 @@ class ModelConfigBase(BaseModel):
setattr(self, key, value) # may raise a validation error setattr(self, key, value) # may raise a validation error
class CheckpointConfig(ModelConfigBase): class _CheckpointConfig(ModelConfigBase):
"""Model config for checkpoint-style models.""" """Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file") config: str = Field(description="path to the checkpoint model config file")
class DiffusersConfig(ModelConfigBase): class _DiffusersConfig(ModelConfigBase):
"""Model config for diffusers-style models.""" """Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@ -158,13 +158,13 @@ class VaeDiffusersConfig(ModelConfigBase):
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(DiffusersConfig): class ControlNetDiffusersConfig(_DiffusersConfig):
"""Model config for ControlNet models (diffusers version).""" """Model config for ControlNet models (diffusers version)."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(CheckpointConfig): class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version).""" """Model config for ControlNet models (diffusers version)."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@ -176,29 +176,29 @@ class TextualInversionConfig(ModelConfigBase):
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder] format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
class MainConfig(ModelConfigBase): class _MainConfig(ModelConfigBase):
"""Model config for main models.""" """Model config for main models."""
vae: Optional[str] = Field(None) vae: Optional[str] = Field(default=None)
variant: ModelVariantType = ModelVariantType.Normal variant: ModelVariantType = ModelVariantType.Normal
ztsnr_training: bool = False ztsnr_training: bool = False
class MainCheckpointConfig(CheckpointConfig, MainConfig): class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main checkpoint models.""" """Model config for main checkpoint models."""
# Note that we do not need prediction_type or upcast_attention here # Note that we do not need prediction_type or upcast_attention here
# because they are provided in the checkpoint's own config file. # because they are provided in the checkpoint's own config file.
class MainDiffusersConfig(DiffusersConfig, MainConfig): class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models.""" """Model config for main diffusers models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False upcast_attention: bool = False
class ONNXSD1Config(MainConfig): class ONNXSD1Config(_MainConfig):
"""Model config for ONNX format models based on sd-1.""" """Model config for ONNX format models based on sd-1."""
format: Literal[ModelFormat.Onnx, ModelFormat.Olive] format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
@ -206,7 +206,7 @@ class ONNXSD1Config(MainConfig):
upcast_attention: bool = False upcast_attention: bool = False
class ONNXSD2Config(MainConfig): class ONNXSD2Config(_MainConfig):
"""Model config for ONNX format models based on sd-2.""" """Model config for ONNX format models based on sd-2."""
format: Literal[ModelFormat.Onnx, ModelFormat.Olive] format: Literal[ModelFormat.Onnx, ModelFormat.Olive]

View File

@ -16,7 +16,7 @@ from invokeai.app.services.model_records import (
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
BaseModelType, BaseModelType,
DiffusersConfig, MainDiffusersConfig,
ModelType, ModelType,
TextualInversionConfig, TextualInversionConfig,
VaeDiffusersConfig, VaeDiffusersConfig,
@ -83,6 +83,16 @@ def test_update(store: ModelRecordServiceBase):
new_config = store.get_model("key1") new_config = store.get_model("key1")
assert new_config.name == "new name" assert new_config.name == "new name"
def test_rename(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)
config = store.get_model("key1")
assert config.name == "old name"
store.rename_model("key1", "new name")
new_config = store.get_model("key1")
assert new_config.name == "new name"
def test_unknown_key(store: ModelRecordServiceBase): def test_unknown_key(store: ModelRecordServiceBase):
config = example_config() config = example_config()
@ -108,14 +118,14 @@ def test_exists(store: ModelRecordServiceBase):
def test_filter(store: ModelRecordServiceBase): def test_filter(store: ModelRecordServiceBase):
config1 = DiffusersConfig( config1 = MainDiffusersConfig(
path="/tmp/config1", path="/tmp/config1",
name="config1", name="config1",
base=BaseModelType("sd-1"), base=BaseModelType("sd-1"),
type=ModelType("main"), type=ModelType("main"),
original_hash="CONFIG1HASH", original_hash="CONFIG1HASH",
) )
config2 = DiffusersConfig( config2 = MainDiffusersConfig(
path="/tmp/config2", path="/tmp/config2",
name="config2", name="config2",
base=BaseModelType("sd-1"), base=BaseModelType("sd-1"),
@ -131,17 +141,17 @@ def test_filter(store: ModelRecordServiceBase):
) )
for c in config1, config2, config3: for c in config1, config2, config3:
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c) store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
matches = store.search_by_name(model_type=ModelType("main")) matches = store.search_by_attr(model_type=ModelType("main"))
assert len(matches) == 2 assert len(matches) == 2
assert matches[0].name in {"config1", "config2"} assert matches[0].name in {"config1", "config2"}
matches = store.search_by_name(model_type=ModelType("vae")) matches = store.search_by_attr(model_type=ModelType("vae"))
assert len(matches) == 1 assert len(matches) == 1
assert matches[0].name == "config3" assert matches[0].name == "config3"
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest() assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
matches = store.search_by_name(model_type=BaseModelType("sd-2")) matches = store.search_by_attr(model_type=BaseModelType("sd-2"))
matches = store.search_by_hash("CONFIG1HASH") matches = store.search_by_hash("CONFIG1HASH")
assert len(matches) == 1 assert len(matches) == 1
@ -152,28 +162,28 @@ def test_filter(store: ModelRecordServiceBase):
def test_filter_2(store: ModelRecordServiceBase): def test_filter_2(store: ModelRecordServiceBase):
config1 = DiffusersConfig( config1 = MainDiffusersConfig(
path="/tmp/config1", path="/tmp/config1",
name="config1", name="config1",
base=BaseModelType("sd-1"), base=BaseModelType("sd-1"),
type=ModelType("main"), type=ModelType("main"),
original_hash="CONFIG1HASH", original_hash="CONFIG1HASH",
) )
config2 = DiffusersConfig( config2 = MainDiffusersConfig(
path="/tmp/config2", path="/tmp/config2",
name="config2", name="config2",
base=BaseModelType("sd-1"), base=BaseModelType("sd-1"),
type=ModelType("main"), type=ModelType("main"),
original_hash="CONFIG2HASH", original_hash="CONFIG2HASH",
) )
config3 = DiffusersConfig( config3 = MainDiffusersConfig(
path="/tmp/config3", path="/tmp/config3",
name="dup_name1", name="dup_name1",
base=BaseModelType("sd-2"), base=BaseModelType("sd-2"),
type=ModelType("main"), type=ModelType("main"),
original_hash="CONFIG3HASH", original_hash="CONFIG3HASH",
) )
config4 = DiffusersConfig( config4 = MainDiffusersConfig(
path="/tmp/config4", path="/tmp/config4",
name="dup_name1", name="dup_name1",
base=BaseModelType("sd-2"), base=BaseModelType("sd-2"),
@ -190,19 +200,19 @@ def test_filter_2(store: ModelRecordServiceBase):
for c in config1, config2, config3, config4, config5: for c in config1, config2, config3, config4, config5:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c) store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
matches = store.search_by_name( matches = store.search_by_attr(
model_type=ModelType("main"), model_type=ModelType("main"),
model_name="dup_name1", model_name="dup_name1",
) )
assert len(matches) == 2 assert len(matches) == 2
matches = store.search_by_name( matches = store.search_by_attr(
base_model=BaseModelType("sd-1"), base_model=BaseModelType("sd-1"),
model_type=ModelType("main"), model_type=ModelType("main"),
) )
assert len(matches) == 2 assert len(matches) == 2
matches = store.search_by_name( matches = store.search_by_attr(
base_model=BaseModelType("sd-1"), base_model=BaseModelType("sd-1"),
model_type=ModelType("vae"), model_type=ModelType("vae"),
model_name="dup_name1", model_name="dup_name1",