mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
multiple small fixes suggested in reviews from psychedelicious and ryan
This commit is contained in:
parent
fdaa661245
commit
0544917161
@ -39,7 +39,7 @@ class ModelRecordServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> ModelConfigBase:
|
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Add a model to the database.
|
Add a model to the database.
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ class ModelRecordServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search_by_name(
|
def search_by_attr(
|
||||||
self,
|
self,
|
||||||
model_name: Optional[str] = None,
|
model_name: Optional[str] = None,
|
||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
@ -130,16 +130,16 @@ class ModelRecordServiceBase(ABC):
|
|||||||
|
|
||||||
def all_models(self) -> List[AnyModelConfig]:
|
def all_models(self) -> List[AnyModelConfig]:
|
||||||
"""Return all the model configs in the database."""
|
"""Return all the model configs in the database."""
|
||||||
return self.search_by_name()
|
return self.search_by_attr()
|
||||||
|
|
||||||
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase:
|
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Return information about a single model using its name, base type and model type.
|
Return information about a single model using its name, base type and model type.
|
||||||
|
|
||||||
If there are more than one model that match, raises a DuplicateModelException.
|
If there are more than one model that match, raises a DuplicateModelException.
|
||||||
If no model matches, raises an UnknownModelException
|
If no model matches, raises an UnknownModelException
|
||||||
"""
|
"""
|
||||||
model_configs = self.search_by_name(model_name=model_name, base_model=base_model, model_type=model_type)
|
model_configs = self.search_by_attr(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||||
if len(model_configs) > 1:
|
if len(model_configs) > 1:
|
||||||
raise DuplicateModelException(
|
raise DuplicateModelException(
|
||||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
||||||
@ -154,7 +154,7 @@ class ModelRecordServiceBase(ABC):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
new_name: str,
|
new_name: str,
|
||||||
) -> ModelConfigBase:
|
) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Rename the indicated model. Just a special case of update_model().
|
Rename the indicated model. Just a special case of update_model().
|
||||||
|
|
||||||
@ -164,4 +164,6 @@ class ModelRecordServiceBase(ABC):
|
|||||||
:param key: Model key
|
:param key: Model key
|
||||||
:param new_name: New name for model
|
:param new_name: New name for model
|
||||||
"""
|
"""
|
||||||
return self.update_model(key, {"name": new_name})
|
config = self.get_model(key)
|
||||||
|
config.name = new_name
|
||||||
|
return self.update_model(key, config)
|
||||||
|
@ -36,7 +36,7 @@ Typical usage:
|
|||||||
# searching
|
# searching
|
||||||
configs = store.search_by_path(path='/tmp/pokemon.bin')
|
configs = store.search_by_path(path='/tmp/pokemon.bin')
|
||||||
configs = store.search_by_hash('750a499f35e43b7e1b4d15c207aa2f01')
|
configs = store.search_by_hash('750a499f35e43b7e1b4d15c207aa2f01')
|
||||||
configs = store.search_by_name(base_model='sd-2', model_type='main')
|
configs = store.search_by_attr(base_model='sd-2', model_type='main')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -77,7 +77,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._db = db
|
self._db = db
|
||||||
self._db.conn.row_factory = sqlite3.Row
|
|
||||||
self._cursor = self._db.conn.cursor()
|
self._cursor = self._db.conn.cursor()
|
||||||
|
|
||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
@ -157,7 +156,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
("version", CONFIG_FILE_VERSION),
|
("version", CONFIG_FILE_VERSION),
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Add a model to the database.
|
Add a model to the database.
|
||||||
|
|
||||||
@ -168,7 +167,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||||
"""
|
"""
|
||||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||||
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
|
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
try:
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -252,7 +251,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._db.conn.rollback()
|
self._db.conn.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
def update_model(self, key: str, config: ModelConfigBase) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Update the model, returning the updated version.
|
Update the model, returning the updated version.
|
||||||
|
|
||||||
@ -261,7 +260,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
required fields, or a ModelConfigBase instance.
|
required fields, or a ModelConfigBase instance.
|
||||||
"""
|
"""
|
||||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||||
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
|
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
try:
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -328,7 +327,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
raise e
|
raise e
|
||||||
return count > 0
|
return count > 0
|
||||||
|
|
||||||
def search_by_name(
|
def search_by_attr(
|
||||||
self,
|
self,
|
||||||
model_name: Optional[str] = None,
|
model_name: Optional[str] = None,
|
||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
|
@ -127,14 +127,14 @@ class ModelConfigBase(BaseModel):
|
|||||||
setattr(self, key, value) # may raise a validation error
|
setattr(self, key, value) # may raise a validation error
|
||||||
|
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class _CheckpointConfig(ModelConfigBase):
|
||||||
"""Model config for checkpoint-style models."""
|
"""Model config for checkpoint-style models."""
|
||||||
|
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||||
config: str = Field(description="path to the checkpoint model config file")
|
config: str = Field(description="path to the checkpoint model config file")
|
||||||
|
|
||||||
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class _DiffusersConfig(ModelConfigBase):
|
||||||
"""Model config for diffusers-style models."""
|
"""Model config for diffusers-style models."""
|
||||||
|
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
@ -158,13 +158,13 @@ class VaeDiffusersConfig(ModelConfigBase):
|
|||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
|
|
||||||
class ControlNetDiffusersConfig(DiffusersConfig):
|
class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||||
"""Model config for ControlNet models (diffusers version)."""
|
"""Model config for ControlNet models (diffusers version)."""
|
||||||
|
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
|
|
||||||
class ControlNetCheckpointConfig(CheckpointConfig):
|
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||||
"""Model config for ControlNet models (diffusers version)."""
|
"""Model config for ControlNet models (diffusers version)."""
|
||||||
|
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||||
@ -176,29 +176,29 @@ class TextualInversionConfig(ModelConfigBase):
|
|||||||
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
|
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
|
||||||
|
|
||||||
|
|
||||||
class MainConfig(ModelConfigBase):
|
class _MainConfig(ModelConfigBase):
|
||||||
"""Model config for main models."""
|
"""Model config for main models."""
|
||||||
|
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(default=None)
|
||||||
variant: ModelVariantType = ModelVariantType.Normal
|
variant: ModelVariantType = ModelVariantType.Normal
|
||||||
ztsnr_training: bool = False
|
ztsnr_training: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MainCheckpointConfig(CheckpointConfig, MainConfig):
|
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||||
"""Model config for main checkpoint models."""
|
"""Model config for main checkpoint models."""
|
||||||
|
|
||||||
# Note that we do not need prediction_type or upcast_attention here
|
# Note that we do not need prediction_type or upcast_attention here
|
||||||
# because they are provided in the checkpoint's own config file.
|
# because they are provided in the checkpoint's own config file.
|
||||||
|
|
||||||
|
|
||||||
class MainDiffusersConfig(DiffusersConfig, MainConfig):
|
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||||
"""Model config for main diffusers models."""
|
"""Model config for main diffusers models."""
|
||||||
|
|
||||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||||
upcast_attention: bool = False
|
upcast_attention: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ONNXSD1Config(MainConfig):
|
class ONNXSD1Config(_MainConfig):
|
||||||
"""Model config for ONNX format models based on sd-1."""
|
"""Model config for ONNX format models based on sd-1."""
|
||||||
|
|
||||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||||
@ -206,7 +206,7 @@ class ONNXSD1Config(MainConfig):
|
|||||||
upcast_attention: bool = False
|
upcast_attention: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ONNXSD2Config(MainConfig):
|
class ONNXSD2Config(_MainConfig):
|
||||||
"""Model config for ONNX format models based on sd-2."""
|
"""Model config for ONNX format models based on sd-2."""
|
||||||
|
|
||||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||||
|
@ -16,7 +16,7 @@ from invokeai.app.services.model_records import (
|
|||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
DiffusersConfig,
|
MainDiffusersConfig,
|
||||||
ModelType,
|
ModelType,
|
||||||
TextualInversionConfig,
|
TextualInversionConfig,
|
||||||
VaeDiffusersConfig,
|
VaeDiffusersConfig,
|
||||||
@ -83,6 +83,16 @@ def test_update(store: ModelRecordServiceBase):
|
|||||||
new_config = store.get_model("key1")
|
new_config = store.get_model("key1")
|
||||||
assert new_config.name == "new name"
|
assert new_config.name == "new name"
|
||||||
|
|
||||||
|
def test_rename(store: ModelRecordServiceBase):
|
||||||
|
config = example_config()
|
||||||
|
store.add_model("key1", config)
|
||||||
|
config = store.get_model("key1")
|
||||||
|
assert config.name == "old name"
|
||||||
|
|
||||||
|
store.rename_model("key1", "new name")
|
||||||
|
new_config = store.get_model("key1")
|
||||||
|
assert new_config.name == "new name"
|
||||||
|
|
||||||
|
|
||||||
def test_unknown_key(store: ModelRecordServiceBase):
|
def test_unknown_key(store: ModelRecordServiceBase):
|
||||||
config = example_config()
|
config = example_config()
|
||||||
@ -108,14 +118,14 @@ def test_exists(store: ModelRecordServiceBase):
|
|||||||
|
|
||||||
|
|
||||||
def test_filter(store: ModelRecordServiceBase):
|
def test_filter(store: ModelRecordServiceBase):
|
||||||
config1 = DiffusersConfig(
|
config1 = MainDiffusersConfig(
|
||||||
path="/tmp/config1",
|
path="/tmp/config1",
|
||||||
name="config1",
|
name="config1",
|
||||||
base=BaseModelType("sd-1"),
|
base=BaseModelType("sd-1"),
|
||||||
type=ModelType("main"),
|
type=ModelType("main"),
|
||||||
original_hash="CONFIG1HASH",
|
original_hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
config2 = DiffusersConfig(
|
config2 = MainDiffusersConfig(
|
||||||
path="/tmp/config2",
|
path="/tmp/config2",
|
||||||
name="config2",
|
name="config2",
|
||||||
base=BaseModelType("sd-1"),
|
base=BaseModelType("sd-1"),
|
||||||
@ -131,17 +141,17 @@ def test_filter(store: ModelRecordServiceBase):
|
|||||||
)
|
)
|
||||||
for c in config1, config2, config3:
|
for c in config1, config2, config3:
|
||||||
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
|
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
|
||||||
matches = store.search_by_name(model_type=ModelType("main"))
|
matches = store.search_by_attr(model_type=ModelType("main"))
|
||||||
assert len(matches) == 2
|
assert len(matches) == 2
|
||||||
assert matches[0].name in {"config1", "config2"}
|
assert matches[0].name in {"config1", "config2"}
|
||||||
|
|
||||||
matches = store.search_by_name(model_type=ModelType("vae"))
|
matches = store.search_by_attr(model_type=ModelType("vae"))
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
assert matches[0].name == "config3"
|
assert matches[0].name == "config3"
|
||||||
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
|
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
|
||||||
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
|
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
|
||||||
|
|
||||||
matches = store.search_by_name(model_type=BaseModelType("sd-2"))
|
matches = store.search_by_attr(model_type=BaseModelType("sd-2"))
|
||||||
|
|
||||||
matches = store.search_by_hash("CONFIG1HASH")
|
matches = store.search_by_hash("CONFIG1HASH")
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
@ -152,28 +162,28 @@ def test_filter(store: ModelRecordServiceBase):
|
|||||||
|
|
||||||
|
|
||||||
def test_filter_2(store: ModelRecordServiceBase):
|
def test_filter_2(store: ModelRecordServiceBase):
|
||||||
config1 = DiffusersConfig(
|
config1 = MainDiffusersConfig(
|
||||||
path="/tmp/config1",
|
path="/tmp/config1",
|
||||||
name="config1",
|
name="config1",
|
||||||
base=BaseModelType("sd-1"),
|
base=BaseModelType("sd-1"),
|
||||||
type=ModelType("main"),
|
type=ModelType("main"),
|
||||||
original_hash="CONFIG1HASH",
|
original_hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
config2 = DiffusersConfig(
|
config2 = MainDiffusersConfig(
|
||||||
path="/tmp/config2",
|
path="/tmp/config2",
|
||||||
name="config2",
|
name="config2",
|
||||||
base=BaseModelType("sd-1"),
|
base=BaseModelType("sd-1"),
|
||||||
type=ModelType("main"),
|
type=ModelType("main"),
|
||||||
original_hash="CONFIG2HASH",
|
original_hash="CONFIG2HASH",
|
||||||
)
|
)
|
||||||
config3 = DiffusersConfig(
|
config3 = MainDiffusersConfig(
|
||||||
path="/tmp/config3",
|
path="/tmp/config3",
|
||||||
name="dup_name1",
|
name="dup_name1",
|
||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
type=ModelType("main"),
|
type=ModelType("main"),
|
||||||
original_hash="CONFIG3HASH",
|
original_hash="CONFIG3HASH",
|
||||||
)
|
)
|
||||||
config4 = DiffusersConfig(
|
config4 = MainDiffusersConfig(
|
||||||
path="/tmp/config4",
|
path="/tmp/config4",
|
||||||
name="dup_name1",
|
name="dup_name1",
|
||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
@ -190,19 +200,19 @@ def test_filter_2(store: ModelRecordServiceBase):
|
|||||||
for c in config1, config2, config3, config4, config5:
|
for c in config1, config2, config3, config4, config5:
|
||||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
||||||
|
|
||||||
matches = store.search_by_name(
|
matches = store.search_by_attr(
|
||||||
model_type=ModelType("main"),
|
model_type=ModelType("main"),
|
||||||
model_name="dup_name1",
|
model_name="dup_name1",
|
||||||
)
|
)
|
||||||
assert len(matches) == 2
|
assert len(matches) == 2
|
||||||
|
|
||||||
matches = store.search_by_name(
|
matches = store.search_by_attr(
|
||||||
base_model=BaseModelType("sd-1"),
|
base_model=BaseModelType("sd-1"),
|
||||||
model_type=ModelType("main"),
|
model_type=ModelType("main"),
|
||||||
)
|
)
|
||||||
assert len(matches) == 2
|
assert len(matches) == 2
|
||||||
|
|
||||||
matches = store.search_by_name(
|
matches = store.search_by_attr(
|
||||||
base_model=BaseModelType("sd-1"),
|
base_model=BaseModelType("sd-1"),
|
||||||
model_type=ModelType("vae"),
|
model_type=ModelType("vae"),
|
||||||
model_name="dup_name1",
|
model_name="dup_name1",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user