InvokeAI/invokeai/backend/model_manager/hash.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

92 lines
3.1 KiB
Python
Raw Normal View History

# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Fast hashing of diffusers and checkpoint-style models.
Usage:
from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import cProfile
import os
import pstats
import threading
from pathlib import Path
from tempfile import TemporaryDirectory
2024-02-27 09:51:49 +00:00
from typing import Union
2024-02-27 09:51:49 +00:00
from blake3 import blake3
from tqdm import tqdm
class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""
@classmethod
def hash(cls, model_location: Union[str, Path]) -> str:
"""
Return hexdigest string for model located at model_location.
:param model_location: Path to the model
"""
model_location = Path(model_location)
if model_location.is_file():
2024-02-27 09:51:49 +00:00
return cls._hash_file(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location)
else:
raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod
2024-02-27 09:51:49 +00:00
def _hash_file(cls, model_location: Union[str, Path]) -> str:
"""
2024-02-27 09:51:49 +00:00
Compute full BLAKE3 hash over a single file and return its hexdigest.
:param model_location: Path to the model file
"""
2024-02-27 09:51:49 +00:00
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(model_location)
return file_hasher.hexdigest()
@classmethod
2024-02-27 09:51:49 +00:00
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
"""
2024-02-27 09:51:49 +00:00
Compute full BLAKE3 hash over all files in a directory and return its hexdigest.
2024-02-27 09:51:49 +00:00
:param model_location: Path to the model directory
"""
2024-02-27 09:51:49 +00:00
components: list[str] = []
2023-11-13 20:57:07 +00:00
for root, _dirs, files in os.walk(model_location):
for file in files:
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
2024-02-27 09:51:49 +00:00
if file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
components.append((Path(root, file).resolve().as_posix()))
2024-02-27 09:51:49 +00:00
component_hashes: list[str] = []
for component in tqdm(sorted(components), desc=f"Hashing model components for {model_location}"):
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(component)
component_hashes.append(file_hasher.hexdigest())
2024-02-27 09:51:49 +00:00
return blake3(b"".join([bytes.fromhex(h) for h in component_hashes])).hexdigest()
if __name__ == "__main__":
with TemporaryDirectory() as tempdir:
profile_path = Path(tempdir, "profile_results.pstats").as_posix()
profiler = cProfile.Profile()
profiler.enable()
t = threading.Thread(
target=FastModelHash.hash, args=("/media/rhino/invokeai/models/sd-1/main/stable-diffusion-v1-5-inpainting",)
)
t.start()
t.join()
profiler.disable()
stats = pstats.Stats(profiler).sort_stats(pstats.SortKey.TIME)
stats.dump_stats(profile_path)
os.system(f"snakeviz {profile_path}")