mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): make ModelHash instantiatable, taking an algorithm as arg
This commit is contained in:
parent
3493ae7cdb
commit
5b675d8481
@ -166,7 +166,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
raise DuplicateModelException(
|
raise DuplicateModelException(
|
||||||
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
|
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
|
||||||
) from excp
|
) from excp
|
||||||
new_hash = ModelHash.hash(new_path)
|
new_hash = ModelHash().hash(new_path)
|
||||||
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||||
|
|
||||||
return self._register(
|
return self._register(
|
||||||
@ -468,7 +468,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
new_path = models_dir / model.base.value / model.type.value / model.name
|
new_path = models_dir / model.base.value / model.type.value / model.name
|
||||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||||
new_path = self._move_model(old_path, new_path)
|
new_path = self._move_model(old_path, new_path)
|
||||||
new_hash = ModelHash.hash(new_path)
|
new_hash = ModelHash().hash(new_path)
|
||||||
model.path = new_path.relative_to(models_dir).as_posix()
|
model.path = new_path.relative_to(models_dir).as_posix()
|
||||||
if model.current_hash != new_hash:
|
if model.current_hash != new_hash:
|
||||||
assert (
|
assert (
|
||||||
|
@ -72,7 +72,7 @@ class MigrateModelYamlToDb1:
|
|||||||
|
|
||||||
base_type, model_type, model_name = str(model_key).split("/")
|
base_type, model_type, model_name = str(model_key).split("/")
|
||||||
try:
|
try:
|
||||||
hash = ModelHash.hash(self.config.models_path / stanza.path)
|
hash = ModelHash().hash(self.config.models_path / stanza.path)
|
||||||
except OSError:
|
except OSError:
|
||||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||||
continue
|
continue
|
||||||
|
@ -10,13 +10,13 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
|
|||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Union
|
from typing import Callable, Literal, Union
|
||||||
|
|
||||||
from blake3 import blake3
|
from blake3 import blake3
|
||||||
|
|
||||||
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||||
|
|
||||||
ALGORITHMS = Literal[
|
ALGORITHM = Literal[
|
||||||
"md5",
|
"md5",
|
||||||
"sha1",
|
"sha1",
|
||||||
"sha1_fast",
|
"sha1_fast",
|
||||||
@ -37,10 +37,39 @@ ALGORITHMS = Literal[
|
|||||||
|
|
||||||
|
|
||||||
class ModelHash:
|
class ModelHash:
|
||||||
"""ModelHash provides one public class method, hash()."""
|
"""
|
||||||
|
Creates a hash of a model using a specified algorithm.
|
||||||
|
|
||||||
@classmethod
|
:param algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||||
def hash(cls, model_location: Union[str, Path], algorithm: ALGORITHMS = "blake3") -> str:
|
|
||||||
|
If the model is a single file, it is hashed directly using the provided algorithm.
|
||||||
|
|
||||||
|
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
|
||||||
|
|
||||||
|
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
|
||||||
|
|
||||||
|
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
||||||
|
that directory hashes are never weaker than the file hashes.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
|
||||||
|
```py
|
||||||
|
ModelHash().hash("path/to/some/model.safetensors")
|
||||||
|
ModelHash("md5").hash("path/to/model/dir/")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, algorithm: ALGORITHM = "blake3") -> None:
|
||||||
|
if algorithm == "blake3":
|
||||||
|
self._hash_file = self._blake3
|
||||||
|
elif algorithm == "sha1_fast":
|
||||||
|
self._hash_file = self._sha1_fast
|
||||||
|
elif algorithm in hashlib.algorithms_available:
|
||||||
|
self._hash_file = self._get_hashlib(algorithm)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Algorithm {algorithm} not available")
|
||||||
|
|
||||||
|
def hash(self, model_location: Union[str, Path]) -> str:
|
||||||
"""
|
"""
|
||||||
Return hexdigest string for model located at model_location.
|
Return hexdigest string for model located at model_location.
|
||||||
|
|
||||||
@ -48,48 +77,23 @@ class ModelHash:
|
|||||||
directory. The final composite hash is always computed using BLAKE3.
|
directory. The final composite hash is always computed using BLAKE3.
|
||||||
|
|
||||||
:param model_location: Path to the model
|
:param model_location: Path to the model
|
||||||
:param algorithm: Hashing algorithm to use
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_location = Path(model_location)
|
model_location = Path(model_location)
|
||||||
if model_location.is_file():
|
if model_location.is_file():
|
||||||
return cls._hash_file(model_location, algorithm)
|
return self._hash_file(model_location)
|
||||||
elif model_location.is_dir():
|
elif model_location.is_dir():
|
||||||
return cls._hash_dir(model_location, algorithm)
|
return self._hash_dir(model_location)
|
||||||
else:
|
else:
|
||||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
raise OSError(f"Not a valid file or directory: {model_location}")
|
||||||
|
|
||||||
@classmethod
|
def _hash_dir(self, model_location: Path) -> str:
|
||||||
def _hash_file(cls, model_location: Path, algorithm: ALGORITHMS) -> str:
|
"""Compute the hash for all files in a directory and return a hexdigest."""
|
||||||
"""
|
model_component_paths = self._get_file_paths(model_location)
|
||||||
Compute the hash for a single file and return its hexdigest.
|
|
||||||
|
|
||||||
:param model_location: Path to the model file
|
|
||||||
:param algorithm: Hashing algorithm to use
|
|
||||||
"""
|
|
||||||
|
|
||||||
if algorithm == "blake3":
|
|
||||||
return cls._blake3(model_location)
|
|
||||||
elif algorithm == "sha1_fast":
|
|
||||||
return cls._sha1_fast(model_location)
|
|
||||||
elif algorithm in hashlib.algorithms_available:
|
|
||||||
return cls._hashlib(model_location, algorithm)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Algorithm {algorithm} not available")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _hash_dir(cls, model_location: Path, algorithm: ALGORITHMS) -> str:
|
|
||||||
"""
|
|
||||||
Compute the hash for all files in a directory and return a hexdigest.
|
|
||||||
|
|
||||||
:param model_location: Path to the model directory
|
|
||||||
:param algorithm: Hashing algorithm to use
|
|
||||||
"""
|
|
||||||
model_component_paths = cls._get_file_paths(model_location)
|
|
||||||
|
|
||||||
component_hashes: list[str] = []
|
component_hashes: list[str] = []
|
||||||
for component in sorted(model_component_paths):
|
for component in sorted(model_component_paths):
|
||||||
component_hashes.append(cls._hash_file(component, algorithm))
|
component_hashes.append(self._hash_file(component))
|
||||||
|
|
||||||
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
||||||
# for the composite hash
|
# for the composite hash
|
||||||
@ -128,9 +132,13 @@ class ModelHash:
|
|||||||
return file_hash.hexdigest()
|
return file_hash.hexdigest()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _hashlib(file_path: Path, algorithm: ALGORITHMS) -> str:
|
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
||||||
"""Hashes a file using a hashlib algorithm"""
|
"""Hashes a file using a hashlib algorithm"""
|
||||||
file_hasher = hashlib.new(algorithm)
|
|
||||||
with open(file_path, "rb") as f:
|
def hasher(file_path: Path) -> str:
|
||||||
file_hasher.update(f.read())
|
file_hasher = hashlib.new(algorithm)
|
||||||
return file_hasher.hexdigest()
|
with open(file_path, "rb") as f:
|
||||||
|
file_hasher.update(f.read())
|
||||||
|
return file_hasher.hexdigest()
|
||||||
|
|
||||||
|
return hasher
|
||||||
|
@ -147,7 +147,7 @@ class ModelProbe(object):
|
|||||||
if not probe_class:
|
if not probe_class:
|
||||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||||
|
|
||||||
hash = ModelHash.hash(model_path)
|
hash = ModelHash().hash(model_path)
|
||||||
probe = probe_class(model_path)
|
probe = probe_class(model_path)
|
||||||
|
|
||||||
fields["path"] = model_path.as_posix()
|
fields["path"] = model_path.as_posix()
|
||||||
|
Loading…
Reference in New Issue
Block a user