feat(mm): make ModelHash instantiatable, taking an algorithm as arg

This commit is contained in:
psychedelicious 2024-02-28 13:21:29 +11:00
parent 3493ae7cdb
commit 5b675d8481
4 changed files with 53 additions and 45 deletions

View File

@ -166,7 +166,7 @@ class ModelInstallService(ModelInstallServiceBase):
raise DuplicateModelException( raise DuplicateModelException(
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}" f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
) from excp ) from excp
new_hash = ModelHash.hash(new_path) new_hash = ModelHash().hash(new_path)
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted." assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
return self._register( return self._register(
@ -468,7 +468,7 @@ class ModelInstallService(ModelInstallServiceBase):
new_path = models_dir / model.base.value / model.type.value / model.name new_path = models_dir / model.base.value / model.type.value / model.name
self._logger.info(f"Moving {model.name} to {new_path}.") self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path) new_path = self._move_model(old_path, new_path)
new_hash = ModelHash.hash(new_path) new_hash = ModelHash().hash(new_path)
model.path = new_path.relative_to(models_dir).as_posix() model.path = new_path.relative_to(models_dir).as_posix()
if model.current_hash != new_hash: if model.current_hash != new_hash:
assert ( assert (

View File

@ -72,7 +72,7 @@ class MigrateModelYamlToDb1:
base_type, model_type, model_name = str(model_key).split("/") base_type, model_type, model_name = str(model_key).split("/")
try: try:
hash = ModelHash.hash(self.config.models_path / stanza.path) hash = ModelHash().hash(self.config.models_path / stanza.path)
except OSError: except OSError:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.") self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue continue

View File

@ -10,13 +10,13 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
import hashlib import hashlib
import os import os
from pathlib import Path from pathlib import Path
from typing import Literal, Union from typing import Callable, Literal, Union
from blake3 import blake3 from blake3 import blake3
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth") MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
ALGORITHMS = Literal[ ALGORITHM = Literal[
"md5", "md5",
"sha1", "sha1",
"sha1_fast", "sha1_fast",
@ -37,10 +37,39 @@ ALGORITHMS = Literal[
class ModelHash: class ModelHash:
"""ModelHash provides one public class method, hash().""" """
Creates a hash of a model using a specified algorithm.
@classmethod :param algorithm: Hashing algorithm to use. Defaults to BLAKE3.
def hash(cls, model_location: Union[str, Path], algorithm: ALGORITHMS = "blake3") -> str:
If the model is a single file, it is hashed directly using the provided algorithm.
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
Usage
```py
ModelHash().hash("path/to/some/model.safetensors")
ModelHash("md5").hash("path/to/model/dir/")
```
"""
def __init__(self, algorithm: ALGORITHM = "blake3") -> None:
if algorithm == "blake3":
self._hash_file = self._blake3
elif algorithm == "sha1_fast":
self._hash_file = self._sha1_fast
elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm)
else:
raise ValueError(f"Algorithm {algorithm} not available")
def hash(self, model_location: Union[str, Path]) -> str:
""" """
Return hexdigest string for model located at model_location. Return hexdigest string for model located at model_location.
@ -48,48 +77,23 @@ class ModelHash:
directory. The final composite hash is always computed using BLAKE3. directory. The final composite hash is always computed using BLAKE3.
:param model_location: Path to the model :param model_location: Path to the model
:param algorithm: Hashing algorithm to use
""" """
model_location = Path(model_location) model_location = Path(model_location)
if model_location.is_file(): if model_location.is_file():
return cls._hash_file(model_location, algorithm) return self._hash_file(model_location)
elif model_location.is_dir(): elif model_location.is_dir():
return cls._hash_dir(model_location, algorithm) return self._hash_dir(model_location)
else: else:
raise OSError(f"Not a valid file or directory: {model_location}") raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod def _hash_dir(self, model_location: Path) -> str:
def _hash_file(cls, model_location: Path, algorithm: ALGORITHMS) -> str: """Compute the hash for all files in a directory and return a hexdigest."""
""" model_component_paths = self._get_file_paths(model_location)
Compute the hash for a single file and return its hexdigest.
:param model_location: Path to the model file
:param algorithm: Hashing algorithm to use
"""
if algorithm == "blake3":
return cls._blake3(model_location)
elif algorithm == "sha1_fast":
return cls._sha1_fast(model_location)
elif algorithm in hashlib.algorithms_available:
return cls._hashlib(model_location, algorithm)
else:
raise ValueError(f"Algorithm {algorithm} not available")
@classmethod
def _hash_dir(cls, model_location: Path, algorithm: ALGORITHMS) -> str:
"""
Compute the hash for all files in a directory and return a hexdigest.
:param model_location: Path to the model directory
:param algorithm: Hashing algorithm to use
"""
model_component_paths = cls._get_file_paths(model_location)
component_hashes: list[str] = [] component_hashes: list[str] = []
for component in sorted(model_component_paths): for component in sorted(model_component_paths):
component_hashes.append(cls._hash_file(component, algorithm)) component_hashes.append(self._hash_file(component))
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm # BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
# for the composite hash # for the composite hash
@ -128,9 +132,13 @@ class ModelHash:
return file_hash.hexdigest() return file_hash.hexdigest()
@staticmethod @staticmethod
def _hashlib(file_path: Path, algorithm: ALGORITHMS) -> str: def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
"""Hashes a file using a hashlib algorithm""" """Hashes a file using a hashlib algorithm"""
def hasher(file_path: Path) -> str:
file_hasher = hashlib.new(algorithm) file_hasher = hashlib.new(algorithm)
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
file_hasher.update(f.read()) file_hasher.update(f.read())
return file_hasher.hexdigest() return file_hasher.hexdigest()
return hasher

View File

@ -147,7 +147,7 @@ class ModelProbe(object):
if not probe_class: if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}") raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = ModelHash.hash(model_path) hash = ModelHash().hash(model_path)
probe = probe_class(model_path) probe = probe_class(model_path)
fields["path"] = model_path.as_posix() fields["path"] = model_path.as_posix()