feat(mm): add hashing algos to ModelHash

- Some algos are slow, so it is now just called ModelHash
- Added all hashlib algos, plus BLAKE3 and the fast (but incorrect) SHA1 algo
This commit is contained in:
psychedelicious 2024-02-28 01:50:05 +11:00 committed by Ryan Dick
parent 4b073157b8
commit 6b41246b2d
3 changed files with 84 additions and 44 deletions

View File

@ -21,7 +21,7 @@ from invokeai.backend.model_manager.config import (
ModelConfigFactory, ModelConfigFactory,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.hash import FastModelHash from invokeai.backend.model_manager.hash import ModelHash
ModelsValidator = TypeAdapter(AnyModelConfig) ModelsValidator = TypeAdapter(AnyModelConfig)
@ -72,7 +72,7 @@ class MigrateModelYamlToDb1:
base_type, model_type, model_name = str(model_key).split("/") base_type, model_type, model_name = str(model_key).split("/")
try: try:
hash = FastModelHash.hash(self.config.models_path / stanza.path) hash = ModelHash.hash(self.config.models_path / stanza.path)
except OSError: except OSError:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.") self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue continue

View File

@ -7,53 +7,82 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5') >>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f' 'a8e693a126ea5b831c96064dc569956f'
""" """
import cProfile import hashlib
import os import os
import pstats
import threading
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from typing import Literal, Union
from typing import Union
from blake3 import blake3 from blake3 import blake3
from tqdm import tqdm
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
ALGORITHMS = Literal[
"md5",
"sha1",
"sha1_fast",
"sha224",
"sha256",
"sha384",
"sha512",
"blake2b",
"blake2s",
"sha3_224",
"sha3_256",
"sha3_384",
"sha3_512",
"shake_128",
"shake_256",
"blake3",
]
class FastModelHash(object): class ModelHash:
"""FastModelHash obect provides one public class method, hash().""" """ModelHash provides one public class method, hash()."""
@classmethod @classmethod
def hash(cls, model_location: Union[str, Path]) -> str: def hash(cls, model_location: Union[str, Path], algorithm: ALGORITHMS = "blake3") -> str:
""" """
Return hexdigest string for model located at model_location. Return hexdigest string for model located at model_location.
If model_location is a directory, the hash is computed by hashing the hashes of all model files in the
directory. The final composite hash is always computed using BLAKE3.
:param model_location: Path to the model :param model_location: Path to the model
:param algorithm: Hashing algorithm to use
""" """
model_location = Path(model_location) model_location = Path(model_location)
if model_location.is_file(): if model_location.is_file():
return cls._hash_file(model_location) return cls._hash_file(model_location, algorithm)
elif model_location.is_dir(): elif model_location.is_dir():
return cls._hash_dir(model_location) return cls._hash_dir(model_location, algorithm)
else: else:
raise OSError(f"Not a valid file or directory: {model_location}") raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod @classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str: def _hash_file(cls, model_location: Union[str, Path], algorithm: ALGORITHMS) -> str:
""" """
Compute full BLAKE3 hash over a single file and return its hexdigest. Compute the hash for a single file and return its hexdigest.
:param model_location: Path to the model file :param model_location: Path to the model file
:param algorithm: Hashing algorithm to use
""" """
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(model_location) if algorithm == "blake3":
return file_hasher.hexdigest() return cls._blake3(model_location)
elif algorithm == "sha1_fast":
return cls._sha1_fast(model_location)
elif algorithm in hashlib.algorithms_available:
return cls._hashlib(model_location, algorithm)
else:
raise ValueError(f"Algorithm {algorithm} not available")
@classmethod @classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str: def _hash_dir(cls, model_location: Union[str, Path], algorithm: ALGORITHMS) -> str:
""" """
Compute full BLAKE3 hash over all files in a directory and return its hexdigest. Compute the hash for all files in a directory and return a hexdigest.
:param model_location: Path to the model directory :param model_location: Path to the model directory
:param algorithm: Hashing algorithm to use
""" """
components: list[str] = [] components: list[str] = []
@ -61,31 +90,42 @@ class FastModelHash(object):
for file in files: for file in files:
# only tally tensor files because diffusers config files change slightly # only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted. # depending on how the model was downloaded/converted.
if file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")): if file.endswith(MODEL_FILE_EXTENSIONS):
components.append((Path(root, file).resolve().as_posix())) components.append((Path(root, file).as_posix()))
component_hashes: list[str] = [] component_hashes: list[str] = []
for component in sorted(components):
component_hashes.append(cls._hash_file(component, algorithm))
for component in tqdm(sorted(components), desc=f"Hashing model components for {model_location}"): # BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
file_hasher = blake3(max_threads=blake3.AUTO) # for the composite hash
file_hasher.update_mmap(component) composite_hasher = blake3()
component_hashes.append(file_hasher.hexdigest()) for h in components:
composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest()
return blake3(b"".join([bytes.fromhex(h) for h in component_hashes])).hexdigest() @staticmethod
def _blake3(file_path: Union[str, Path]) -> str:
"""Hashes a file using BLAKE3"""
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(file_path)
return file_hasher.hexdigest()
@staticmethod
def _sha1_fast(file_path: Union[str, Path]) -> str:
"""Hashes a file using SHA1, but with a block size of 2**16. The result is not a standard SHA1 hash due to the
# padding introduced by the block size. The algorithm is, however, very fast."""
BLOCK_SIZE = 2**16
file_hash = hashlib.sha1()
with open(file_path, "rb") as f:
data = f.read(BLOCK_SIZE)
file_hash.update(data)
return file_hash.hexdigest()
if __name__ == "__main__": @staticmethod
with TemporaryDirectory() as tempdir: def _hashlib(file_path: Union[str, Path], algorithm: ALGORITHMS) -> str:
profile_path = Path(tempdir, "profile_results.pstats").as_posix() """Hashes a file using a hashlib algorithm"""
profiler = cProfile.Profile() file_hasher = hashlib.new(algorithm)
profiler.enable() with open(file_path, "rb") as f:
t = threading.Thread( file_hasher.update(f.read())
target=FastModelHash.hash, args=("/media/rhino/invokeai/models/sd-1/main/stable-diffusion-v1-5-inpainting",) return file_hasher.hexdigest()
)
t.start()
t.join()
profiler.disable()
stats = pstats.Stats(profiler).sort_stats(pstats.SortKey.TIME)
stats.dump_stats(profile_path)
os.system(f"snakeviz {profile_path}")

View File

@ -21,7 +21,7 @@ from .config import (
ModelVariantType, ModelVariantType,
SchedulerPredictionType, SchedulerPredictionType,
) )
from .hash import FastModelHash from .hash import ModelHash
from .util.model_util import lora_token_vector_length, read_checkpoint_meta from .util.model_util import lora_token_vector_length, read_checkpoint_meta
CkptType = Dict[str, Any] CkptType = Dict[str, Any]
@ -147,7 +147,7 @@ class ModelProbe(object):
if not probe_class: if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}") raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = FastModelHash.hash(model_path) hash = ModelHash.hash(model_path)
probe = probe_class(model_path) probe = probe_class(model_path)
fields["path"] = model_path.as_posix() fields["path"] = model_path.as_posix()