mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): faster hashing for spinning disk HDDs
BLAKE3 has poor performance on spinning disks when parallelized. See https://github.com/BLAKE3-team/BLAKE3/issues/31 - Replace `skip_model_hash` setting with `hashing_algorithm`. Any algorithm we support is accepted. - Add `random` algorithm: hashes a UUID with BLAKE3 to create a random "hash". Equivalent to the previous skip functionality. - Add `blake3_single` algorithm: hashes on a single thread using BLAKE3, fixes the aforementioned performance issue - Update model probe to accept the algorithm to hash with as an optional arg, defaulting to `blake3` - Update all calls of the probe to use the app's configured hashing algorithm - Update an external script that probes models - Update tests - Move ModelHash into its own module to avoid circuclar import issues
This commit is contained in:
@ -1,185 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Fast hashing of diffusers and checkpoint-style models.
|
||||
|
||||
Usage:
|
||||
from invokeai.backend.model_managre.model_hash import FastModelHash
|
||||
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
||||
'a8e693a126ea5b831c96064dc569956f'
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
from blake3 import blake3
|
||||
|
||||
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||
|
||||
ALGORITHM = Literal[
|
||||
"md5",
|
||||
"sha1",
|
||||
"sha224",
|
||||
"sha256",
|
||||
"sha384",
|
||||
"sha512",
|
||||
"blake2b",
|
||||
"blake2s",
|
||||
"sha3_224",
|
||||
"sha3_256",
|
||||
"sha3_384",
|
||||
"sha3_512",
|
||||
"shake_128",
|
||||
"shake_256",
|
||||
"blake3",
|
||||
]
|
||||
|
||||
|
||||
class ModelHash:
|
||||
"""
|
||||
Creates a hash of a model using a specified algorithm.
|
||||
|
||||
Args:
|
||||
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||
file_filter: A function that takes a file name and returns True if the file should be included in the hash.
|
||||
|
||||
If the model is a single file, it is hashed directly using the provided algorithm.
|
||||
|
||||
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
|
||||
|
||||
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
|
||||
|
||||
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
||||
that directory hashes are never weaker than the file hashes.
|
||||
|
||||
Usage:
|
||||
```py
|
||||
# BLAKE3 hash
|
||||
ModelHash().hash("path/to/some/model.safetensors")
|
||||
# MD5
|
||||
ModelHash("md5").hash("path/to/model/dir/")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
|
||||
if algorithm == "blake3":
|
||||
self._hash_file = self._blake3
|
||||
elif algorithm in hashlib.algorithms_available:
|
||||
self._hash_file = self._get_hashlib(algorithm)
|
||||
else:
|
||||
raise ValueError(f"Algorithm {algorithm} not available")
|
||||
|
||||
self._file_filter = file_filter or self._default_file_filter
|
||||
|
||||
def hash(self, model_path: Union[str, Path]) -> str:
|
||||
"""
|
||||
Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation.
|
||||
|
||||
If model_path is a directory, the hash is computed by hashing the hashes of all model files in the
|
||||
directory. The final composite hash is always computed using BLAKE3.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
|
||||
Returns:
|
||||
str: Hexdigest of the hash of the model
|
||||
"""
|
||||
|
||||
model_path = Path(model_path)
|
||||
if model_path.is_file():
|
||||
return self._hash_file(model_path)
|
||||
elif model_path.is_dir():
|
||||
return self._hash_dir(model_path)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_path}")
|
||||
|
||||
def _hash_dir(self, dir: Path) -> str:
|
||||
"""Compute the hash for all files in a directory and return a hexdigest.
|
||||
|
||||
Args:
|
||||
dir: Path to the directory
|
||||
|
||||
Returns:
|
||||
str: Hexdigest of the hash of the directory
|
||||
"""
|
||||
model_component_paths = self._get_file_paths(dir, self._file_filter)
|
||||
|
||||
component_hashes: list[str] = []
|
||||
for component in sorted(model_component_paths):
|
||||
component_hashes.append(self._hash_file(component))
|
||||
|
||||
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
||||
# for the composite hash
|
||||
composite_hasher = blake3()
|
||||
for h in component_hashes:
|
||||
composite_hasher.update(h.encode("utf-8"))
|
||||
return composite_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
|
||||
"""Return a list of all model files in the directory.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
file_filter: Function that takes a file name and returns True if the file should be included in the list.
|
||||
|
||||
Returns:
|
||||
List of all model files in the directory
|
||||
"""
|
||||
|
||||
files: list[Path] = []
|
||||
for root, _dirs, _files in os.walk(model_path):
|
||||
for file in _files:
|
||||
if file_filter(file):
|
||||
files.append(Path(root, file))
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
def _blake3(file_path: Path) -> str:
|
||||
"""Hashes a file using BLAKE3
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to hash
|
||||
|
||||
Returns:
|
||||
Hexdigest of the hash of the file
|
||||
"""
|
||||
file_hasher = blake3(max_threads=blake3.AUTO)
|
||||
file_hasher.update_mmap(file_path)
|
||||
return file_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
||||
"""Factory function that returns a function to hash a file with the given algorithm.
|
||||
|
||||
Args:
|
||||
algorithm: Hashing algorithm to use
|
||||
|
||||
Returns:
|
||||
A function that hashes a file using the given algorithm
|
||||
"""
|
||||
|
||||
def hashlib_hasher(file_path: Path) -> str:
|
||||
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
|
||||
hasher = hashlib.new(algorithm)
|
||||
buffer = bytearray(128 * 1024)
|
||||
mv = memoryview(buffer)
|
||||
with open(file_path, "rb", buffering=0) as f:
|
||||
while n := f.readinto(mv):
|
||||
hasher.update(mv[:n])
|
||||
return hasher.hexdigest()
|
||||
|
||||
return hashlib_hasher
|
||||
|
||||
@staticmethod
|
||||
def _default_file_filter(file_path: str) -> bool:
|
||||
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
True if the file matches the given extensions, otherwise False
|
||||
"""
|
||||
return file_path.endswith(MODEL_FILE_EXTENSIONS)
|
@ -9,6 +9,7 @@ from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.util.util import SilenceWarnings
|
||||
|
||||
from .config import (
|
||||
@ -24,7 +25,6 @@ from .config import (
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .hash import ModelHash
|
||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
|
||||
CkptType = Dict[str, Any]
|
||||
@ -113,9 +113,7 @@ class ModelProbe(object):
|
||||
|
||||
@classmethod
|
||||
def probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
fields: Optional[Dict[str, Any]] = None,
|
||||
cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, hash_algo: HASHING_ALGORITHMS = "blake3"
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Probe the model at model_path and return its configuration record.
|
||||
@ -160,7 +158,7 @@ class ModelProbe(object):
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||
|
||||
fields["default_settings"] = (
|
||||
fields.get("default_settings") or probe.get_default_settings(fields["name"])
|
||||
|
Reference in New Issue
Block a user