From eb6e6548ed3e18aa44b3cbcbc29fc286ccc8c7f6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 14 Mar 2024 09:44:55 +1100 Subject: [PATCH] feat(mm): faster hashing for spinning disk HDDs BLAKE3 has poor performance on spinning disks when parallelized. See https://github.com/BLAKE3-team/BLAKE3/issues/31 - Replace `skip_model_hash` setting with `hashing_algorithm`. Any algorithm we support is accepted. - Add `random` algorithm: hashes a UUID with BLAKE3 to create a random "hash". Equivalent to the previous skip functionality. - Add `blake3_single` algorithm: hashes on a single thread using BLAKE3, fixes the aforementioned performance issue - Update model probe to accept the algorithm to hash with as an optional arg, defaulting to `blake3` - Update all calls of the probe to use the app's configured hashing algorithm - Update an external script that probes models - Update tests - Move ModelHash into its own module to avoid circuclar import issues --- .../app/services/config/config_default.py | 4 +- .../model_install/model_install_default.py | 11 +--- .../hash.py => model_hash/model_hash.py} | 50 ++++++++++++++----- invokeai/backend/model_manager/probe.py | 8 ++- scripts/probe-model.py | 12 ++++- tests/test_model_hash.py | 26 ++++++++-- 6 files changed, 78 insertions(+), 33 deletions(-) rename invokeai/backend/{model_manager/hash.py => model_hash/model_hash.py} (79%) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index fb09b49bbc..85db8a41ab 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -179,6 +179,8 @@ from pydantic import BaseModel, Field, field_validator from pydantic.config import JsonDict from pydantic_settings import SettingsConfigDict +from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS + from .config_base import InvokeAISettings INIT_FILE = Path("invokeai.yaml") @@ -360,7 +362,7 @@ class InvokeAIAppConfig(InvokeAISettings): node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory.", json_schema_extra=Categories.Nodes) # MODEL INSTALL - skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models.", json_schema_extra=Categories.ModelInstall) + hashing_algorithm : HASHING_ALGORITHMS = Field(default="blake3", description="Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'none' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.", json_schema_extra=Categories.ModelInstall) remote_api_tokens : Optional[list[URLRegexToken]] = Field( default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.", diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 138bde8bbf..273be6ba4b 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -22,7 +22,6 @@ from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records.model_records_base import ModelRecordChanges -from invokeai.app.util.misc import uuid_string from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, @@ -154,10 +153,7 @@ class ModelInstallService(ModelInstallServiceBase): model_path = Path(model_path) config = config or {} - if self._app_config.skip_model_hash: - config["hash"] = uuid_string() - - info: AnyModelConfig = ModelProbe.probe(Path(model_path), config) + info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm) if preferred_name := config.get("name"): preferred_name = Path(preferred_name).with_suffix(model_path.suffix) @@ -585,10 +581,7 @@ class ModelInstallService(ModelInstallServiceBase): ) -> str: config = config or {} - if self._app_config.skip_model_hash: - config["hash"] = uuid_string() - - info = info or ModelProbe.probe(model_path, config) + info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm) model_path = model_path.resolve() diff --git a/invokeai/backend/model_manager/hash.py b/invokeai/backend/model_hash/model_hash.py similarity index 79% rename from invokeai/backend/model_manager/hash.py rename to invokeai/backend/model_hash/model_hash.py index 656b591f4a..b1cd93be7b 100644 --- a/invokeai/backend/model_manager/hash.py +++ b/invokeai/backend/model_hash/model_hash.py @@ -1,12 +1,4 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team -""" -Fast hashing of diffusers and checkpoint-style models. - -Usage: -from invokeai.backend.model_managre.model_hash import FastModelHash ->>> FastModelHash.hash('/home/models/stable-diffusion-v1.5') -'a8e693a126ea5b831c96064dc569956f' -""" import hashlib import os @@ -15,9 +7,9 @@ from typing import Callable, Literal, Optional, Union from blake3 import blake3 -MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth") +from invokeai.app.util.misc import uuid_string -ALGORITHM = Literal[ +HASHING_ALGORITHMS = Literal[ "md5", "sha1", "sha224", @@ -33,7 +25,10 @@ ALGORITHM = Literal[ "shake_128", "shake_256", "blake3", + "blake3_single", + "random", ] +MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth") class ModelHash: @@ -53,6 +48,8 @@ class ModelHash: The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring that directory hashes are never weaker than the file hashes. + A convenience algorithm choice of "random" is also available, which returns a random string. This is not a hash. + Usage: ```py # BLAKE3 hash @@ -62,11 +59,17 @@ class ModelHash: ``` """ - def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None: + def __init__( + self, algorithm: HASHING_ALGORITHMS = "blake3", file_filter: Optional[Callable[[str], bool]] = None + ) -> None: if algorithm == "blake3": self._hash_file = self._blake3 + elif algorithm == "blake3_single": + self._hash_file = self._blake3_single elif algorithm in hashlib.algorithms_available: self._hash_file = self._get_hashlib(algorithm) + elif algorithm == "random": + self._hash_file = self._random else: raise ValueError(f"Algorithm {algorithm} not available") @@ -137,7 +140,7 @@ class ModelHash: @staticmethod def _blake3(file_path: Path) -> str: - """Hashes a file using BLAKE3 + """Hashes a file using BLAKE3, using parallelized and memory-mapped I/O to avoid reading the entire file into memory. Args: file_path: Path to the file to hash @@ -150,7 +153,21 @@ class ModelHash: return file_hasher.hexdigest() @staticmethod - def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]: + def _blake3_single(file_path: Path) -> str: + """Hashes a file using BLAKE3, without parallelism. Suitable for spinning hard drives. + + Args: + file_path: Path to the file to hash + + Returns: + Hexdigest of the hash of the file + """ + file_hasher = blake3() + file_hasher.update_mmap(file_path) + return file_hasher.hexdigest() + + @staticmethod + def _get_hashlib(algorithm: HASHING_ALGORITHMS) -> Callable[[Path], str]: """Factory function that returns a function to hash a file with the given algorithm. Args: @@ -172,6 +189,13 @@ class ModelHash: return hashlib_hasher + @staticmethod + def _random(_file_path: Path) -> str: + """Returns a random string. This is not a hash. + + The string is a UUID, hashed with BLAKE3 to ensure that it is unique.""" + return blake3(uuid_string().encode()).hexdigest() + @staticmethod def _default_file_filter(file_path: str) -> bool: """A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index a9827fb564..dcec5a9a25 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -9,6 +9,7 @@ from picklescan.scanner import scan_file_path import invokeai.backend.util.logging as logger from invokeai.app.util.misc import uuid_string +from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash from invokeai.backend.util.util import SilenceWarnings from .config import ( @@ -24,7 +25,6 @@ from .config import ( ModelVariantType, SchedulerPredictionType, ) -from .hash import ModelHash from .util.model_util import lora_token_vector_length, read_checkpoint_meta CkptType = Dict[str, Any] @@ -113,9 +113,7 @@ class ModelProbe(object): @classmethod def probe( - cls, - model_path: Path, - fields: Optional[Dict[str, Any]] = None, + cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, hash_algo: HASHING_ALGORITHMS = "blake3" ) -> AnyModelConfig: """ Probe the model at model_path and return its configuration record. @@ -160,7 +158,7 @@ class ModelProbe(object): fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" ) fields["format"] = fields.get("format") or probe.get_format() - fields["hash"] = fields.get("hash") or ModelHash().hash(model_path) + fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path) fields["default_settings"] = ( fields.get("default_settings") or probe.get_default_settings(fields["name"]) diff --git a/scripts/probe-model.py b/scripts/probe-model.py index 8518b76437..eca0f4c415 100755 --- a/scripts/probe-model.py +++ b/scripts/probe-model.py @@ -4,20 +4,30 @@ import argparse from pathlib import Path +from typing import get_args +from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe +algos = ", ".join(set(get_args(HASHING_ALGORITHMS))) + parser = argparse.ArgumentParser(description="Probe model type") parser.add_argument( "model_path", type=Path, nargs="+", ) +parser.add_argument( + "--hash_algo", + type=str, + default="blake3", + help=f"Hashing algorithm to use (default: blake3), one of: {algos}", +) args = parser.parse_args() for path in args.model_path: try: - info = ModelProbe.probe(path) + info = ModelProbe.probe(path, hash_algo=args.hash_algo) print(f"{path}:{info.model_dump_json(indent=4)}") except InvalidModelConfigException as exc: print(exc) diff --git a/tests/test_model_hash.py b/tests/test_model_hash.py index 641a150034..e7150633e5 100644 --- a/tests/test_model_hash.py +++ b/tests/test_model_hash.py @@ -6,9 +6,9 @@ from typing import Iterable import pytest from blake3 import blake3 -from invokeai.backend.model_manager.hash import ALGORITHM, MODEL_FILE_EXTENSIONS, ModelHash +from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, MODEL_FILE_EXTENSIONS, ModelHash -test_cases: list[tuple[ALGORITHM, str]] = [ +test_cases: list[tuple[HASHING_ALGORITHMS, str]] = [ ("md5", "a0cd925fc063f98dbf029eee315060c3"), ("sha1", "9e362940e5603fdc60566ea100a288ba2fe48b8c"), ("sha256", "6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"), @@ -21,7 +21,7 @@ test_cases: list[tuple[ALGORITHM, str]] = [ @pytest.mark.parametrize("algorithm,expected_hash", test_cases) -def test_model_hash_hashes_file(tmp_path: Path, algorithm: ALGORITHM, expected_hash: str): +def test_model_hash_hashes_file(tmp_path: Path, algorithm: HASHING_ALGORITHMS, expected_hash: str): file = Path(tmp_path / "test") file.write_text("model data") md5 = ModelHash(algorithm).hash(file) @@ -29,7 +29,7 @@ def test_model_hash_hashes_file(tmp_path: Path, algorithm: ALGORITHM, expected_h @pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3"]) -def test_model_hash_hashes_dir(tmp_path: Path, algorithm: ALGORITHM): +def test_model_hash_hashes_dir(tmp_path: Path, algorithm: HASHING_ALGORITHMS): model_hash = ModelHash(algorithm) files = [Path(tmp_path, f"{i}.bin") for i in range(5)] @@ -47,6 +47,24 @@ def test_model_hash_hashes_dir(tmp_path: Path, algorithm: ALGORITHM): assert md5 == composite_hasher.hexdigest() +def test_model_hash_blake3_matches_blake3_single(tmp_path: Path): + model_hash = ModelHash("blake3") + model_hash_simple = ModelHash("blake3_single") + + file = tmp_path / "test.bin" + file.write_text("model data") + + assert model_hash.hash(file) == model_hash_simple.hash(file) + + +def test_model_hash_random_algorithm(tmp_path: Path): + model_hash = ModelHash("random") + file = tmp_path / "test.bin" + file.write_text("model data") + + assert model_hash.hash(file) != model_hash.hash(file) + + def test_model_hash_raises_error_on_invalid_algorithm(): with pytest.raises(ValueError, match="Algorithm invalid_algorithm not available"): ModelHash("invalid_algorithm") # pyright: ignore [reportArgumentType]