feat(mm): add hashing algos to ModelHash

- Some algos are slow, so it is now just called ModelHash
- Added all hashlib algos, plus BLAKE3 and the fast (but incorrect) SHA1 algo
This commit is contained in:
psychedelicious 2024-02-28 01:50:05 +11:00 committed by Ryan Dick
parent 4b073157b8
commit 6b41246b2d
3 changed files with 84 additions and 44 deletions

View File

@ -21,7 +21,7 @@ from invokeai.backend.model_manager.config import (
ModelConfigFactory,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.hash import ModelHash
ModelsValidator = TypeAdapter(AnyModelConfig)
@ -72,7 +72,7 @@ class MigrateModelYamlToDb1:
base_type, model_type, model_name = str(model_key).split("/")
try:
hash = FastModelHash.hash(self.config.models_path / stanza.path)
hash = ModelHash.hash(self.config.models_path / stanza.path)
except OSError:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue

View File

@ -7,53 +7,82 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import cProfile
import hashlib
import os
import pstats
import threading
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Union
from typing import Literal, Union
from blake3 import blake3
from tqdm import tqdm
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
ALGORITHMS = Literal[
"md5",
"sha1",
"sha1_fast",
"sha224",
"sha256",
"sha384",
"sha512",
"blake2b",
"blake2s",
"sha3_224",
"sha3_256",
"sha3_384",
"sha3_512",
"shake_128",
"shake_256",
"blake3",
]
class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""
class ModelHash:
"""ModelHash provides one public class method, hash()."""
@classmethod
def hash(cls, model_location: Union[str, Path]) -> str:
def hash(cls, model_location: Union[str, Path], algorithm: ALGORITHMS = "blake3") -> str:
"""
Return hexdigest string for model located at model_location.
If model_location is a directory, the hash is computed by hashing the hashes of all model files in the
directory. The final composite hash is always computed using BLAKE3.
:param model_location: Path to the model
:param algorithm: Hashing algorithm to use
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file(model_location)
return cls._hash_file(model_location, algorithm)
elif model_location.is_dir():
return cls._hash_dir(model_location)
return cls._hash_dir(model_location, algorithm)
else:
raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str:
def _hash_file(cls, model_location: Union[str, Path], algorithm: ALGORITHMS) -> str:
"""
Compute full BLAKE3 hash over a single file and return its hexdigest.
Compute the hash for a single file and return its hexdigest.
:param model_location: Path to the model file
:param algorithm: Hashing algorithm to use
"""
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(model_location)
return file_hasher.hexdigest()
if algorithm == "blake3":
return cls._blake3(model_location)
elif algorithm == "sha1_fast":
return cls._sha1_fast(model_location)
elif algorithm in hashlib.algorithms_available:
return cls._hashlib(model_location, algorithm)
else:
raise ValueError(f"Algorithm {algorithm} not available")
@classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
def _hash_dir(cls, model_location: Union[str, Path], algorithm: ALGORITHMS) -> str:
"""
Compute full BLAKE3 hash over all files in a directory and return its hexdigest.
Compute the hash for all files in a directory and return a hexdigest.
:param model_location: Path to the model directory
:param algorithm: Hashing algorithm to use
"""
components: list[str] = []
@ -61,31 +90,42 @@ class FastModelHash(object):
for file in files:
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
components.append((Path(root, file).resolve().as_posix()))
if file.endswith(MODEL_FILE_EXTENSIONS):
components.append((Path(root, file).as_posix()))
component_hashes: list[str] = []
for component in sorted(components):
component_hashes.append(cls._hash_file(component, algorithm))
for component in tqdm(sorted(components), desc=f"Hashing model components for {model_location}"):
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(component)
component_hashes.append(file_hasher.hexdigest())
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
# for the composite hash
composite_hasher = blake3()
for h in components:
composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest()
return blake3(b"".join([bytes.fromhex(h) for h in component_hashes])).hexdigest()
@staticmethod
def _blake3(file_path: Union[str, Path]) -> str:
"""Hashes a file using BLAKE3"""
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(file_path)
return file_hasher.hexdigest()
@staticmethod
def _sha1_fast(file_path: Union[str, Path]) -> str:
"""Hashes a file using SHA1, but with a block size of 2**16. The result is not a standard SHA1 hash due to the
# padding introduced by the block size. The algorithm is, however, very fast."""
BLOCK_SIZE = 2**16
file_hash = hashlib.sha1()
with open(file_path, "rb") as f:
data = f.read(BLOCK_SIZE)
file_hash.update(data)
return file_hash.hexdigest()
if __name__ == "__main__":
with TemporaryDirectory() as tempdir:
profile_path = Path(tempdir, "profile_results.pstats").as_posix()
profiler = cProfile.Profile()
profiler.enable()
t = threading.Thread(
target=FastModelHash.hash, args=("/media/rhino/invokeai/models/sd-1/main/stable-diffusion-v1-5-inpainting",)
)
t.start()
t.join()
profiler.disable()
stats = pstats.Stats(profiler).sort_stats(pstats.SortKey.TIME)
stats.dump_stats(profile_path)
os.system(f"snakeviz {profile_path}")
@staticmethod
def _hashlib(file_path: Union[str, Path], algorithm: ALGORITHMS) -> str:
"""Hashes a file using a hashlib algorithm"""
file_hasher = hashlib.new(algorithm)
with open(file_path, "rb") as f:
file_hasher.update(f.read())
return file_hasher.hexdigest()

View File

@ -21,7 +21,7 @@ from .config import (
ModelVariantType,
SchedulerPredictionType,
)
from .hash import FastModelHash
from .hash import ModelHash
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
CkptType = Dict[str, Any]
@ -147,7 +147,7 @@ class ModelProbe(object):
if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = FastModelHash.hash(model_path)
hash = ModelHash.hash(model_path)
probe = probe_class(model_path)
fields["path"] = model_path.as_posix()