feat(mm): use blake3 for hashing

This commit is contained in:
psychedelicious 2024-02-27 20:51:49 +11:00
parent 2b1cb569eb
commit 7b2f81babb
2 changed files with 26 additions and 36 deletions

View File

@ -7,13 +7,12 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5') >>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f' 'a8e693a126ea5b831c96064dc569956f'
""" """
import hashlib
import os import os
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Union
from imohash import hashfile from blake3 import blake3
from tqdm import tqdm
class FastModelHash(object): class FastModelHash(object):
@ -28,53 +27,44 @@ class FastModelHash(object):
""" """
model_location = Path(model_location) model_location = Path(model_location)
if model_location.is_file(): if model_location.is_file():
return cls._hash_file_sha1(model_location) return cls._hash_file(model_location)
elif model_location.is_dir(): elif model_location.is_dir():
return cls._hash_dir(model_location) return cls._hash_dir(model_location)
else: else:
raise OSError(f"Not a valid file or directory: {model_location}") raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod @classmethod
def _hash_file_sha1(cls, model_location: Union[str, Path]) -> str: def _hash_file(cls, model_location: Union[str, Path]) -> str:
""" """
Compute full sha1 hash over a single file and return its hexdigest. Compute full BLAKE3 hash over a single file and return its hexdigest.
:param model_location: Path to the model file :param model_location: Path to the model file
""" """
BLOCK_SIZE = 65536 file_hasher = blake3(max_threads=blake3.AUTO)
file_hash = hashlib.sha1() file_hasher.update_mmap(model_location)
with open(model_location, "rb") as f: return file_hasher.hexdigest()
data = f.read(BLOCK_SIZE)
file_hash.update(data)
return file_hash.hexdigest()
@classmethod
def _hash_file_fast(cls, model_location: Union[str, Path]) -> str:
"""
Fasthash a single file and return its hexdigest.
:param model_location: Path to the model file
"""
# we return md5 hash of the filehash to make it shorter
# cryptographic security not needed here
return hashlib.md5(hashfile(model_location)).hexdigest()
@classmethod @classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str: def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {} """
Compute full BLAKE3 hash over all files in a directory and return its hexdigest.
:param model_location: Path to the model directory
"""
components: list[str] = []
for root, _dirs, files in os.walk(model_location): for root, _dirs, files in os.walk(model_location):
for file in files: for file in files:
# only tally tensor files because diffusers config files change slightly # only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted. # depending on how the model was downloaded/converted.
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")): if file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue components.append((Path(root, file).as_posix()))
path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file_fast(path)
components.update({path: fast_hash})
# hash all the model hashes together, using alphabetic file order component_hashes: list[str] = []
md5 = hashlib.md5()
for _path, fast_hash in sorted(components.items()): for component in tqdm(sorted(components), desc=f"Hashing model components for {model_location}"):
md5.update(fast_hash.encode("utf-8")) file_hasher = blake3(max_threads=blake3.AUTO)
return md5.hexdigest() file_hasher.update_mmap(component)
component_hashes.append(file_hasher.hexdigest())
return blake3(b"".join([bytes.fromhex(h) for h in component_hashes])).hexdigest()

View File

@ -64,6 +64,7 @@ dependencies = [
# Auxiliary dependencies, pinned only if necessary. # Auxiliary dependencies, pinned only if necessary.
"albumentations", "albumentations",
"blake3",
"click", "click",
"datasets", "datasets",
"Deprecated", "Deprecated",
@ -72,7 +73,6 @@ dependencies = [
"easing-functions", "easing-functions",
"einops", "einops",
"facexlib", "facexlib",
"imohash",
"matplotlib", # needed for plotting of Penner easing functions "matplotlib", # needed for plotting of Penner easing functions
"npyscreen", "npyscreen",
"omegaconf", "omegaconf",