mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
multiple small fixes suggested in reviews from psychedelicious and ryan
This commit is contained in:
parent
fdaa661245
commit
0544917161
@ -39,7 +39,7 @@ class ModelRecordServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> ModelConfigBase:
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -110,7 +110,7 @@ class ModelRecordServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_name(
|
||||
def search_by_attr(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
@ -130,16 +130,16 @@ class ModelRecordServiceBase(ABC):
|
||||
|
||||
def all_models(self) -> List[AnyModelConfig]:
|
||||
"""Return all the model configs in the database."""
|
||||
return self.search_by_name()
|
||||
return self.search_by_attr()
|
||||
|
||||
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase:
|
||||
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> AnyModelConfig:
|
||||
"""
|
||||
Return information about a single model using its name, base type and model type.
|
||||
|
||||
If there are more than one model that match, raises a DuplicateModelException.
|
||||
If no model matches, raises an UnknownModelException
|
||||
"""
|
||||
model_configs = self.search_by_name(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
model_configs = self.search_by_attr(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
if len(model_configs) > 1:
|
||||
raise DuplicateModelException(
|
||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
||||
@ -154,7 +154,7 @@ class ModelRecordServiceBase(ABC):
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
) -> ModelConfigBase:
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Rename the indicated model. Just a special case of update_model().
|
||||
|
||||
@ -164,4 +164,6 @@ class ModelRecordServiceBase(ABC):
|
||||
:param key: Model key
|
||||
:param new_name: New name for model
|
||||
"""
|
||||
return self.update_model(key, {"name": new_name})
|
||||
config = self.get_model(key)
|
||||
config.name = new_name
|
||||
return self.update_model(key, config)
|
||||
|
@ -36,7 +36,7 @@ Typical usage:
|
||||
# searching
|
||||
configs = store.search_by_path(path='/tmp/pokemon.bin')
|
||||
configs = store.search_by_hash('750a499f35e43b7e1b4d15c207aa2f01')
|
||||
configs = store.search_by_name(base_model='sd-2', model_type='main')
|
||||
configs = store.search_by_attr(base_model='sd-2', model_type='main')
|
||||
"""
|
||||
|
||||
|
||||
@ -77,7 +77,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._db.conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
with self._db.lock:
|
||||
@ -157,7 +156,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
("version", CONFIG_FILE_VERSION),
|
||||
)
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -168,7 +167,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
@ -252,7 +251,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
||||
def update_model(self, key: str, config: ModelConfigBase) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
@ -261,7 +260,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
@ -328,7 +327,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
raise e
|
||||
return count > 0
|
||||
|
||||
def search_by_name(
|
||||
def search_by_attr(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
|
@ -127,14 +127,14 @@ class ModelConfigBase(BaseModel):
|
||||
setattr(self, key, value) # may raise a validation error
|
||||
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
class _CheckpointConfig(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
config: str = Field(description="path to the checkpoint model config file")
|
||||
|
||||
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
class _DiffusersConfig(ModelConfigBase):
|
||||
"""Model config for diffusers-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
@ -158,13 +158,13 @@ class VaeDiffusersConfig(ModelConfigBase):
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfig):
|
||||
class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class ControlNetCheckpointConfig(CheckpointConfig):
|
||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
@ -176,29 +176,29 @@ class TextualInversionConfig(ModelConfigBase):
|
||||
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
|
||||
|
||||
|
||||
class MainConfig(ModelConfigBase):
|
||||
class _MainConfig(ModelConfigBase):
|
||||
"""Model config for main models."""
|
||||
|
||||
vae: Optional[str] = Field(None)
|
||||
vae: Optional[str] = Field(default=None)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
ztsnr_training: bool = False
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfig, MainConfig):
|
||||
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
# Note that we do not need prediction_type or upcast_attention here
|
||||
# because they are provided in the checkpoint's own config file.
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfig, MainConfig):
|
||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
class ONNXSD1Config(MainConfig):
|
||||
class ONNXSD1Config(_MainConfig):
|
||||
"""Model config for ONNX format models based on sd-1."""
|
||||
|
||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||
@ -206,7 +206,7 @@ class ONNXSD1Config(MainConfig):
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
class ONNXSD2Config(MainConfig):
|
||||
class ONNXSD2Config(_MainConfig):
|
||||
"""Model config for ONNX format models based on sd-2."""
|
||||
|
||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||
|
@ -16,7 +16,7 @@ from invokeai.app.services.model_records import (
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
DiffusersConfig,
|
||||
MainDiffusersConfig,
|
||||
ModelType,
|
||||
TextualInversionConfig,
|
||||
VaeDiffusersConfig,
|
||||
@ -83,6 +83,16 @@ def test_update(store: ModelRecordServiceBase):
|
||||
new_config = store.get_model("key1")
|
||||
assert new_config.name == "new name"
|
||||
|
||||
def test_rename(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
store.add_model("key1", config)
|
||||
config = store.get_model("key1")
|
||||
assert config.name == "old name"
|
||||
|
||||
store.rename_model("key1", "new name")
|
||||
new_config = store.get_model("key1")
|
||||
assert new_config.name == "new name"
|
||||
|
||||
|
||||
def test_unknown_key(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
@ -108,14 +118,14 @@ def test_exists(store: ModelRecordServiceBase):
|
||||
|
||||
|
||||
def test_filter(store: ModelRecordServiceBase):
|
||||
config1 = DiffusersConfig(
|
||||
config1 = MainDiffusersConfig(
|
||||
path="/tmp/config1",
|
||||
name="config1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
config2 = DiffusersConfig(
|
||||
config2 = MainDiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
name="config2",
|
||||
base=BaseModelType("sd-1"),
|
||||
@ -131,17 +141,17 @@ def test_filter(store: ModelRecordServiceBase):
|
||||
)
|
||||
for c in config1, config2, config3:
|
||||
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
|
||||
matches = store.search_by_name(model_type=ModelType("main"))
|
||||
matches = store.search_by_attr(model_type=ModelType("main"))
|
||||
assert len(matches) == 2
|
||||
assert matches[0].name in {"config1", "config2"}
|
||||
|
||||
matches = store.search_by_name(model_type=ModelType("vae"))
|
||||
matches = store.search_by_attr(model_type=ModelType("vae"))
|
||||
assert len(matches) == 1
|
||||
assert matches[0].name == "config3"
|
||||
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
|
||||
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
|
||||
|
||||
matches = store.search_by_name(model_type=BaseModelType("sd-2"))
|
||||
matches = store.search_by_attr(model_type=BaseModelType("sd-2"))
|
||||
|
||||
matches = store.search_by_hash("CONFIG1HASH")
|
||||
assert len(matches) == 1
|
||||
@ -152,28 +162,28 @@ def test_filter(store: ModelRecordServiceBase):
|
||||
|
||||
|
||||
def test_filter_2(store: ModelRecordServiceBase):
|
||||
config1 = DiffusersConfig(
|
||||
config1 = MainDiffusersConfig(
|
||||
path="/tmp/config1",
|
||||
name="config1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
config2 = DiffusersConfig(
|
||||
config2 = MainDiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
name="config2",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG2HASH",
|
||||
)
|
||||
config3 = DiffusersConfig(
|
||||
config3 = MainDiffusersConfig(
|
||||
path="/tmp/config3",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG3HASH",
|
||||
)
|
||||
config4 = DiffusersConfig(
|
||||
config4 = MainDiffusersConfig(
|
||||
path="/tmp/config4",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sd-2"),
|
||||
@ -190,19 +200,19 @@ def test_filter_2(store: ModelRecordServiceBase):
|
||||
for c in config1, config2, config3, config4, config5:
|
||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
||||
|
||||
matches = store.search_by_name(
|
||||
matches = store.search_by_attr(
|
||||
model_type=ModelType("main"),
|
||||
model_name="dup_name1",
|
||||
)
|
||||
assert len(matches) == 2
|
||||
|
||||
matches = store.search_by_name(
|
||||
matches = store.search_by_attr(
|
||||
base_model=BaseModelType("sd-1"),
|
||||
model_type=ModelType("main"),
|
||||
)
|
||||
assert len(matches) == 2
|
||||
|
||||
matches = store.search_by_name(
|
||||
matches = store.search_by_attr(
|
||||
base_model=BaseModelType("sd-1"),
|
||||
model_type=ModelType("vae"),
|
||||
model_name="dup_name1",
|
||||
|
Loading…
Reference in New Issue
Block a user