From 7b2f81babbcf567bc93fd937dca21344bfa39b03 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 27 Feb 2024 20:51:49 +1100 Subject: [PATCH] feat(mm): use blake3 for hashing --- invokeai/backend/model_manager/hash.py | 60 +++++++++++--------------- pyproject.toml | 2 +- 2 files changed, 26 insertions(+), 36 deletions(-) diff --git a/invokeai/backend/model_manager/hash.py b/invokeai/backend/model_manager/hash.py index c4f4165ebf..9ca778a65e 100644 --- a/invokeai/backend/model_manager/hash.py +++ b/invokeai/backend/model_manager/hash.py @@ -7,13 +7,12 @@ from invokeai.backend.model_managre.model_hash import FastModelHash >>> FastModelHash.hash('/home/models/stable-diffusion-v1.5') 'a8e693a126ea5b831c96064dc569956f' """ - -import hashlib import os from pathlib import Path -from typing import Dict, Union +from typing import Union -from imohash import hashfile +from blake3 import blake3 +from tqdm import tqdm class FastModelHash(object): @@ -28,53 +27,44 @@ class FastModelHash(object): """ model_location = Path(model_location) if model_location.is_file(): - return cls._hash_file_sha1(model_location) + return cls._hash_file(model_location) elif model_location.is_dir(): return cls._hash_dir(model_location) else: raise OSError(f"Not a valid file or directory: {model_location}") @classmethod - def _hash_file_sha1(cls, model_location: Union[str, Path]) -> str: + def _hash_file(cls, model_location: Union[str, Path]) -> str: """ - Compute full sha1 hash over a single file and return its hexdigest. + Compute full BLAKE3 hash over a single file and return its hexdigest. :param model_location: Path to the model file """ - BLOCK_SIZE = 65536 - file_hash = hashlib.sha1() - with open(model_location, "rb") as f: - data = f.read(BLOCK_SIZE) - file_hash.update(data) - return file_hash.hexdigest() - - @classmethod - def _hash_file_fast(cls, model_location: Union[str, Path]) -> str: - """ - Fasthash a single file and return its hexdigest. - - :param model_location: Path to the model file - """ - # we return md5 hash of the filehash to make it shorter - # cryptographic security not needed here - return hashlib.md5(hashfile(model_location)).hexdigest() + file_hasher = blake3(max_threads=blake3.AUTO) + file_hasher.update_mmap(model_location) + return file_hasher.hexdigest() @classmethod def _hash_dir(cls, model_location: Union[str, Path]) -> str: - components: Dict[str, str] = {} + """ + Compute full BLAKE3 hash over all files in a directory and return its hexdigest. + + :param model_location: Path to the model directory + """ + components: list[str] = [] for root, _dirs, files in os.walk(model_location): for file in files: # only tally tensor files because diffusers config files change slightly # depending on how the model was downloaded/converted. - if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")): - continue - path = (Path(root) / file).as_posix() - fast_hash = cls._hash_file_fast(path) - components.update({path: fast_hash}) + if file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")): + components.append((Path(root, file).as_posix())) - # hash all the model hashes together, using alphabetic file order - md5 = hashlib.md5() - for _path, fast_hash in sorted(components.items()): - md5.update(fast_hash.encode("utf-8")) - return md5.hexdigest() + component_hashes: list[str] = [] + + for component in tqdm(sorted(components), desc=f"Hashing model components for {model_location}"): + file_hasher = blake3(max_threads=blake3.AUTO) + file_hasher.update_mmap(component) + component_hashes.append(file_hasher.hexdigest()) + + return blake3(b"".join([bytes.fromhex(h) for h in component_hashes])).hexdigest() diff --git a/pyproject.toml b/pyproject.toml index 5345851951..d823638c1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ dependencies = [ # Auxiliary dependencies, pinned only if necessary. "albumentations", + "blake3", "click", "datasets", "Deprecated", @@ -72,7 +73,6 @@ dependencies = [ "easing-functions", "einops", "facexlib", - "imohash", "matplotlib", # needed for plotting of Penner easing functions "npyscreen", "omegaconf",