feat(mm): make ModelHash instantiatable, taking an algorithm as arg

This commit is contained in:
psychedelicious 2024-02-28 13:21:29 +11:00 committed by Ryan Dick
parent 3cf3ed55a2
commit 33967cfc6d
3 changed files with 51 additions and 43 deletions

View File

@ -72,7 +72,7 @@ class MigrateModelYamlToDb1:
base_type, model_type, model_name = str(model_key).split("/")
try:
hash = ModelHash.hash(self.config.models_path / stanza.path)
hash = ModelHash().hash(self.config.models_path / stanza.path)
except OSError:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue

View File

@ -10,13 +10,13 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
import hashlib
import os
from pathlib import Path
from typing import Literal, Union
from typing import Callable, Literal, Union
from blake3 import blake3
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
ALGORITHMS = Literal[
ALGORITHM = Literal[
"md5",
"sha1",
"sha1_fast",
@ -37,10 +37,39 @@ ALGORITHMS = Literal[
class ModelHash:
"""ModelHash provides one public class method, hash()."""
"""
Creates a hash of a model using a specified algorithm.
@classmethod
def hash(cls, model_location: Union[str, Path], algorithm: ALGORITHMS = "blake3") -> str:
:param algorithm: Hashing algorithm to use. Defaults to BLAKE3.
If the model is a single file, it is hashed directly using the provided algorithm.
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
Usage
```py
ModelHash().hash("path/to/some/model.safetensors")
ModelHash("md5").hash("path/to/model/dir/")
```
"""
def __init__(self, algorithm: ALGORITHM = "blake3") -> None:
if algorithm == "blake3":
self._hash_file = self._blake3
elif algorithm == "sha1_fast":
self._hash_file = self._sha1_fast
elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm)
else:
raise ValueError(f"Algorithm {algorithm} not available")
def hash(self, model_location: Union[str, Path]) -> str:
"""
Return hexdigest string for model located at model_location.
@ -48,48 +77,23 @@ class ModelHash:
directory. The final composite hash is always computed using BLAKE3.
:param model_location: Path to the model
:param algorithm: Hashing algorithm to use
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file(model_location, algorithm)
return self._hash_file(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location, algorithm)
return self._hash_dir(model_location)
else:
raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod
def _hash_file(cls, model_location: Path, algorithm: ALGORITHMS) -> str:
"""
Compute the hash for a single file and return its hexdigest.
:param model_location: Path to the model file
:param algorithm: Hashing algorithm to use
"""
if algorithm == "blake3":
return cls._blake3(model_location)
elif algorithm == "sha1_fast":
return cls._sha1_fast(model_location)
elif algorithm in hashlib.algorithms_available:
return cls._hashlib(model_location, algorithm)
else:
raise ValueError(f"Algorithm {algorithm} not available")
@classmethod
def _hash_dir(cls, model_location: Path, algorithm: ALGORITHMS) -> str:
"""
Compute the hash for all files in a directory and return a hexdigest.
:param model_location: Path to the model directory
:param algorithm: Hashing algorithm to use
"""
model_component_paths = cls._get_file_paths(model_location)
def _hash_dir(self, model_location: Path) -> str:
"""Compute the hash for all files in a directory and return a hexdigest."""
model_component_paths = self._get_file_paths(model_location)
component_hashes: list[str] = []
for component in sorted(model_component_paths):
component_hashes.append(cls._hash_file(component, algorithm))
component_hashes.append(self._hash_file(component))
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
# for the composite hash
@ -128,9 +132,13 @@ class ModelHash:
return file_hash.hexdigest()
@staticmethod
def _hashlib(file_path: Path, algorithm: ALGORITHMS) -> str:
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
"""Hashes a file using a hashlib algorithm"""
file_hasher = hashlib.new(algorithm)
with open(file_path, "rb") as f:
file_hasher.update(f.read())
return file_hasher.hexdigest()
def hasher(file_path: Path) -> str:
file_hasher = hashlib.new(algorithm)
with open(file_path, "rb") as f:
file_hasher.update(f.read())
return file_hasher.hexdigest()
return hasher

View File

@ -147,7 +147,7 @@ class ModelProbe(object):
if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = ModelHash.hash(model_path)
hash = ModelHash().hash(model_path)
probe = probe_class(model_path)
fields["path"] = model_path.as_posix()