multiple small fixes suggested in reviews from psychedelicious and ryan

This commit is contained in:
Lincoln Stein 2023-11-10 18:25:37 -05:00
parent fdaa661245
commit 0544917161
4 changed files with 48 additions and 37 deletions

View File

@ -39,7 +39,7 @@ class ModelRecordServiceBase(ABC):
pass
@abstractmethod
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> ModelConfigBase:
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Add a model to the database.
@ -110,7 +110,7 @@ class ModelRecordServiceBase(ABC):
pass
@abstractmethod
def search_by_name(
def search_by_attr(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
@ -130,16 +130,16 @@ class ModelRecordServiceBase(ABC):
def all_models(self) -> List[AnyModelConfig]:
"""Return all the model configs in the database."""
return self.search_by_name()
return self.search_by_attr()
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase:
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> AnyModelConfig:
"""
Return information about a single model using its name, base type and model type.
If there are more than one model that match, raises a DuplicateModelException.
If no model matches, raises an UnknownModelException
"""
model_configs = self.search_by_name(model_name=model_name, base_model=base_model, model_type=model_type)
model_configs = self.search_by_attr(model_name=model_name, base_model=base_model, model_type=model_type)
if len(model_configs) > 1:
raise DuplicateModelException(
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
@ -154,7 +154,7 @@ class ModelRecordServiceBase(ABC):
self,
key: str,
new_name: str,
) -> ModelConfigBase:
) -> AnyModelConfig:
"""
Rename the indicated model. Just a special case of update_model().
@ -164,4 +164,6 @@ class ModelRecordServiceBase(ABC):
:param key: Model key
:param new_name: New name for model
"""
return self.update_model(key, {"name": new_name})
config = self.get_model(key)
config.name = new_name
return self.update_model(key, config)

View File

@ -36,7 +36,7 @@ Typical usage:
# searching
configs = store.search_by_path(path='/tmp/pokemon.bin')
configs = store.search_by_hash('750a499f35e43b7e1b4d15c207aa2f01')
configs = store.search_by_name(base_model='sd-2', model_type='main')
configs = store.search_by_attr(base_model='sd-2', model_type='main')
"""
@ -77,7 +77,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""
super().__init__()
self._db = db
self._db.conn.row_factory = sqlite3.Row
self._cursor = self._db.conn.cursor()
with self._db.lock:
@ -157,7 +156,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
("version", CONFIG_FILE_VERSION),
)
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> AnyModelConfig:
"""
Add a model to the database.
@ -168,7 +167,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
@ -252,7 +251,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db.conn.rollback()
raise e
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
def update_model(self, key: str, config: ModelConfigBase) -> AnyModelConfig:
"""
Update the model, returning the updated version.
@ -261,7 +260,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
@ -328,7 +327,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
raise e
return count > 0
def search_by_name(
def search_by_attr(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,

View File

@ -127,14 +127,14 @@ class ModelConfigBase(BaseModel):
setattr(self, key, value) # may raise a validation error
class CheckpointConfig(ModelConfigBase):
class _CheckpointConfig(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file")
class DiffusersConfig(ModelConfigBase):
class _DiffusersConfig(ModelConfigBase):
"""Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@ -158,13 +158,13 @@ class VaeDiffusersConfig(ModelConfigBase):
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(DiffusersConfig):
class ControlNetDiffusersConfig(_DiffusersConfig):
"""Model config for ControlNet models (diffusers version)."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(CheckpointConfig):
class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version)."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@ -176,29 +176,29 @@ class TextualInversionConfig(ModelConfigBase):
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
class MainConfig(ModelConfigBase):
class _MainConfig(ModelConfigBase):
"""Model config for main models."""
vae: Optional[str] = Field(None)
vae: Optional[str] = Field(default=None)
variant: ModelVariantType = ModelVariantType.Normal
ztsnr_training: bool = False
class MainCheckpointConfig(CheckpointConfig, MainConfig):
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main checkpoint models."""
# Note that we do not need prediction_type or upcast_attention here
# because they are provided in the checkpoint's own config file.
class MainDiffusersConfig(DiffusersConfig, MainConfig):
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD1Config(MainConfig):
class ONNXSD1Config(_MainConfig):
"""Model config for ONNX format models based on sd-1."""
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
@ -206,7 +206,7 @@ class ONNXSD1Config(MainConfig):
upcast_attention: bool = False
class ONNXSD2Config(MainConfig):
class ONNXSD2Config(_MainConfig):
"""Model config for ONNX format models based on sd-2."""
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]

