mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
71 lines
2.5 KiB
Python
71 lines
2.5 KiB
Python
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
"""
|
|
Fast hashing of diffusers and checkpoint-style models.
|
|
|
|
Usage:
|
|
from invokeai.backend.model_managre.model_hash import FastModelHash
|
|
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
|
'a8e693a126ea5b831c96064dc569956f'
|
|
"""
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Union
|
|
|
|
from blake3 import blake3
|
|
from tqdm import tqdm
|
|
|
|
|
|
class FastModelHash(object):
|
|
"""FastModelHash obect provides one public class method, hash()."""
|
|
|
|
@classmethod
|
|
def hash(cls, model_location: Union[str, Path]) -> str:
|
|
"""
|
|
Return hexdigest string for model located at model_location.
|
|
|
|
:param model_location: Path to the model
|
|
"""
|
|
model_location = Path(model_location)
|
|
if model_location.is_file():
|
|
return cls._hash_file(model_location)
|
|
elif model_location.is_dir():
|
|
return cls._hash_dir(model_location)
|
|
else:
|
|
raise OSError(f"Not a valid file or directory: {model_location}")
|
|
|
|
@classmethod
|
|
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
|
"""
|
|
Compute full BLAKE3 hash over a single file and return its hexdigest.
|
|
|
|
:param model_location: Path to the model file
|
|
"""
|
|
file_hasher = blake3(max_threads=blake3.AUTO)
|
|
file_hasher.update_mmap(model_location)
|
|
return file_hasher.hexdigest()
|
|
|
|
@classmethod
|
|
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
|
"""
|
|
Compute full BLAKE3 hash over all files in a directory and return its hexdigest.
|
|
|
|
:param model_location: Path to the model directory
|
|
"""
|
|
components: list[str] = []
|
|
|
|
for root, _dirs, files in os.walk(model_location):
|
|
for file in files:
|
|
# only tally tensor files because diffusers config files change slightly
|
|
# depending on how the model was downloaded/converted.
|
|
if file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
|
components.append((Path(root, file).as_posix()))
|
|
|
|
component_hashes: list[str] = []
|
|
|
|
for component in tqdm(sorted(components), desc=f"Hashing model components for {model_location}"):
|
|
file_hasher = blake3(max_threads=blake3.AUTO)
|
|
file_hasher.update_mmap(component)
|
|
component_hashes.append(file_hasher.hexdigest())
|
|
|
|
return blake3(b"".join([bytes.fromhex(h) for h in component_hashes])).hexdigest()
|