From 33967cfc6dbe35b3c7d6cae9a71eac4d980c2de4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 28 Feb 2024 13:21:29 +1100 Subject: [PATCH] feat(mm): make ModelHash instantiatable, taking an algorithm as arg --- .../migrations/util/migrate_yaml_config_1.py | 2 +- invokeai/backend/model_manager/hash.py | 90 ++++++++++--------- invokeai/backend/model_manager/probe.py | 2 +- 3 files changed, 51 insertions(+), 43 deletions(-) diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/util/migrate_yaml_config_1.py b/invokeai/app/services/shared/sqlite_migrator/migrations/util/migrate_yaml_config_1.py index a52bb4f599..be4d5f0140 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/util/migrate_yaml_config_1.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/util/migrate_yaml_config_1.py @@ -72,7 +72,7 @@ class MigrateModelYamlToDb1: base_type, model_type, model_name = str(model_key).split("/") try: - hash = ModelHash.hash(self.config.models_path / stanza.path) + hash = ModelHash().hash(self.config.models_path / stanza.path) except OSError: self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.") continue diff --git a/invokeai/backend/model_manager/hash.py b/invokeai/backend/model_manager/hash.py index a7ac014194..1139a1dacf 100644 --- a/invokeai/backend/model_manager/hash.py +++ b/invokeai/backend/model_manager/hash.py @@ -10,13 +10,13 @@ from invokeai.backend.model_managre.model_hash import FastModelHash import hashlib import os from pathlib import Path -from typing import Literal, Union +from typing import Callable, Literal, Union from blake3 import blake3 MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth") -ALGORITHMS = Literal[ +ALGORITHM = Literal[ "md5", "sha1", "sha1_fast", @@ -37,10 +37,39 @@ ALGORITHMS = Literal[ class ModelHash: - """ModelHash provides one public class method, hash().""" + """ + Creates a hash of a model using a specified algorithm. - @classmethod - def hash(cls, model_location: Union[str, Path], algorithm: ALGORITHMS = "blake3") -> str: + :param algorithm: Hashing algorithm to use. Defaults to BLAKE3. + + If the model is a single file, it is hashed directly using the provided algorithm. + + If the model is a directory, each model weights file in the directory is hashed using the provided algorithm. + + Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth + + The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring + that directory hashes are never weaker than the file hashes. + + Usage + + ```py + ModelHash().hash("path/to/some/model.safetensors") + ModelHash("md5").hash("path/to/model/dir/") + ``` + """ + + def __init__(self, algorithm: ALGORITHM = "blake3") -> None: + if algorithm == "blake3": + self._hash_file = self._blake3 + elif algorithm == "sha1_fast": + self._hash_file = self._sha1_fast + elif algorithm in hashlib.algorithms_available: + self._hash_file = self._get_hashlib(algorithm) + else: + raise ValueError(f"Algorithm {algorithm} not available") + + def hash(self, model_location: Union[str, Path]) -> str: """ Return hexdigest string for model located at model_location. @@ -48,48 +77,23 @@ class ModelHash: directory. The final composite hash is always computed using BLAKE3. :param model_location: Path to the model - :param algorithm: Hashing algorithm to use """ model_location = Path(model_location) if model_location.is_file(): - return cls._hash_file(model_location, algorithm) + return self._hash_file(model_location) elif model_location.is_dir(): - return cls._hash_dir(model_location, algorithm) + return self._hash_dir(model_location) else: raise OSError(f"Not a valid file or directory: {model_location}") - @classmethod - def _hash_file(cls, model_location: Path, algorithm: ALGORITHMS) -> str: - """ - Compute the hash for a single file and return its hexdigest. - - :param model_location: Path to the model file - :param algorithm: Hashing algorithm to use - """ - - if algorithm == "blake3": - return cls._blake3(model_location) - elif algorithm == "sha1_fast": - return cls._sha1_fast(model_location) - elif algorithm in hashlib.algorithms_available: - return cls._hashlib(model_location, algorithm) - else: - raise ValueError(f"Algorithm {algorithm} not available") - - @classmethod - def _hash_dir(cls, model_location: Path, algorithm: ALGORITHMS) -> str: - """ - Compute the hash for all files in a directory and return a hexdigest. - - :param model_location: Path to the model directory - :param algorithm: Hashing algorithm to use - """ - model_component_paths = cls._get_file_paths(model_location) + def _hash_dir(self, model_location: Path) -> str: + """Compute the hash for all files in a directory and return a hexdigest.""" + model_component_paths = self._get_file_paths(model_location) component_hashes: list[str] = [] for component in sorted(model_component_paths): - component_hashes.append(cls._hash_file(component, algorithm)) + component_hashes.append(self._hash_file(component)) # BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm # for the composite hash @@ -128,9 +132,13 @@ class ModelHash: return file_hash.hexdigest() @staticmethod - def _hashlib(file_path: Path, algorithm: ALGORITHMS) -> str: + def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]: """Hashes a file using a hashlib algorithm""" - file_hasher = hashlib.new(algorithm) - with open(file_path, "rb") as f: - file_hasher.update(f.read()) - return file_hasher.hexdigest() + + def hasher(file_path: Path) -> str: + file_hasher = hashlib.new(algorithm) + with open(file_path, "rb") as f: + file_hasher.update(f.read()) + return file_hasher.hexdigest() + + return hasher diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 1611a76558..a7250f33d1 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -147,7 +147,7 @@ class ModelProbe(object): if not probe_class: raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}") - hash = ModelHash.hash(model_path) + hash = ModelHash().hash(model_path) probe = probe_class(model_path) fields["path"] = model_path.as_posix()