mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): modularize ModelHash to facilitate testing
This commit is contained in:
parent
982076d7d7
commit
ec8ed530a7
@ -50,6 +50,7 @@ class ModelHash:
|
|||||||
:param model_location: Path to the model
|
:param model_location: Path to the model
|
||||||
:param algorithm: Hashing algorithm to use
|
:param algorithm: Hashing algorithm to use
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_location = Path(model_location)
|
model_location = Path(model_location)
|
||||||
if model_location.is_file():
|
if model_location.is_file():
|
||||||
return cls._hash_file(model_location, algorithm)
|
return cls._hash_file(model_location, algorithm)
|
||||||
@ -59,7 +60,7 @@ class ModelHash:
|
|||||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
raise OSError(f"Not a valid file or directory: {model_location}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _hash_file(cls, model_location: Union[str, Path], algorithm: ALGORITHMS) -> str:
|
def _hash_file(cls, model_location: Path, algorithm: ALGORITHMS) -> str:
|
||||||
"""
|
"""
|
||||||
Compute the hash for a single file and return its hexdigest.
|
Compute the hash for a single file and return its hexdigest.
|
||||||
|
|
||||||
@ -77,44 +78,48 @@ class ModelHash:
|
|||||||
raise ValueError(f"Algorithm {algorithm} not available")
|
raise ValueError(f"Algorithm {algorithm} not available")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _hash_dir(cls, model_location: Union[str, Path], algorithm: ALGORITHMS) -> str:
|
def _hash_dir(cls, model_location: Path, algorithm: ALGORITHMS) -> str:
|
||||||
"""
|
"""
|
||||||
Compute the hash for all files in a directory and return a hexdigest.
|
Compute the hash for all files in a directory and return a hexdigest.
|
||||||
|
|
||||||
:param model_location: Path to the model directory
|
:param model_location: Path to the model directory
|
||||||
:param algorithm: Hashing algorithm to use
|
:param algorithm: Hashing algorithm to use
|
||||||
"""
|
"""
|
||||||
components: list[str] = []
|
model_component_paths = cls._get_file_paths(model_location)
|
||||||
|
|
||||||
for root, _dirs, files in os.walk(model_location):
|
|
||||||
for file in files:
|
|
||||||
# only tally tensor files because diffusers config files change slightly
|
|
||||||
# depending on how the model was downloaded/converted.
|
|
||||||
if file.endswith(MODEL_FILE_EXTENSIONS):
|
|
||||||
components.append((Path(root, file).as_posix()))
|
|
||||||
|
|
||||||
component_hashes: list[str] = []
|
component_hashes: list[str] = []
|
||||||
for component in sorted(components):
|
for component in sorted(model_component_paths):
|
||||||
component_hashes.append(cls._hash_file(component, algorithm))
|
component_hashes.append(cls._hash_file(component, algorithm))
|
||||||
|
|
||||||
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
||||||
# for the composite hash
|
# for the composite hash
|
||||||
composite_hasher = blake3()
|
composite_hasher = blake3()
|
||||||
for h in components:
|
for h in component_hashes:
|
||||||
composite_hasher.update(h.encode("utf-8"))
|
composite_hasher.update(h.encode("utf-8"))
|
||||||
return composite_hasher.hexdigest()
|
return composite_hasher.hexdigest()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_file_paths(cls, dir: Path) -> list[Path]:
|
||||||
|
"""Return a list of all model files in the directory."""
|
||||||
|
files: list[Path] = []
|
||||||
|
for root, _dirs, _files in os.walk(dir):
|
||||||
|
for file in _files:
|
||||||
|
if file.endswith(MODEL_FILE_EXTENSIONS):
|
||||||
|
files.append(Path(root, file))
|
||||||
|
return files
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _blake3(file_path: Union[str, Path]) -> str:
|
def _blake3(file_path: Path) -> str:
|
||||||
"""Hashes a file using BLAKE3"""
|
"""Hashes a file using BLAKE3"""
|
||||||
file_hasher = blake3(max_threads=blake3.AUTO)
|
file_hasher = blake3(max_threads=blake3.AUTO)
|
||||||
file_hasher.update_mmap(file_path)
|
file_hasher.update_mmap(file_path)
|
||||||
return file_hasher.hexdigest()
|
return file_hasher.hexdigest()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sha1_fast(file_path: Union[str, Path]) -> str:
|
def _sha1_fast(file_path: Path) -> str:
|
||||||
"""Hashes a file using SHA1, but with a block size of 2**16. The result is not a standard SHA1 hash due to the
|
"""Hashes a file using SHA1, but with a block size of 2**16.
|
||||||
# padding introduced by the block size. The algorithm is, however, very fast."""
|
The result is not a correct SHA1 hash for the file, due to the padding introduced by the block size.
|
||||||
|
The algorithm is, however, very fast."""
|
||||||
BLOCK_SIZE = 2**16
|
BLOCK_SIZE = 2**16
|
||||||
file_hash = hashlib.sha1()
|
file_hash = hashlib.sha1()
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
@ -123,7 +128,7 @@ class ModelHash:
|
|||||||
return file_hash.hexdigest()
|
return file_hash.hexdigest()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _hashlib(file_path: Union[str, Path], algorithm: ALGORITHMS) -> str:
|
def _hashlib(file_path: Path, algorithm: ALGORITHMS) -> str:
|
||||||
"""Hashes a file using a hashlib algorithm"""
|
"""Hashes a file using a hashlib algorithm"""
|
||||||
file_hasher = hashlib.new(algorithm)
|
file_hasher = hashlib.new(algorithm)
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
|
Loading…
Reference in New Issue
Block a user