mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implement psychedelicious recommendations as of 13 November
This commit is contained in:
parent
8afe517204
commit
efbdb75568
@ -41,17 +41,13 @@ async def list_model_records(
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
if base_models and len(base_models) > 0:
|
||||
models_raw = list()
|
||||
models = list()
|
||||
if base_models:
|
||||
for base_model in base_models:
|
||||
models_raw.extend(
|
||||
[x.model_dump() for x in record_store.search_by_attr(base_model=base_model, model_type=model_type)]
|
||||
)
|
||||
models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
|
||||
else:
|
||||
models_raw = [x.model_dump() for x in record_store.search_by_attr(model_type=model_type)]
|
||||
models = ModelsListValidator.validate_python({"models": models_raw})
|
||||
return models
|
||||
|
||||
models.extend(record_store.search_by_attr(model_type=model_type))
|
||||
return ModelsList(models=models)
|
||||
|
||||
@model_records_router.get(
|
||||
"/i/{key}",
|
||||
|
@ -107,7 +107,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
)
|
||||
@ -200,6 +202,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "model_config.path" in str(e):
|
||||
msg = f"A model with path '{record.path}' is already installed"
|
||||
elif "model_config.name" in str(e):
|
||||
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
|
@ -298,7 +298,7 @@ class ModelConfigFactory(object):
|
||||
@classmethod
|
||||
def make_config(
|
||||
cls,
|
||||
model_data: Union[dict, ModelConfigBase],
|
||||
model_data: Union[dict, AnyModelConfig],
|
||||
key: Optional[str] = None,
|
||||
dest_class: Optional[Type] = None,
|
||||
) -> AnyModelConfig:
|
||||
|
@ -170,6 +170,44 @@ def test_filter(store: ModelRecordServiceBase):
|
||||
matches = store.all_models()
|
||||
assert len(matches) == 3
|
||||
|
||||
def test_unique(store: ModelRecordServiceBase):
|
||||
config1 = MainDiffusersConfig(
|
||||
path="/tmp/config1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
config2 = MainDiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("main"),
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
config3 = VaeDiffusersConfig(
|
||||
path="/tmp/config3",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("vae"),
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
config4 = MainDiffusersConfig(
|
||||
path="/tmp/config4",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
# config1, config2 and config3 are compatible because they have unique combos
|
||||
# of name, type and base
|
||||
for c in config1, config2, config3:
|
||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
||||
|
||||
# config4 clashes with config1 and should raise an integrity error
|
||||
with pytest.raises(DuplicateModelException):
|
||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), config4)
|
||||
|
||||
|
||||
def test_filter_2(store: ModelRecordServiceBase):
|
||||
config1 = MainDiffusersConfig(
|
||||
@ -196,7 +234,7 @@ def test_filter_2(store: ModelRecordServiceBase):
|
||||
config4 = MainDiffusersConfig(
|
||||
path="/tmp/config4",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sd-2"),
|
||||
base=BaseModelType("sdxl"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG3HASH",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user