implement psychedelicious recommendations as of 13 November

This commit is contained in:
Lincoln Stein 2023-11-13 17:05:01 -05:00
parent 8afe517204
commit efbdb75568
4 changed files with 50 additions and 12 deletions

View File

@ -41,17 +41,13 @@ async def list_model_records(
) -> ModelsList: ) -> ModelsList:
"""Get a list of models.""" """Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records record_store = ApiDependencies.invoker.services.model_records
if base_models and len(base_models) > 0: models = list()
models_raw = list() if base_models:
for base_model in base_models: for base_model in base_models:
models_raw.extend( models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
[x.model_dump() for x in record_store.search_by_attr(base_model=base_model, model_type=model_type)]
)
else: else:
models_raw = [x.model_dump() for x in record_store.search_by_attr(model_type=model_type)] models.extend(record_store.search_by_attr(model_type=model_type))
models = ModelsListValidator.validate_python({"models": models_raw}) return ModelsList(models=models)
return models
@model_records_router.get( @model_records_router.get(
"/i/{key}", "/i/{key}",

View File

@ -107,7 +107,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
config TEXT NOT NULL, config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger -- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
); );
""" """
) )
@ -200,6 +202,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
if "UNIQUE constraint failed" in str(e): if "UNIQUE constraint failed" in str(e):
if "model_config.path" in str(e): if "model_config.path" in str(e):
msg = f"A model with path '{record.path}' is already installed" msg = f"A model with path '{record.path}' is already installed"
elif "model_config.name" in str(e):
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
else: else:
msg = f"A model with key '{key}' is already installed" msg = f"A model with key '{key}' is already installed"
raise DuplicateModelException(msg) from e raise DuplicateModelException(msg) from e

View File

@ -298,7 +298,7 @@ class ModelConfigFactory(object):
@classmethod @classmethod
def make_config( def make_config(
cls, cls,
model_data: Union[dict, ModelConfigBase], model_data: Union[dict, AnyModelConfig],
key: Optional[str] = None, key: Optional[str] = None,
dest_class: Optional[Type] = None, dest_class: Optional[Type] = None,
) -> AnyModelConfig: ) -> AnyModelConfig:

View File

@ -170,6 +170,44 @@ def test_filter(store: ModelRecordServiceBase):
matches = store.all_models() matches = store.all_models()
assert len(matches) == 3 assert len(matches) == 3
def test_unique(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig(
path="/tmp/config1",
base=BaseModelType("sd-1"),
type=ModelType("main"),
name="nonuniquename",
original_hash="CONFIG1HASH",
)
config2 = MainDiffusersConfig(
path="/tmp/config2",
base=BaseModelType("sd-2"),
type=ModelType("main"),
name="nonuniquename",
original_hash="CONFIG1HASH",
)
config3 = VaeDiffusersConfig(
path="/tmp/config3",
base=BaseModelType("sd-2"),
type=ModelType("vae"),
name="nonuniquename",
original_hash="CONFIG1HASH",
)
config4 = MainDiffusersConfig(
path="/tmp/config4",
base=BaseModelType("sd-1"),
type=ModelType("main"),
name="nonuniquename",
original_hash="CONFIG1HASH",
)
# config1, config2 and config3 are compatible because they have unique combos
# of name, type and base
for c in config1, config2, config3:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
# config4 clashes with config1 and should raise an integrity error
with pytest.raises(DuplicateModelException):
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), config4)
def test_filter_2(store: ModelRecordServiceBase): def test_filter_2(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig( config1 = MainDiffusersConfig(
@ -196,7 +234,7 @@ def test_filter_2(store: ModelRecordServiceBase):
config4 = MainDiffusersConfig( config4 = MainDiffusersConfig(
path="/tmp/config4", path="/tmp/config4",
name="dup_name1", name="dup_name1",
base=BaseModelType("sd-2"), base=BaseModelType("sdxl"),
type=ModelType("main"), type=ModelType("main"),
original_hash="CONFIG3HASH", original_hash="CONFIG3HASH",
) )