Compare commits

...

9 Commits

Author SHA1 Message Date
bd222454cd feat(mm): draft revised config 2024-02-28 21:48:46 +11:00
ee78412aaa tests(mm): add tests for ModelHash 2024-02-28 18:08:35 +11:00
5b675d8481 feat(mm): make ModelHash instantiatable, taking an algorithm as arg 2024-02-28 18:08:35 +11:00
3493ae7cdb feat(mm): modularize ModelHash to facilitate testing 2024-02-28 18:08:34 +11:00
4ee52ed689 feat(mm): add hashing algos to ModelHash
- Some algos are slow, so it is now just called ModelHash
- Added all hashlib algos, plus BLAKE3 and the fast (but incorrect) SHA1 algo
2024-02-28 18:08:34 +11:00
9b4f1126c0 feat(mm): make hash.py a script for testing 2024-02-28 18:08:34 +11:00
7b2f81babb feat(mm): use blake3 for hashing 2024-02-28 18:08:34 +11:00
2b1cb569eb make model key assignment deterministic
- When installing, model keys are now calculated from the model contents.
- .safetensors, .ckpt and other single file models are hashed with sha1
- The contents of diffusers directories are hashed using imohash (faster)

fixup yaml->sql db migration script to assign deterministic key

- this commit also detects and assigns the correct image encoder for
  ip adapter models.
2024-02-28 18:08:34 +11:00
98a13aa7dc handle change to Civitai metadata schema for commercial usage 2024-02-28 16:15:29 +11:00
12 changed files with 633 additions and 61 deletions

View File

