implement psychedelicious recommendations as of 13 November

This commit is contained in:
Lincoln Stein 2023-11-13 17:05:01 -05:00
parent 8afe517204
commit efbdb75568
4 changed files with 50 additions and 12 deletions

View File

@ -41,17 +41,13 @@ async def list_model_records(
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
if base_models and len(base_models) > 0:
models_raw = list()
models = list()
if base_models:
for base_model in base_models:
models_raw.extend(
[x.model_dump() for x in record_store.search_by_attr(base_model=base_model, model_type=model_type)]
)
models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
else:
models_raw = [x.model_dump() for x in record_store.search_by_attr(model_type=model_type)]
models = ModelsListValidator.validate_python({"models": models_raw})
return models
models.extend(record_store.search_by_attr(model_type=model_type))
return ModelsList(models=models)
@model_records_router.get(
"/i/{key}",

View File

@ -107,7 +107,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
@ -200,6 +202,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
if "UNIQUE constraint failed" in str(e):
if "model_config.path" in str(e):
msg = f"A model with path '{record.path}' is already installed"
elif "model_config.name" in str(e):
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
else:
msg = f"A model with key '{key}' is already installed"
raise DuplicateModelException(msg) from e

View File

@ -298,7 +298,7 @@ class ModelConfigFactory(object):
@classmethod
def make_config(
cls,
model_data: Union[dict, ModelConfigBase],
model_data: Union[dict, AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type] = None,
) -> AnyModelConfig:

View File

@ -170,6 +170,44 @@ def test_filter(store: ModelRecordServiceBase):
matches = store.all_models()
assert len(matches) == 3
def test_unique(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig(
path="/tmp/config1",
base=BaseModelType("sd-1"),
type=ModelType("main"),
name="nonuniquename",
original_hash="CONFIG1HASH",
)
config2 = MainDiffusersConfig(
path="/tmp/config2",
base=BaseModelType("sd-2"),
type=ModelType("main"),
name="nonuniquename",
original_hash="CONFIG1HASH",
)
config3 = VaeDiffusersConfig(
path="/tmp/config3",
base=BaseModelType("sd-2"),
type=ModelType("vae"),
name="nonuniquename",
original_hash="CONFIG1HASH",
)
config4 = MainDiffusersConfig(
path="/tmp/config4",
base=BaseModelType("sd-1"),
type=ModelType("main"),
name="nonuniquename",
original_hash="CONFIG1HASH",
)
# config1, config2 and config3 are compatible because they have unique combos
# of name, type and base
for c in config1, config2, config3:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
# config4 clashes with config1 and should raise an integrity error
with pytest.raises(DuplicateModelException):
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), config4)
def test_filter_2(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig(
@ -196,7 +234,7 @@ def test_filter_2(store: ModelRecordServiceBase):
config4 = MainDiffusersConfig(
path="/tmp/config4",
name="dup_name1",
base=BaseModelType("sd-2"),
base=BaseModelType("sdxl"),
type=ModelType("main"),
original_hash="CONFIG3HASH",
)