mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): add hashing algos to ModelHash
- Some algos are slow, so it is now just called ModelHash - Added all hashlib algos, plus BLAKE3 and the fast (but incorrect) SHA1 algo
This commit is contained in:
parent
4b073157b8
commit
6b41246b2d
@ -21,7 +21,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelConfigFactory,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.hash import ModelHash
|
||||
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
@ -72,7 +72,7 @@ class MigrateModelYamlToDb1:
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
try:
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
hash = ModelHash.hash(self.config.models_path / stanza.path)
|
||||
except OSError:
|
||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||
continue
|
||||
|
@ -7,53 +7,82 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
|
||||
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
||||
'a8e693a126ea5b831c96064dc569956f'
|
||||
"""
|
||||
import cProfile
|
||||
import hashlib
|
||||
import os
|
||||
import pstats
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Union
|
||||
from typing import Literal, Union
|
||||
|
||||
from blake3 import blake3
|
||||
from tqdm import tqdm
|
||||
|
||||
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||
|
||||
ALGORITHMS = Literal[
|
||||
"md5",
|
||||
"sha1",
|
||||
"sha1_fast",
|
||||
"sha224",
|
||||
"sha256",
|
||||
"sha384",
|
||||
"sha512",
|
||||
"blake2b",
|
||||
"blake2s",
|
||||
"sha3_224",
|
||||
"sha3_256",
|
||||
"sha3_384",
|
||||
"sha3_512",
|
||||
"shake_128",
|
||||
"shake_256",
|
||||
"blake3",
|
||||
]
|
||||
|
||||
|
||||
class FastModelHash(object):
|
||||
"""FastModelHash obect provides one public class method, hash()."""
|
||||
class ModelHash:
|
||||
"""ModelHash provides one public class method, hash()."""
|
||||
|
||||
@classmethod
|
||||
def hash(cls, model_location: Union[str, Path]) -> str:
|
||||
def hash(cls, model_location: Union[str, Path], algorithm: ALGORITHMS = "blake3") -> str:
|
||||
"""
|
||||
Return hexdigest string for model located at model_location.
|
||||
|
||||
If model_location is a directory, the hash is computed by hashing the hashes of all model files in the
|
||||
directory. The final composite hash is always computed using BLAKE3.
|
||||
|
||||
:param model_location: Path to the model
|
||||
:param algorithm: Hashing algorithm to use
|
||||
"""
|
||||
model_location = Path(model_location)
|
||||
if model_location.is_file():
|
||||
return cls._hash_file(model_location)
|
||||
return cls._hash_file(model_location, algorithm)
|
||||
elif model_location.is_dir():
|
||||
return cls._hash_dir(model_location)
|
||||
return cls._hash_dir(model_location, algorithm)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
||||
|
||||
@classmethod
|
||||
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
||||
def _hash_file(cls, model_location: Union[str, Path], algorithm: ALGORITHMS) -> str:
|
||||
"""
|
||||
Compute full BLAKE3 hash over a single file and return its hexdigest.
|
||||
Compute the hash for a single file and return its hexdigest.
|
||||
|
||||
:param model_location: Path to the model file
|
||||
:param algorithm: Hashing algorithm to use
|
||||
"""
|
||||
file_hasher = blake3(max_threads=blake3.AUTO)
|
||||
file_hasher.update_mmap(model_location)
|
||||
return file_hasher.hexdigest()
|
||||
|
||||
if algorithm == "blake3":
|
||||
return cls._blake3(model_location)
|
||||
elif algorithm == "sha1_fast":
|
||||
return cls._sha1_fast(model_location)
|
||||
elif algorithm in hashlib.algorithms_available:
|
||||
return cls._hashlib(model_location, algorithm)
|
||||
else:
|
||||
raise ValueError(f"Algorithm {algorithm} not available")
|
||||
|
||||
@classmethod
|
||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
||||
def _hash_dir(cls, model_location: Union[str, Path], algorithm: ALGORITHMS) -> str:
|
||||
"""
|
||||
Compute full BLAKE3 hash over all files in a directory and return its hexdigest.
|
||||
Compute the hash for all files in a directory and return a hexdigest.
|
||||
|
||||
:param model_location: Path to the model directory
|
||||
:param algorithm: Hashing algorithm to use
|
||||
"""
|
||||
components: list[str] = []
|
||||
|
||||
@ -61,31 +90,42 @@ class FastModelHash(object):
|
||||
for file in files:
|
||||
# only tally tensor files because diffusers config files change slightly
|
||||
# depending on how the model was downloaded/converted.
|
||||
if file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
||||
components.append((Path(root, file).resolve().as_posix()))
|
||||
if file.endswith(MODEL_FILE_EXTENSIONS):
|
||||
components.append((Path(root, file).as_posix()))
|
||||
|
||||
component_hashes: list[str] = []
|
||||
for component in sorted(components):
|
||||
component_hashes.append(cls._hash_file(component, algorithm))
|
||||
|
||||
for component in tqdm(sorted(components), desc=f"Hashing model components for {model_location}"):
|
||||
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
||||
# for the composite hash
|
||||
composite_hasher = blake3()
|
||||
for h in components:
|
||||
composite_hasher.update(h.encode("utf-8"))
|
||||
return composite_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _blake3(file_path: Union[str, Path]) -> str:
|
||||
"""Hashes a file using BLAKE3"""
|
||||
file_hasher = blake3(max_threads=blake3.AUTO)
|
||||
file_hasher.update_mmap(component)
|
||||
component_hashes.append(file_hasher.hexdigest())
|
||||
file_hasher.update_mmap(file_path)
|
||||
return file_hasher.hexdigest()
|
||||
|
||||
return blake3(b"".join([bytes.fromhex(h) for h in component_hashes])).hexdigest()
|
||||
@staticmethod
|
||||
def _sha1_fast(file_path: Union[str, Path]) -> str:
|
||||
"""Hashes a file using SHA1, but with a block size of 2**16. The result is not a standard SHA1 hash due to the
|
||||
# padding introduced by the block size. The algorithm is, however, very fast."""
|
||||
BLOCK_SIZE = 2**16
|
||||
file_hash = hashlib.sha1()
|
||||
with open(file_path, "rb") as f:
|
||||
data = f.read(BLOCK_SIZE)
|
||||
file_hash.update(data)
|
||||
return file_hash.hexdigest()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with TemporaryDirectory() as tempdir:
|
||||
profile_path = Path(tempdir, "profile_results.pstats").as_posix()
|
||||
profiler = cProfile.Profile()
|
||||
profiler.enable()
|
||||
t = threading.Thread(
|
||||
target=FastModelHash.hash, args=("/media/rhino/invokeai/models/sd-1/main/stable-diffusion-v1-5-inpainting",)
|
||||
)
|
||||
t.start()
|
||||
t.join()
|
||||
profiler.disable()
|
||||
stats = pstats.Stats(profiler).sort_stats(pstats.SortKey.TIME)
|
||||
stats.dump_stats(profile_path)
|
||||
|
||||
os.system(f"snakeviz {profile_path}")
|
||||
@staticmethod
|
||||
def _hashlib(file_path: Union[str, Path], algorithm: ALGORITHMS) -> str:
|
||||
"""Hashes a file using a hashlib algorithm"""
|
||||
file_hasher = hashlib.new(algorithm)
|
||||
with open(file_path, "rb") as f:
|
||||
file_hasher.update(f.read())
|
||||
return file_hasher.hexdigest()
|
||||
|
@ -21,7 +21,7 @@ from .config import (
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .hash import FastModelHash
|
||||
from .hash import ModelHash
|
||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
|
||||
CkptType = Dict[str, Any]
|
||||
@ -147,7 +147,7 @@ class ModelProbe(object):
|
||||
if not probe_class:
|
||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||
|
||||
hash = FastModelHash.hash(model_path)
|
||||
hash = ModelHash.hash(model_path)
|
||||
probe = probe_class(model_path)
|
||||
|
||||
fields["path"] = model_path.as_posix()
|
||||
|
Loading…
Reference in New Issue
Block a user