feat(mm): use blake3 for hashing

This commit is contained in:
psychedelicious 2024-02-27 20:51:49 +11:00
parent a72056e0df
commit 908e915a71
2 changed files with 26 additions and 36 deletions

View File

@ -7,13 +7,12 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import hashlib
import os
from pathlib import Path
from typing import Dict, Union
from typing import Union
from imohash import hashfile
from blake3 import blake3
from tqdm import tqdm
class FastModelHash(object):
@ -28,53 +27,44 @@ class FastModelHash(object):
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file_sha1(model_location)
return cls._hash_file(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location)
else:
raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod
def _hash_file_sha1(cls, model_location: Union[str, Path]) -> str:
def _hash_file(cls, model_location: Union[str, Path]) -> str:
"""
Compute full sha1 hash over a single file and return its hexdigest.
Compute full BLAKE3 hash over a single file and return its hexdigest.
:param model_location: Path to the model file
"""
BLOCK_SIZE = 65536
file_hash = hashlib.sha1()
with open(model_location, "rb") as f:
data = f.read(BLOCK_SIZE)
file_hash.update(data)
return file_hash.hexdigest()
@classmethod
def _hash_file_fast(cls, model_location: Union[str, Path]) -> str:
"""
Fasthash a single file and return its hexdigest.
:param model_location: Path to the model file
"""
# we return md5 hash of the filehash to make it shorter
# cryptographic security not needed here
return hashlib.md5(hashfile(model_location)).hexdigest()
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(model_location)
return file_hasher.hexdigest()
@classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {}
"""
Compute full BLAKE3 hash over all files in a directory and return its hexdigest.
:param model_location: Path to the model directory
"""
components: list[str] = []
for root, _dirs, files in os.walk(model_location):
for file in files:
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue
path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file_fast(path)
components.update({path: fast_hash})
if file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
components.append((Path(root, file).as_posix()))
# hash all the model hashes together, using alphabetic file order
md5 = hashlib.md5()
for _path, fast_hash in sorted(components.items()):
md5.update(fast_hash.encode("utf-8"))
return md5.hexdigest()
component_hashes: list[str] = []
for component in tqdm(sorted(components), desc=f"Hashing model components for {model_location}"):
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(component)
component_hashes.append(file_hasher.hexdigest())
return blake3(b"".join([bytes.fromhex(h) for h in component_hashes])).hexdigest()

View File

@ -64,6 +64,7 @@ dependencies = [
# Auxiliary dependencies, pinned only if necessary.
"albumentations",
"blake3",
"click",
"datasets",
"Deprecated",
@ -72,7 +73,6 @@ dependencies = [
"easing-functions",
"einops",
"facexlib",
"imohash",
"matplotlib", # needed for plotting of Penner easing functions
"npyscreen",
"omegaconf",