@ -7,7 +7,6 @@ import time
from hashlib import sha256
from pathlib import Path
from queue import Empty, Queue
from random import randbytes
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union
@ -28,7 +27,7 @@ from invokeai.backend.model_manager.config import (
ModelRepoVariant,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.hash import ModelHash
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
CivitaiMetadataFetch,
@ -167,7 +166,7 @@ class ModelInstallService(ModelInstallServiceBase):
raise DuplicateModelException(
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
) from excp
new_hash = FastModelHash.hash(new_path)
new_hash = ModelHash().hash(new_path)
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
return self._register(
@ -469,7 +468,7 @@ class ModelInstallService(ModelInstallServiceBase):
new_path = models_dir / model.base.value / model.type.value / model.name
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
new_hash = FastModelHash.hash(new_path)
new_hash = ModelHash().hash(new_path)
model.path = new_path.relative_to(models_dir).as_posix()
if model.current_hash != new_hash:
assert (
@ -536,16 +535,16 @@ class ModelInstallService(ModelInstallServiceBase):
setattr(info, key, value)
return info
def _create_key(self) -> str:
return sha256(randbytes(100)).hexdigest()[0:32]
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str:
key = self._create_key()
if config and not config.get("key", None):
config["key"] = key
# the model key is either the forced key specified in config,
# or it is the file/directory hash computed by probe
info = info or ModelProbe.probe(model_path, config)
override_key: Optional[str] = config.get("key") if config else None
assert info.original_hash # always assigned by probe()
info.key = override_key or info.original_hash
model_path = model_path.absolute()
if model_path.is_relative_to(self.app_config.models_path):

View File

@ -3,7 +3,6 @@
import json
import sqlite3
from hashlib import sha1
from logging import Logger
from pathlib import Path
from typing import Optional
@ -22,7 +21,7 @@ from invokeai.backend.model_manager.config import (
ModelConfigFactory,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.hash import ModelHash
ModelsValidator = TypeAdapter(AnyModelConfig)
@ -73,19 +72,27 @@ class MigrateModelYamlToDb1:
base_type, model_type, model_name = str(model_key).split("/")
try:
hash = FastModelHash.hash(self.config.models_path / stanza.path)
hash = ModelHash().hash(self.config.models_path / stanza.path)
except OSError:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type)
stanza["name"] = model_name
stanza["original_hash"] = hash
stanza["current_hash"] = hash
new_key = hash # deterministic key assignment
# special case for ip adapters, which need the new `image_encoder_model_id` field
if stanza["type"] == ModelType.IPAdapter:
try:
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
self.config.models_path / stanza.path
)
except OSError:
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
continue
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
self._update_model(key, new_config)
else:
self.logger.info(f"Adding model {model_name} with key {model_key}")
self.logger.info(f"Adding model {model_name} with key {new_key}")
self._add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
@ -149,3 +156,8 @@ class MigrateModelYamlToDb1:
)
except sqlite3.IntegrityError as exc:
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
def _get_image_encoder_model_id(self, model_path: Path) -> str:
with open(model_path / "image_encoder.txt") as f:
encoder = f.read()
return encoder.strip()

View File

@ -0,0 +1,400 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Configuration definitions for image generation models.
Typical usage:
from invokeai.backend.model_manager import ModelConfigFactory
raw = dict(path='models/sd-1/main/foo.ckpt',
name='foo',
base='sd-1',
type='main',
config='configs/stable-diffusion/v1-inference.yaml',
variant='normal',
format='checkpoint'
)
config = ModelConfigFactory.make_config(raw)
print(config.name)
Validation errors will raise an InvalidModelConfigException error.
"""
from datetime import datetime
from enum import Enum
from typing import Literal, Optional, Type, Union
import torch
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, Discriminator, Field, JsonValue, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.backend.model_manager.hash import ALGORITHM, ModelHash
from invokeai.backend.raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
class ModelSourceType(str, Enum):
"""The source of the model."""
HF_REPO_ID = "hf_repo_id"
CIVITAI = "civitai"
URL = "url"
PATH = "path"
class BaseModelType(str, Enum):
"""Base model type."""
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
"""Model type."""
ONNX = "onnx"
Main = "main"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
class ModelVariantType(str, Enum):
"""Variant type."""
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class ModelFormat(str, Enum):
"""Storage format of model."""
Diffusers = "diffusers"
Checkpoint = "checkpoint"
Lycoris = "lycoris"
Onnx = "onnx"
Olive = "olive"
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
InvokeAI = "invokeai"
class SchedulerPredictionType(str, Enum):
"""Scheduler prediction type."""
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format."""
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
FP16 = "fp16"
FP32 = "fp32"
ONNX = "onnx"
OPENVINO = "openvino"
FLAX = "flax"
class _ModelConfigBase(BaseModel):
"""The configuration of a model."""
id: str = Field(description="The unique identifier of the model") # Primary Key
hash: str = Field(description="The BLAKE3 hash of the model.", frozen=True)
base: BaseModelType = Field(description="The base of the model")
path: str = Field(description="The path of the model")
name: str = Field(description="The name of the model")
description: Optional[str] = Field(description="The description of the model", default=None)
def compute_hash(self, algorithm: ALGORITHM = "blake3") -> str:
"""Compute the hash of the model."""
return ModelHash(algorithm).hash(self.path)
class _CheckpointConfig(_ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config_path: str = Field(description="Path to the checkpoint model config file")
class _DiffusersConfig(_ModelConfigBase):
"""Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
class LoRALycorisConfig(_ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Lycoris] = ModelFormat.Lycoris
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Lora}.{ModelFormat.Lycoris}")
class LoRADiffusersConfig(_ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}")
class VaeCheckpointConfig(_ModelConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Vae}.{ModelFormat.Checkpoint}")
class VaeDiffusersConfig(_ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Vae}.{ModelFormat.Diffusers}")
class ControlNetDiffusersConfig(_DiffusersConfig):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Diffusers}")
class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Checkpoint}")
class TextualInversionFileConfig(_ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFile}")
class TextualInversionFolderConfig(_ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}")
class _MainConfig(_ModelConfigBase):
"""Model config for main models."""
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
ztsnr_training: bool = False
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}")
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main}.{ModelFormat.Diffusers}")
class IPAdapterConfig(_ModelConfigBase):
"""Model config for IP Adaptor format models."""
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter}.{ModelFormat.InvokeAI}")
class CLIPVisionDiffusersConfig(_ModelConfigBase):
"""Model config for ClipVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPVision}.{ModelFormat.Diffusers}")
class T2IAdapterConfig(_ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T2IAdapter}.{ModelFormat.Diffusers}")
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
"""
if isinstance(v, dict):
return f"{v.get('type')}.{v.get('format')}" # pyright: ignore [reportUnknownMemberType]
return f"{v.getattr('type')}.{v.getattr('format')}"
AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[VaeDiffusersConfig, VaeDiffusersConfig.get_tag()],
Annotated[VaeCheckpointConfig, VaeCheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALycorisConfig, LoRALycorisConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
class ModelRecord(BaseModel):
"""A record of a model in the database."""
# Internal DB/record data
id: str = Field(description="The unique identifier of the model") # Primary Key
config: AnyModelConfig = Field(description="The configuration of the model")
source: str = Field(
description="The original source of the model (path, URL or repo_id)",
frozen=True, # This field is immutable
)
source_type: ModelSourceType = Field(
description="The type of the source of the model",
frozen=True, # This field is immutable
)
source_api_response: Optional[JsonValue] = Field(
description="The original API response from which the model was installed.",
default=None,
frozen=True, # This field is immutable
)
created_at: datetime | str = Field(description="When the model was created")
updated_at: datetime | str = Field(description="When the model was last updated")
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""
@classmethod
def make_config(
cls,
model_data: Union[Dict[str, Any], AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type[_ModelConfigBase]] = None,
timestamp: Optional[float] = None,
) -> AnyModelConfig:
"""
Return the appropriate config object from raw dict values.
:param model_data: A raw dict corresponding the obect fields to be
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
object, which will be passed through unchanged.
:param dest_class: The config class to be returned. If not provided, will
be selected automatically.
"""
model: Optional[_ModelConfigBase] = None
if isinstance(model_data, _ModelConfigBase):
model = model_data
elif dest_class:
model = dest_class.model_validate(model_data)
else:
# mypy doesn't typecheck TypeAdapters well?
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
assert model is not None
if key:
model.key = key
if timestamp:
model.last_modified = timestamp
return model # type: ignore

View File

@ -7,60 +7,138 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import hashlib
import os
from pathlib import Path
from typing import Dict, Union
from typing import Callable, Literal, Union
from imohash import hashfile
from blake3 import blake3
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
ALGORITHM = Literal[
"md5",
"sha1",
"sha1_fast",
"sha224",
"sha256",
"sha384",
"sha512",
"blake2b",
"blake2s",
"sha3_224",
"sha3_256",
"sha3_384",
"sha3_512",
"shake_128",
"shake_256",
"blake3",
]
class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""
class ModelHash:
"""
Creates a hash of a model using a specified algorithm.
@classmethod
def hash(cls, model_location: Union[str, Path]) -> str:
:param algorithm: Hashing algorithm to use. Defaults to BLAKE3.
If the model is a single file, it is hashed directly using the provided algorithm.
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
Usage
```py
ModelHash().hash("path/to/some/model.safetensors")
ModelHash("md5").hash("path/to/model/dir/")
```
"""
def __init__(self, algorithm: ALGORITHM = "blake3") -> None:
if algorithm == "blake3":
self._hash_file = self._blake3
elif algorithm == "sha1_fast":
self._hash_file = self._sha1_fast
elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm)
else:
raise ValueError(f"Algorithm {algorithm} not available")
def hash(self, model_location: Union[str, Path]) -> str:
"""
Return hexdigest string for model located at model_location.
If model_location is a directory, the hash is computed by hashing the hashes of all model files in the
directory. The final composite hash is always computed using BLAKE3.
:param model_location: Path to the model
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file(model_location)
return self._hash_file(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location)
return self._hash_dir(model_location)
else:
raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str:
"""
Fasthash a single file and return its hexdigest.
def _hash_dir(self, model_location: Path) -> str:
"""Compute the hash for all files in a directory and return a hexdigest."""
model_component_paths = self._get_file_paths(model_location)
:param model_location: Path to the model file
"""
# we return md5 hash of the filehash to make it shorter
# cryptographic security not needed here
return hashlib.md5(hashfile(model_location)).hexdigest()
component_hashes: list[str] = []
for component in sorted(model_component_paths):
component_hashes.append(self._hash_file(component))
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
# for the composite hash
composite_hasher = blake3()
for h in component_hashes:
composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest()
@classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {}
def _get_file_paths(cls, dir: Path) -> list[Path]:
"""Return a list of all model files in the directory."""
files: list[Path] = []
for root, _dirs, _files in os.walk(dir):
for file in _files:
if file.endswith(MODEL_FILE_EXTENSIONS):
files.append(Path(root, file))
return files
for root, _dirs, files in os.walk(model_location):
for file in files:
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue
path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file(path)
components.update({path: fast_hash})
@staticmethod
def _blake3(file_path: Path) -> str:
"""Hashes a file using BLAKE3"""
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(file_path)
return file_hasher.hexdigest()
# hash all the model hashes together, using alphabetic file order
md5 = hashlib.md5()
for _path, fast_hash in sorted(components.items()):
md5.update(fast_hash.encode("utf-8"))
return md5.hexdigest()
@staticmethod
def _sha1_fast(file_path: Path) -> str:
"""Hashes a file using SHA1, but with a block size of 2**16.
The result is not a correct SHA1 hash for the file, due to the padding introduced by the block size.
The algorithm is, however, very fast."""
BLOCK_SIZE = 2**16
file_hash = hashlib.sha1()
with open(file_path, "rb") as f:
data = f.read(BLOCK_SIZE)
file_hash.update(data)
return file_hash.hexdigest()
@staticmethod
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
"""Hashes a file using a hashlib algorithm"""
def hasher(file_path: Path) -> str:
file_hasher = hashlib.new(algorithm)
with open(file_path, "rb") as f:
file_hasher.update(f.read())
return file_hasher.hexdigest()
return hasher

View File

@ -160,7 +160,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
nsfw=model_json["nsfw"],
restrictions=LicenseRestrictions(
AllowNoCredit=model_json["allowNoCredit"],
AllowCommercialUse=CommercialUsage(model_json["allowCommercialUse"]),
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
AllowDerivatives=model_json["allowDerivatives"],
AllowDifferentLicense=model_json["allowDifferentLicense"],
),

View File

@ -54,8 +54,8 @@ class LicenseRestrictions(BaseModel):
AllowDifferentLicense: bool = Field(
description="if true, derivatives of this model be redistributed under a different license", default=False
)
AllowCommercialUse: Optional[CommercialUsage] = Field(
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default=None
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
description="Type of commercial use allowed if no commercial use is allowed.", default=None
)
@ -142,7 +142,10 @@ class CivitaiMetadata(ModelMetadataWithFiles):
if self.restrictions.AllowCommercialUse is None:
return False
else:
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
# accommodate schema change
acu = self.restrictions.AllowCommercialUse
commercial_usage = acu if isinstance(acu, set) else {acu}
return CommercialUsage.No not in commercial_usage
@property
def allow_derivatives(self) -> bool:

View File

@ -21,7 +21,7 @@ from .config import (
ModelVariantType,
SchedulerPredictionType,
)
from .hash import FastModelHash
from .hash import ModelHash
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
CkptType = Dict[str, Any]
@ -147,7 +147,7 @@ class ModelProbe(object):
if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = FastModelHash.hash(model_path)
hash = ModelHash().hash(model_path)
probe = probe_class(model_path)
fields["path"] = model_path.as_posix()

View File

@ -64,6 +64,7 @@ dependencies = [
# Auxiliary dependencies, pinned only if necessary.
"albumentations",
"blake3",
"click",
"datasets",
"Deprecated",
@ -72,7 +73,6 @@ dependencies = [
"easing-functions",
"einops",
"facexlib",
"imohash",
"matplotlib", # needed for plotting of Penner easing functions
"npyscreen",
"omegaconf",

View File

@ -31,7 +31,7 @@ def test_registration(mm2_installer: ModelInstallServiceBase, embedding_file: Pa
assert len(matches) == 0
key = mm2_installer.register_path(embedding_file)
assert key is not None
assert len(key) == 32
assert len(key) == 40 # length of the sha1 hash
def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:

File diff suppressed because one or more lines are too long

View File

@ -133,7 +133,7 @@ def test_metadata_civitai_fetch(mm2_session: Session) -> None:
assert metadata.id == 215485
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse
assert metadata.version_id == 242807
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}

80
tests/test_model_hash.py Normal file
View File

@ -0,0 +1,80 @@
# pyright:reportPrivateUsage=false
from pathlib import Path
from typing import Iterable
import pytest
from blake3 import blake3
from invokeai.backend.model_manager.hash import ALGORITHM, ModelHash
test_cases: list[tuple[ALGORITHM, str]] = [
("md5", "a0cd925fc063f98dbf029eee315060c3"),
("sha1", "9e362940e5603fdc60566ea100a288ba2fe48b8c"),
("sha256", "6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"),
(
"sha512",
"c4a10476b21e00042f638ad5755c561d91f2bb599d3504d25409495e1c7eda94543332a1a90fbb4efdaf9ee462c33e0336b5eae4acfb1fa0b186af452dd67dc6",
),
("blake3", "ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"),
]
@pytest.mark.parametrize("algorithm,expected_hash", test_cases)
def test_model_hash_hashes_file(tmp_path: Path, algorithm: ALGORITHM, expected_hash: str):
file = Path(tmp_path / "test")
file.write_text("model data")
md5 = ModelHash(algorithm).hash(file)
assert md5 == expected_hash
@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3"])
def test_model_hash_hashes_dir(tmp_path: Path, algorithm: ALGORITHM):
model_hash = ModelHash(algorithm)
files = [Path(tmp_path, f"{i}.bin") for i in range(5)]
for f in files:
f.write_text("data")
md5 = model_hash.hash(tmp_path)
# Manual implementation of composite hash - always uses BLAKE3
composite_hasher = blake3()
for f in files:
h = model_hash.hash(f)
composite_hasher.update(h.encode("utf-8"))
assert md5 == composite_hasher.hexdigest()
def test_model_hash_raises_error_on_invalid_algorithm():
with pytest.raises(ValueError, match="Algorithm invalid_algorithm not available"):
ModelHash("invalid_algorithm") # pyright: ignore [reportArgumentType]
def paths_to_str_set(paths: Iterable[Path]) -> set[str]:
return {str(p) for p in paths}
def test_model_hash_filters_out_non_model_files(tmp_path: Path):
model_files = {
Path(tmp_path, f"{i}.{ext}") for i, ext in enumerate([".ckpt", ".safetensors", ".bin", ".pt", ".pth"])
}
for i, f in enumerate(model_files):
f.write_text(f"data{i}")
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
# Add file that should be ignored - hash should not change
file = tmp_path / "test.icecream"
file.write_text("data")
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
# Add file that should not be ignored - hash should change
file = tmp_path / "test.bin"
file.write_text("more data")
model_files.add(file)
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)