chore: ruff lint

This commit is contained in:
psychedelicious 2023-11-14 07:57:07 +11:00
parent 4465f97cdf
commit bc64cde6f9
3 changed files with 13 additions and 13 deletions

View File

@ -42,7 +42,7 @@ async def list_model_records(
"""Get a list of models.""" """Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records record_store = ApiDependencies.invoker.services.model_records
if base_models and len(base_models) > 0: if base_models and len(base_models) > 0:
models_raw = list() models_raw = []
for base_model in base_models: for base_model in base_models:
models_raw.extend( models_raw.extend(
[x.model_dump() for x in record_store.search_by_attr(base_model=base_model, model_type=model_type)] [x.model_dump() for x in record_store.search_by_attr(base_model=base_model, model_type=model_type)]

View File

@ -49,7 +49,7 @@ class FastModelHash(object):
def _hash_dir(cls, model_location: Union[str, Path]) -> str: def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {} components: Dict[str, str] = {}
for root, dirs, files in os.walk(model_location): for root, _dirs, files in os.walk(model_location):
for file in files: for file in files:
# only tally tensor files because diffusers config files change slightly # only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted. # depending on how the model was downloaded/converted.
@ -61,6 +61,6 @@ class FastModelHash(object):
# hash all the model hashes together, using alphabetic file order # hash all the model hashes together, using alphabetic file order
md5 = hashlib.md5() md5 = hashlib.md5()
for path, fast_hash in sorted(components.items()): for _path, fast_hash in sorted(components.items()):
md5.update(fast_hash.encode("utf-8")) md5.update(fast_hash.encode("utf-8"))
return md5.hexdigest() return md5.hexdigest()

View File

@ -52,16 +52,16 @@ def test_type(store: ModelRecordServiceBase):
def test_add(store: ModelRecordServiceBase): def test_add(store: ModelRecordServiceBase):
raw = dict( raw = {
path="/tmp/foo.ckpt", "path": "/tmp/foo.ckpt",
name="model1", "name": "model1",
base=BaseModelType("sd-1"), "base": BaseModelType("sd-1"),
type="main", "type": "main",
config="/tmp/foo.yaml", "config": "/tmp/foo.yaml",
variant="normal", "variant": "normal",
format="checkpoint", "format": "checkpoint",
original_hash="111222333444", "original_hash": "111222333444",
) }
store.add_model("key1", raw) store.add_model("key1", raw)
config1 = store.get_model("key1") config1 = store.get_model("key1")
assert config1 is not None assert config1 is not None