feat(mm): make ModelHash instantiatable, taking an algorithm as arg

This commit is contained in:
psychedelicious 2024-02-28 13:21:29 +11:00 committed by Ryan Dick
parent 3cf3ed55a2
commit 33967cfc6d
3 changed files with 51 additions and 43 deletions

View File

@ -72,7 +72,7 @@ class MigrateModelYamlToDb1:
base_type, model_type, model_name = str(model_key).split("/") base_type, model_type, model_name = str(model_key).split("/")
try: try:
hash = ModelHash.hash(self.config.models_path / stanza.path) hash = ModelHash().hash(self.config.models_path / stanza.path)
except OSError: except OSError:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.") self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue continue

View File

@ -10,13 +10,13 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
import hashlib import hashlib
import os import os
from pathlib import Path from pathlib import Path
from typing import Literal, Union from typing import Callable, Literal, Union
from blake3 import blake3 from blake3 import blake3
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth") MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
ALGORITHMS = Literal[ ALGORITHM = Literal[
"md5", "md5",
"sha1", "sha1",
"sha1_fast", "sha1_fast",
@ -37,10 +37,39 @@ ALGORITHMS = Literal[
class ModelHash: class ModelHash:
"""ModelHash provides one public class method, hash().""" """
Creates a hash of a model using a specified algorithm.
@classmethod :param algorithm: Hashing algorithm to use. Defaults to BLAKE3.
def hash(cls, model_location: Union[str, Path], algorithm: ALGORITHMS = "blake3") -> str:
If the model is a single file, it is hashed directly using the provided algorithm.
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
Usage
```py
ModelHash().hash("path/to/some/model.safetensors")
ModelHash("md5").hash("path/to/model/dir/")
```
"""
def __init__(self, algorithm: ALGORITHM = "blake3") -> None:
if algorithm == "blake3":
self._hash_file = self._blake3
elif algorithm == "sha1_fast":
self._hash_file = self._sha1_fast
elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm)
else:
raise ValueError(f"Algorithm {algorithm} not available")
def hash(self, model_location: Union[str, Path]) -> str:
""" """
Return hexdigest string for model located at model_location. Return hexdigest string for model located at model_location.
@ -48,48 +77,23 @@ class ModelHash:
directory. The final composite hash is always computed using BLAKE3. directory. The final composite hash is always computed using BLAKE3.
:param model_location: Path to the model :param model_location: Path to the model
:param algorithm: Hashing algorithm to use
""" """
model_location = Path(model_location) model_location = Path(model_location)
if model_location.is_file(): if model_location.is_file():
return cls._hash_file(model_location, algorithm) return self._hash_file(model_location)
elif model_location.is_dir(): elif model_location.is_dir():
return cls._hash_dir(model_location, algorithm) return self._hash_dir(model_location)
else: else:
raise OSError(f"Not a valid file or directory: {model_location}") raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod def _hash_dir(self, model_location: Path) -> str:
def _hash_file(cls, model_location: Path, algorithm: ALGORITHMS) -> str: """Compute the hash for all files in a directory and return a hexdigest."""
""" model_component_paths = self._get_file_paths(model_location)
Compute the hash for a single file and return its hexdigest.
:param model_location: Path to the model file
:param algorithm: Hashing algorithm to use
"""
if algorithm == "blake3":
return cls._blake3(model_location)
elif algorithm == "sha1_fast":
return cls._sha1_fast(model_location)
elif algorithm in hashlib.algorithms_available:
return cls._hashlib(model_location, algorithm)
else:
raise ValueError(f"Algorithm {algorithm} not available")
@classmethod
def _hash_dir(cls, model_location: Path, algorithm: ALGORITHMS) -> str:
"""
Compute the hash for all files in a directory and return a hexdigest.
:param model_location: Path to the model directory
:param algorithm: Hashing algorithm to use
"""
model_component_paths = cls._get_file_paths(model_location)
component_hashes: list[str] = [] component_hashes: list[str] = []
for component in sorted(model_component_paths): for component in sorted(model_component_paths):
component_hashes.append(cls._hash_file(component, algorithm)) component_hashes.append(self._hash_file(component))
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm # BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
# for the composite hash # for the composite hash
@ -128,9 +132,13 @@ class ModelHash:
return file_hash.hexdigest() return file_hash.hexdigest()
@staticmethod @staticmethod
def _hashlib(file_path: Path, algorithm: ALGORITHMS) -> str: def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
"""Hashes a file using a hashlib algorithm""" """Hashes a file using a hashlib algorithm"""
def hasher(file_path: Path) -> str:
file_hasher = hashlib.new(algorithm) file_hasher = hashlib.new(algorithm)
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
file_hasher.update(f.read()) file_hasher.update(f.read())
return file_hasher.hexdigest() return file_hasher.hexdigest()
return hasher

View File

@ -147,7 +147,7 @@ class ModelProbe(object):
if not probe_class: if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}") raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = ModelHash.hash(model_path) hash = ModelHash().hash(model_path)
probe = probe_class(model_path) probe = probe_class(model_path)
fields["path"] = model_path.as_posix() fields["path"] = model_path.as_posix()