mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: ruff lint
This commit is contained in:
parent
4465f97cdf
commit
bc64cde6f9
@ -42,7 +42,7 @@ async def list_model_records(
|
|||||||
"""Get a list of models."""
|
"""Get a list of models."""
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
record_store = ApiDependencies.invoker.services.model_records
|
||||||
if base_models and len(base_models) > 0:
|
if base_models and len(base_models) > 0:
|
||||||
models_raw = list()
|
models_raw = []
|
||||||
for base_model in base_models:
|
for base_model in base_models:
|
||||||
models_raw.extend(
|
models_raw.extend(
|
||||||
[x.model_dump() for x in record_store.search_by_attr(base_model=base_model, model_type=model_type)]
|
[x.model_dump() for x in record_store.search_by_attr(base_model=base_model, model_type=model_type)]
|
||||||
|
@ -49,7 +49,7 @@ class FastModelHash(object):
|
|||||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
||||||
components: Dict[str, str] = {}
|
components: Dict[str, str] = {}
|
||||||
|
|
||||||
for root, dirs, files in os.walk(model_location):
|
for root, _dirs, files in os.walk(model_location):
|
||||||
for file in files:
|
for file in files:
|
||||||
# only tally tensor files because diffusers config files change slightly
|
# only tally tensor files because diffusers config files change slightly
|
||||||
# depending on how the model was downloaded/converted.
|
# depending on how the model was downloaded/converted.
|
||||||
@ -61,6 +61,6 @@ class FastModelHash(object):
|
|||||||
|
|
||||||
# hash all the model hashes together, using alphabetic file order
|
# hash all the model hashes together, using alphabetic file order
|
||||||
md5 = hashlib.md5()
|
md5 = hashlib.md5()
|
||||||
for path, fast_hash in sorted(components.items()):
|
for _path, fast_hash in sorted(components.items()):
|
||||||
md5.update(fast_hash.encode("utf-8"))
|
md5.update(fast_hash.encode("utf-8"))
|
||||||
return md5.hexdigest()
|
return md5.hexdigest()
|
||||||
|
@ -52,16 +52,16 @@ def test_type(store: ModelRecordServiceBase):
|
|||||||
|
|
||||||
|
|
||||||
def test_add(store: ModelRecordServiceBase):
|
def test_add(store: ModelRecordServiceBase):
|
||||||
raw = dict(
|
raw = {
|
||||||
path="/tmp/foo.ckpt",
|
"path": "/tmp/foo.ckpt",
|
||||||
name="model1",
|
"name": "model1",
|
||||||
base=BaseModelType("sd-1"),
|
"base": BaseModelType("sd-1"),
|
||||||
type="main",
|
"type": "main",
|
||||||
config="/tmp/foo.yaml",
|
"config": "/tmp/foo.yaml",
|
||||||
variant="normal",
|
"variant": "normal",
|
||||||
format="checkpoint",
|
"format": "checkpoint",
|
||||||
original_hash="111222333444",
|
"original_hash": "111222333444",
|
||||||
)
|
}
|
||||||
store.add_model("key1", raw)
|
store.add_model("key1", raw)
|
||||||
config1 = store.get_model("key1")
|
config1 = store.get_model("key1")
|
||||||
assert config1 is not None
|
assert config1 is not None
|
||||||
|
Loading…
Reference in New Issue
Block a user