feat(mm): faster hashing for spinning disk HDDs

BLAKE3 has poor performance on spinning disks when parallelized. See https://github.com/BLAKE3-team/BLAKE3/issues/31

- Replace `skip_model_hash` setting with `hashing_algorithm`. Any algorithm we support is accepted.
- Add `random` algorithm: hashes a UUID with BLAKE3 to create a random "hash". Equivalent to the previous skip functionality.
- Add `blake3_single` algorithm: hashes on a single thread using BLAKE3, fixes the aforementioned performance issue
- Update model probe to accept the algorithm to hash with as an optional arg, defaulting to `blake3`
- Update all calls of the probe to use the app's configured hashing algorithm
- Update an external script that probes models
- Update tests
- Move ModelHash into its own module to avoid circuclar import issues
This commit is contained in:
psychedelicious
2024-03-14 09:44:55 +11:00
parent 8287fcf097
commit eb6e6548ed
6 changed files with 78 additions and 33 deletions

View File

@ -1,12 +1,4 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Fast hashing of diffusers and checkpoint-style models.
Usage:
from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import hashlib
import os
@ -15,9 +7,9 @@ from typing import Callable, Literal, Optional, Union
from blake3 import blake3
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
from invokeai.app.util.misc import uuid_string
ALGORITHM = Literal[
HASHING_ALGORITHMS = Literal[
"md5",
"sha1",
"sha224",
@ -33,7 +25,10 @@ ALGORITHM = Literal[
"shake_128",
"shake_256",
"blake3",
"blake3_single",
"random",
]
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
class ModelHash:
@ -53,6 +48,8 @@ class ModelHash:
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
A convenience algorithm choice of "random" is also available, which returns a random string. This is not a hash.
Usage:
```py
# BLAKE3 hash
@ -62,11 +59,17 @@ class ModelHash:
```
"""
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
def __init__(
self, algorithm: HASHING_ALGORITHMS = "blake3", file_filter: Optional[Callable[[str], bool]] = None
) -> None:
if algorithm == "blake3":
self._hash_file = self._blake3
elif algorithm == "blake3_single":
self._hash_file = self._blake3_single
elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm)
elif algorithm == "random":
self._hash_file = self._random
else:
raise ValueError(f"Algorithm {algorithm} not available")
@ -137,7 +140,7 @@ class ModelHash:
@staticmethod
def _blake3(file_path: Path) -> str:
"""Hashes a file using BLAKE3
"""Hashes a file using BLAKE3, using parallelized and memory-mapped I/O to avoid reading the entire file into memory.
Args:
file_path: Path to the file to hash
@ -150,7 +153,21 @@ class ModelHash:
return file_hasher.hexdigest()
@staticmethod
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
def _blake3_single(file_path: Path) -> str:
"""Hashes a file using BLAKE3, without parallelism. Suitable for spinning hard drives.
Args:
file_path: Path to the file to hash
Returns:
Hexdigest of the hash of the file
"""
file_hasher = blake3()
file_hasher.update_mmap(file_path)
return file_hasher.hexdigest()
@staticmethod
def _get_hashlib(algorithm: HASHING_ALGORITHMS) -> Callable[[Path], str]:
"""Factory function that returns a function to hash a file with the given algorithm.
Args:
@ -172,6 +189,13 @@ class ModelHash:
return hashlib_hasher
@staticmethod
def _random(_file_path: Path) -> str:
"""Returns a random string. This is not a hash.
The string is a UUID, hashed with BLAKE3 to ensure that it is unique."""
return blake3(uuid_string().encode()).hexdigest()
@staticmethod
def _default_file_filter(file_path: str) -> bool:
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth

View File

@ -9,6 +9,7 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.util.util import SilenceWarnings
from .config import (
@ -24,7 +25,6 @@ from .config import (
ModelVariantType,
SchedulerPredictionType,
)
from .hash import ModelHash
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
CkptType = Dict[str, Any]
@ -113,9 +113,7 @@ class ModelProbe(object):
@classmethod
def probe(
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, hash_algo: HASHING_ALGORITHMS = "blake3"
) -> AnyModelConfig:
"""
Probe the model at model_path and return its configuration record.
@ -160,7 +158,7 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["default_settings"] = (
fields.get("default_settings") or probe.get_default_settings(fields["name"])