mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implement psychedelicious recommendations as of 13 November
This commit is contained in:
parent
8afe517204
commit
efbdb75568
@ -41,17 +41,13 @@ async def list_model_records(
|
|||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Get a list of models."""
|
"""Get a list of models."""
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
record_store = ApiDependencies.invoker.services.model_records
|
||||||
if base_models and len(base_models) > 0:
|
models = list()
|
||||||
models_raw = list()
|
if base_models:
|
||||||
for base_model in base_models:
|
for base_model in base_models:
|
||||||
models_raw.extend(
|
models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
|
||||||
[x.model_dump() for x in record_store.search_by_attr(base_model=base_model, model_type=model_type)]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
models_raw = [x.model_dump() for x in record_store.search_by_attr(model_type=model_type)]
|
models.extend(record_store.search_by_attr(model_type=model_type))
|
||||||
models = ModelsListValidator.validate_python({"models": models_raw})
|
return ModelsList(models=models)
|
||||||
return models
|
|
||||||
|
|
||||||
|
|
||||||
@model_records_router.get(
|
@model_records_router.get(
|
||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
|
@ -107,7 +107,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
config TEXT NOT NULL,
|
config TEXT NOT NULL,
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
-- Updated via trigger
|
-- Updated via trigger
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- unique constraint on combo of name, base and type
|
||||||
|
UNIQUE(name, base, type)
|
||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
@ -200,6 +202,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
if "UNIQUE constraint failed" in str(e):
|
if "UNIQUE constraint failed" in str(e):
|
||||||
if "model_config.path" in str(e):
|
if "model_config.path" in str(e):
|
||||||
msg = f"A model with path '{record.path}' is already installed"
|
msg = f"A model with path '{record.path}' is already installed"
|
||||||
|
elif "model_config.name" in str(e):
|
||||||
|
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
||||||
else:
|
else:
|
||||||
msg = f"A model with key '{key}' is already installed"
|
msg = f"A model with key '{key}' is already installed"
|
||||||
raise DuplicateModelException(msg) from e
|
raise DuplicateModelException(msg) from e
|
||||||
|
@ -298,7 +298,7 @@ class ModelConfigFactory(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def make_config(
|
def make_config(
|
||||||
cls,
|
cls,
|
||||||
model_data: Union[dict, ModelConfigBase],
|
model_data: Union[dict, AnyModelConfig],
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
dest_class: Optional[Type] = None,
|
dest_class: Optional[Type] = None,
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
|
@ -170,6 +170,44 @@ def test_filter(store: ModelRecordServiceBase):
|
|||||||
matches = store.all_models()
|
matches = store.all_models()
|
||||||
assert len(matches) == 3
|
assert len(matches) == 3
|
||||||
|
|
||||||
|
def test_unique(store: ModelRecordServiceBase):
|
||||||
|
config1 = MainDiffusersConfig(
|
||||||
|
path="/tmp/config1",
|
||||||
|
base=BaseModelType("sd-1"),
|
||||||
|
type=ModelType("main"),
|
||||||
|
name="nonuniquename",
|
||||||
|
original_hash="CONFIG1HASH",
|
||||||
|
)
|
||||||
|
config2 = MainDiffusersConfig(
|
||||||
|
path="/tmp/config2",
|
||||||
|
base=BaseModelType("sd-2"),
|
||||||
|
type=ModelType("main"),
|
||||||
|
name="nonuniquename",
|
||||||
|
original_hash="CONFIG1HASH",
|
||||||
|
)
|
||||||
|
config3 = VaeDiffusersConfig(
|
||||||
|
path="/tmp/config3",
|
||||||
|
base=BaseModelType("sd-2"),
|
||||||
|
type=ModelType("vae"),
|
||||||
|
name="nonuniquename",
|
||||||
|
original_hash="CONFIG1HASH",
|
||||||
|
)
|
||||||
|
config4 = MainDiffusersConfig(
|
||||||
|
path="/tmp/config4",
|
||||||
|
base=BaseModelType("sd-1"),
|
||||||
|
type=ModelType("main"),
|
||||||
|
name="nonuniquename",
|
||||||
|
original_hash="CONFIG1HASH",
|
||||||
|
)
|
||||||
|
# config1, config2 and config3 are compatible because they have unique combos
|
||||||
|
# of name, type and base
|
||||||
|
for c in config1, config2, config3:
|
||||||
|
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
||||||
|
|
||||||
|
# config4 clashes with config1 and should raise an integrity error
|
||||||
|
with pytest.raises(DuplicateModelException):
|
||||||
|
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), config4)
|
||||||
|
|
||||||
|
|
||||||
def test_filter_2(store: ModelRecordServiceBase):
|
def test_filter_2(store: ModelRecordServiceBase):
|
||||||
config1 = MainDiffusersConfig(
|
config1 = MainDiffusersConfig(
|
||||||
@ -196,7 +234,7 @@ def test_filter_2(store: ModelRecordServiceBase):
|
|||||||
config4 = MainDiffusersConfig(
|
config4 = MainDiffusersConfig(
|
||||||
path="/tmp/config4",
|
path="/tmp/config4",
|
||||||
name="dup_name1",
|
name="dup_name1",
|
||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sdxl"),
|
||||||
type=ModelType("main"),
|
type=ModelType("main"),
|
||||||
original_hash="CONFIG3HASH",
|
original_hash="CONFIG3HASH",
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user