View File

@ -16,7 +16,7 @@ from invokeai.app.services.model_records import (
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.model_manager.config import (
BaseModelType,
DiffusersConfig,
MainDiffusersConfig,
ModelType,
TextualInversionConfig,
VaeDiffusersConfig,
@ -83,6 +83,16 @@ def test_update(store: ModelRecordServiceBase):
new_config = store.get_model("key1")
assert new_config.name == "new name"
def test_rename(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)
config = store.get_model("key1")
assert config.name == "old name"
store.rename_model("key1", "new name")
new_config = store.get_model("key1")
assert new_config.name == "new name"
def test_unknown_key(store: ModelRecordServiceBase):
config = example_config()
@ -108,14 +118,14 @@ def test_exists(store: ModelRecordServiceBase):
def test_filter(store: ModelRecordServiceBase):
config1 = DiffusersConfig(
config1 = MainDiffusersConfig(
path="/tmp/config1",
name="config1",
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG1HASH",
)
config2 = DiffusersConfig(
config2 = MainDiffusersConfig(
path="/tmp/config2",
name="config2",
base=BaseModelType("sd-1"),
@ -131,17 +141,17 @@ def test_filter(store: ModelRecordServiceBase):
)
for c in config1, config2, config3:
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
matches = store.search_by_name(model_type=ModelType("main"))
matches = store.search_by_attr(model_type=ModelType("main"))
assert len(matches) == 2
assert matches[0].name in {"config1", "config2"}
matches = store.search_by_name(model_type=ModelType("vae"))
matches = store.search_by_attr(model_type=ModelType("vae"))
assert len(matches) == 1
assert matches[0].name == "config3"
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
matches = store.search_by_name(model_type=BaseModelType("sd-2"))
matches = store.search_by_attr(model_type=BaseModelType("sd-2"))
matches = store.search_by_hash("CONFIG1HASH")
assert len(matches) == 1
@ -152,28 +162,28 @@ def test_filter(store: ModelRecordServiceBase):
def test_filter_2(store: ModelRecordServiceBase):
config1 = DiffusersConfig(
config1 = MainDiffusersConfig(
path="/tmp/config1",
name="config1",
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG1HASH",
)
config2 = DiffusersConfig(
config2 = MainDiffusersConfig(
path="/tmp/config2",
name="config2",
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG2HASH",
)
config3 = DiffusersConfig(
config3 = MainDiffusersConfig(
path="/tmp/config3",
name="dup_name1",
base=BaseModelType("sd-2"),
type=ModelType("main"),
original_hash="CONFIG3HASH",
)
config4 = DiffusersConfig(
config4 = MainDiffusersConfig(
path="/tmp/config4",
name="dup_name1",
base=BaseModelType("sd-2"),
@ -190,19 +200,19 @@ def test_filter_2(store: ModelRecordServiceBase):
for c in config1, config2, config3, config4, config5:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
matches = store.search_by_name(
matches = store.search_by_attr(
model_type=ModelType("main"),
model_name="dup_name1",
)
assert len(matches) == 2
matches = store.search_by_name(
matches = store.search_by_attr(
base_model=BaseModelType("sd-1"),
model_type=ModelType("main"),
)
assert len(matches) == 2
matches = store.search_by_name(
matches = store.search_by_attr(
base_model=BaseModelType("sd-1"),
model_type=ModelType("vae"),
model_name="dup_name1",