mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): faster hashing for spinning disk HDDs
BLAKE3 has poor performance on spinning disks when parallelized. See https://github.com/BLAKE3-team/BLAKE3/issues/31 - Replace `skip_model_hash` setting with `hashing_algorithm`. Any algorithm we support is accepted. - Add `random` algorithm: hashes a UUID with BLAKE3 to create a random "hash". Equivalent to the previous skip functionality. - Add `blake3_single` algorithm: hashes on a single thread using BLAKE3, fixes the aforementioned performance issue - Update model probe to accept the algorithm to hash with as an optional arg, defaulting to `blake3` - Update all calls of the probe to use the app's configured hashing algorithm - Update an external script that probes models - Update tests - Move ModelHash into its own module to avoid circuclar import issues
This commit is contained in:
parent
8287fcf097
commit
eb6e6548ed
@ -179,6 +179,8 @@ from pydantic import BaseModel, Field, field_validator
|
|||||||
from pydantic.config import JsonDict
|
from pydantic.config import JsonDict
|
||||||
from pydantic_settings import SettingsConfigDict
|
from pydantic_settings import SettingsConfigDict
|
||||||
|
|
||||||
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||||
|
|
||||||
from .config_base import InvokeAISettings
|
from .config_base import InvokeAISettings
|
||||||
|
|
||||||
INIT_FILE = Path("invokeai.yaml")
|
INIT_FILE = Path("invokeai.yaml")
|
||||||
@ -360,7 +362,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory.", json_schema_extra=Categories.Nodes)
|
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory.", json_schema_extra=Categories.Nodes)
|
||||||
|
|
||||||
# MODEL INSTALL
|
# MODEL INSTALL
|
||||||
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models.", json_schema_extra=Categories.ModelInstall)
|
hashing_algorithm : HASHING_ALGORITHMS = Field(default="blake3", description="Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'none' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.", json_schema_extra=Categories.ModelInstall)
|
||||||
remote_api_tokens : Optional[list[URLRegexToken]] = Field(
|
remote_api_tokens : Optional[list[URLRegexToken]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.",
|
description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.",
|
||||||
|
@ -22,7 +22,6 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
|||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||||
from invokeai.app.util.misc import uuid_string
|
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -154,10 +153,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
config = config or {}
|
config = config or {}
|
||||||
|
|
||||||
if self._app_config.skip_model_hash:
|
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
|
||||||
config["hash"] = uuid_string()
|
|
||||||
|
|
||||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
|
|
||||||
|
|
||||||
if preferred_name := config.get("name"):
|
if preferred_name := config.get("name"):
|
||||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||||
@ -585,10 +581,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
) -> str:
|
) -> str:
|
||||||
config = config or {}
|
config = config or {}
|
||||||
|
|
||||||
if self._app_config.skip_model_hash:
|
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
|
||||||
config["hash"] = uuid_string()
|
|
||||||
|
|
||||||
info = info or ModelProbe.probe(model_path, config)
|
|
||||||
|
|
||||||
model_path = model_path.resolve()
|
model_path = model_path.resolve()
|
||||||
|
|
||||||
|
@ -1,12 +1,4 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||||
"""
|
|
||||||
Fast hashing of diffusers and checkpoint-style models.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from invokeai.backend.model_managre.model_hash import FastModelHash
|
|
||||||
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
|
||||||
'a8e693a126ea5b831c96064dc569956f'
|
|
||||||
"""
|
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
@ -15,9 +7,9 @@ from typing import Callable, Literal, Optional, Union
|
|||||||
|
|
||||||
from blake3 import blake3
|
from blake3 import blake3
|
||||||
|
|
||||||
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
ALGORITHM = Literal[
|
HASHING_ALGORITHMS = Literal[
|
||||||
"md5",
|
"md5",
|
||||||
"sha1",
|
"sha1",
|
||||||
"sha224",
|
"sha224",
|
||||||
@ -33,7 +25,10 @@ ALGORITHM = Literal[
|
|||||||
"shake_128",
|
"shake_128",
|
||||||
"shake_256",
|
"shake_256",
|
||||||
"blake3",
|
"blake3",
|
||||||
|
"blake3_single",
|
||||||
|
"random",
|
||||||
]
|
]
|
||||||
|
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||||
|
|
||||||
|
|
||||||
class ModelHash:
|
class ModelHash:
|
||||||
@ -53,6 +48,8 @@ class ModelHash:
|
|||||||
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
||||||
that directory hashes are never weaker than the file hashes.
|
that directory hashes are never weaker than the file hashes.
|
||||||
|
|
||||||
|
A convenience algorithm choice of "random" is also available, which returns a random string. This is not a hash.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
```py
|
```py
|
||||||
# BLAKE3 hash
|
# BLAKE3 hash
|
||||||
@ -62,11 +59,17 @@ class ModelHash:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
|
def __init__(
|
||||||
|
self, algorithm: HASHING_ALGORITHMS = "blake3", file_filter: Optional[Callable[[str], bool]] = None
|
||||||
|
) -> None:
|
||||||
if algorithm == "blake3":
|
if algorithm == "blake3":
|
||||||
self._hash_file = self._blake3
|
self._hash_file = self._blake3
|
||||||
|
elif algorithm == "blake3_single":
|
||||||
|
self._hash_file = self._blake3_single
|
||||||
elif algorithm in hashlib.algorithms_available:
|
elif algorithm in hashlib.algorithms_available:
|
||||||
self._hash_file = self._get_hashlib(algorithm)
|
self._hash_file = self._get_hashlib(algorithm)
|
||||||
|
elif algorithm == "random":
|
||||||
|
self._hash_file = self._random
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Algorithm {algorithm} not available")
|
raise ValueError(f"Algorithm {algorithm} not available")
|
||||||
|
|
||||||
@ -137,7 +140,7 @@ class ModelHash:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _blake3(file_path: Path) -> str:
|
def _blake3(file_path: Path) -> str:
|
||||||
"""Hashes a file using BLAKE3
|
"""Hashes a file using BLAKE3, using parallelized and memory-mapped I/O to avoid reading the entire file into memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: Path to the file to hash
|
file_path: Path to the file to hash
|
||||||
@ -150,7 +153,21 @@ class ModelHash:
|
|||||||
return file_hasher.hexdigest()
|
return file_hasher.hexdigest()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
def _blake3_single(file_path: Path) -> str:
|
||||||
|
"""Hashes a file using BLAKE3, without parallelism. Suitable for spinning hard drives.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to hash
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hexdigest of the hash of the file
|
||||||
|
"""
|
||||||
|
file_hasher = blake3()
|
||||||
|
file_hasher.update_mmap(file_path)
|
||||||
|
return file_hasher.hexdigest()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_hashlib(algorithm: HASHING_ALGORITHMS) -> Callable[[Path], str]:
|
||||||
"""Factory function that returns a function to hash a file with the given algorithm.
|
"""Factory function that returns a function to hash a file with the given algorithm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -172,6 +189,13 @@ class ModelHash:
|
|||||||
|
|
||||||
return hashlib_hasher
|
return hashlib_hasher
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _random(_file_path: Path) -> str:
|
||||||
|
"""Returns a random string. This is not a hash.
|
||||||
|
|
||||||
|
The string is a UUID, hashed with BLAKE3 to ensure that it is unique."""
|
||||||
|
return blake3(uuid_string().encode()).hexdigest()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _default_file_filter(file_path: str) -> bool:
|
def _default_file_filter(file_path: str) -> bool:
|
||||||
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
|
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
|
@ -9,6 +9,7 @@ from picklescan.scanner import scan_file_path
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||||
from invokeai.backend.util.util import SilenceWarnings
|
from invokeai.backend.util.util import SilenceWarnings
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
@ -24,7 +25,6 @@ from .config import (
|
|||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
)
|
)
|
||||||
from .hash import ModelHash
|
|
||||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||||
|
|
||||||
CkptType = Dict[str, Any]
|
CkptType = Dict[str, Any]
|
||||||
@ -113,9 +113,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def probe(
|
def probe(
|
||||||
cls,
|
cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, hash_algo: HASHING_ALGORITHMS = "blake3"
|
||||||
model_path: Path,
|
|
||||||
fields: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Probe the model at model_path and return its configuration record.
|
Probe the model at model_path and return its configuration record.
|
||||||
@ -160,7 +158,7 @@ class ModelProbe(object):
|
|||||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||||
)
|
)
|
||||||
fields["format"] = fields.get("format") or probe.get_format()
|
fields["format"] = fields.get("format") or probe.get_format()
|
||||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||||
|
|
||||||
fields["default_settings"] = (
|
fields["default_settings"] = (
|
||||||
fields.get("default_settings") or probe.get_default_settings(fields["name"])
|
fields.get("default_settings") or probe.get_default_settings(fields["name"])
|
||||||
|
@ -4,20 +4,30 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import get_args
|
||||||
|
|
||||||
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||||
from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe
|
from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe
|
||||||
|
|
||||||
|
algos = ", ".join(set(get_args(HASHING_ALGORITHMS)))
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Probe model type")
|
parser = argparse.ArgumentParser(description="Probe model type")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"model_path",
|
"model_path",
|
||||||
type=Path,
|
type=Path,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hash_algo",
|
||||||
|
type=str,
|
||||||
|
default="blake3",
|
||||||
|
help=f"Hashing algorithm to use (default: blake3), one of: {algos}",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
for path in args.model_path:
|
for path in args.model_path:
|
||||||
try:
|
try:
|
||||||
info = ModelProbe.probe(path)
|
info = ModelProbe.probe(path, hash_algo=args.hash_algo)
|
||||||
print(f"{path}:{info.model_dump_json(indent=4)}")
|
print(f"{path}:{info.model_dump_json(indent=4)}")
|
||||||
except InvalidModelConfigException as exc:
|
except InvalidModelConfigException as exc:
|
||||||
print(exc)
|
print(exc)
|
||||||
|
@ -6,9 +6,9 @@ from typing import Iterable
|
|||||||
import pytest
|
import pytest
|
||||||
from blake3 import blake3
|
from blake3 import blake3
|
||||||
|
|
||||||
from invokeai.backend.model_manager.hash import ALGORITHM, MODEL_FILE_EXTENSIONS, ModelHash
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, MODEL_FILE_EXTENSIONS, ModelHash
|
||||||
|
|
||||||
test_cases: list[tuple[ALGORITHM, str]] = [
|
test_cases: list[tuple[HASHING_ALGORITHMS, str]] = [
|
||||||
("md5", "a0cd925fc063f98dbf029eee315060c3"),
|
("md5", "a0cd925fc063f98dbf029eee315060c3"),
|
||||||
("sha1", "9e362940e5603fdc60566ea100a288ba2fe48b8c"),
|
("sha1", "9e362940e5603fdc60566ea100a288ba2fe48b8c"),
|
||||||
("sha256", "6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"),
|
("sha256", "6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"),
|
||||||
@ -21,7 +21,7 @@ test_cases: list[tuple[ALGORITHM, str]] = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("algorithm,expected_hash", test_cases)
|
@pytest.mark.parametrize("algorithm,expected_hash", test_cases)
|
||||||
def test_model_hash_hashes_file(tmp_path: Path, algorithm: ALGORITHM, expected_hash: str):
|
def test_model_hash_hashes_file(tmp_path: Path, algorithm: HASHING_ALGORITHMS, expected_hash: str):
|
||||||
file = Path(tmp_path / "test")
|
file = Path(tmp_path / "test")
|
||||||
file.write_text("model data")
|
file.write_text("model data")
|
||||||
md5 = ModelHash(algorithm).hash(file)
|
md5 = ModelHash(algorithm).hash(file)
|
||||||
@ -29,7 +29,7 @@ def test_model_hash_hashes_file(tmp_path: Path, algorithm: ALGORITHM, expected_h
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3"])
|
@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3"])
|
||||||
def test_model_hash_hashes_dir(tmp_path: Path, algorithm: ALGORITHM):
|
def test_model_hash_hashes_dir(tmp_path: Path, algorithm: HASHING_ALGORITHMS):
|
||||||
model_hash = ModelHash(algorithm)
|
model_hash = ModelHash(algorithm)
|
||||||
files = [Path(tmp_path, f"{i}.bin") for i in range(5)]
|
files = [Path(tmp_path, f"{i}.bin") for i in range(5)]
|
||||||
|
|
||||||
@ -47,6 +47,24 @@ def test_model_hash_hashes_dir(tmp_path: Path, algorithm: ALGORITHM):
|
|||||||
assert md5 == composite_hasher.hexdigest()
|
assert md5 == composite_hasher.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_hash_blake3_matches_blake3_single(tmp_path: Path):
|
||||||
|
model_hash = ModelHash("blake3")
|
||||||
|
model_hash_simple = ModelHash("blake3_single")
|
||||||
|
|
||||||
|
file = tmp_path / "test.bin"
|
||||||
|
file.write_text("model data")
|
||||||
|
|
||||||
|
assert model_hash.hash(file) == model_hash_simple.hash(file)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_hash_random_algorithm(tmp_path: Path):
|
||||||
|
model_hash = ModelHash("random")
|
||||||
|
file = tmp_path / "test.bin"
|
||||||
|
file.write_text("model data")
|
||||||
|
|
||||||
|
assert model_hash.hash(file) != model_hash.hash(file)
|
||||||
|
|
||||||
|
|
||||||
def test_model_hash_raises_error_on_invalid_algorithm():
|
def test_model_hash_raises_error_on_invalid_algorithm():
|
||||||
with pytest.raises(ValueError, match="Algorithm invalid_algorithm not available"):
|
with pytest.raises(ValueError, match="Algorithm invalid_algorithm not available"):
|
||||||
ModelHash("invalid_algorithm") # pyright: ignore [reportArgumentType]
|
ModelHash("invalid_algorithm") # pyright: ignore [reportArgumentType]
|
||||||
|
Loading…
Reference in New Issue
Block a user