mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
9 Commits
next-fix-t
...
psyche/mm/
Author | SHA1 | Date | |
---|---|---|---|
bd222454cd | |||
ee78412aaa | |||
5b675d8481 | |||
3493ae7cdb | |||
4ee52ed689 | |||
9b4f1126c0 | |||
7b2f81babb | |||
2b1cb569eb | |||
98a13aa7dc |
@ -7,7 +7,6 @@ import time
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from random import randbytes
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
@ -28,7 +27,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.hash import ModelHash
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
@ -167,7 +166,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
raise DuplicateModelException(
|
||||
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
|
||||
) from excp
|
||||
new_hash = FastModelHash.hash(new_path)
|
||||
new_hash = ModelHash().hash(new_path)
|
||||
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||
|
||||
return self._register(
|
||||
@ -469,7 +468,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
new_path = models_dir / model.base.value / model.type.value / model.name
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
new_hash = FastModelHash.hash(new_path)
|
||||
new_hash = ModelHash().hash(new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
if model.current_hash != new_hash:
|
||||
assert (
|
||||
@ -536,16 +535,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
setattr(info, key, value)
|
||||
return info
|
||||
|
||||
def _create_key(self) -> str:
|
||||
return sha256(randbytes(100)).hexdigest()[0:32]
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
key = self._create_key()
|
||||
if config and not config.get("key", None):
|
||||
config["key"] = key
|
||||
# the model key is either the forced key specified in config,
|
||||
# or it is the file/directory hash computed by probe
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
override_key: Optional[str] = config.get("key") if config else None
|
||||
|
||||
assert info.original_hash # always assigned by probe()
|
||||
info.key = override_key or info.original_hash
|
||||
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from hashlib import sha1
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@ -22,7 +21,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelConfigFactory,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.hash import ModelHash
|
||||
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
@ -73,19 +72,27 @@ class MigrateModelYamlToDb1:
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
try:
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
hash = ModelHash().hash(self.config.models_path / stanza.path)
|
||||
except OSError:
|
||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||
continue
|
||||
|
||||
assert isinstance(model_key, str)
|
||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||
|
||||
stanza["base"] = BaseModelType(base_type)
|
||||
stanza["type"] = ModelType(model_type)
|
||||
stanza["name"] = model_name
|
||||
stanza["original_hash"] = hash
|
||||
stanza["current_hash"] = hash
|
||||
new_key = hash # deterministic key assignment
|
||||
|
||||
# special case for ip adapters, which need the new `image_encoder_model_id` field
|
||||
if stanza["type"] == ModelType.IPAdapter:
|
||||
try:
|
||||
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
|
||||
self.config.models_path / stanza.path
|
||||
)
|
||||
except OSError:
|
||||
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
|
||||
continue
|
||||
|
||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||
|
||||
@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
|
||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||
self._update_model(key, new_config)
|
||||
else:
|
||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||
self.logger.info(f"Adding model {model_name} with key {new_key}")
|
||||
self._add_model(new_key, new_config)
|
||||
except DuplicateModelException:
|
||||
self.logger.warning(f"Model {model_name} is already in the database")
|
||||
@ -149,3 +156,8 @@ class MigrateModelYamlToDb1:
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
||||
|
||||
def _get_image_encoder_model_id(self, model_path: Path) -> str:
|
||||
with open(model_path / "image_encoder.txt") as f:
|
||||
encoder = f.read()
|
||||
return encoder.strip()
|
||||
|
400
invokeai/backend/model_manager/config_new.py
Normal file
400
invokeai/backend/model_manager/config_new.py
Normal file
@ -0,0 +1,400 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Configuration definitions for image generation models.
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.backend.model_manager import ModelConfigFactory
|
||||
raw = dict(path='models/sd-1/main/foo.ckpt',
|
||||
name='foo',
|
||||
base='sd-1',
|
||||
type='main',
|
||||
config='configs/stable-diffusion/v1-inference.yaml',
|
||||
variant='normal',
|
||||
format='checkpoint'
|
||||
)
|
||||
config = ModelConfigFactory.make_config(raw)
|
||||
print(config.name)
|
||||
|
||||
Validation errors will raise an InvalidModelConfigException error.
|
||||
|
||||
"""
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from pydantic import BaseModel, Discriminator, Field, JsonValue, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.backend.model_manager.hash import ALGORITHM, ModelHash
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
|
||||
|
||||
class ModelSourceType(str, Enum):
|
||||
"""The source of the model."""
|
||||
|
||||
HF_REPO_ID = "hf_repo_id"
|
||||
CIVITAI = "civitai"
|
||||
URL = "url"
|
||||
PATH = "path"
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
"""Base model type."""
|
||||
|
||||
Any = "any"
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
"""Model type."""
|
||||
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
"""Submodel type."""
|
||||
|
||||
UNet = "unet"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
Normal = "normal"
|
||||
Inpaint = "inpaint"
|
||||
Depth = "depth"
|
||||
|
||||
|
||||
class ModelFormat(str, Enum):
|
||||
"""Storage format of model."""
|
||||
|
||||
Diffusers = "diffusers"
|
||||
Checkpoint = "checkpoint"
|
||||
Lycoris = "lycoris"
|
||||
Onnx = "onnx"
|
||||
Olive = "olive"
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
InvokeAI = "invokeai"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
"""Scheduler prediction type."""
|
||||
|
||||
Epsilon = "epsilon"
|
||||
VPrediction = "v_prediction"
|
||||
Sample = "sample"
|
||||
|
||||
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
OPENVINO = "openvino"
|
||||
FLAX = "flax"
|
||||
|
||||
|
||||
class _ModelConfigBase(BaseModel):
|
||||
"""The configuration of a model."""
|
||||
|
||||
id: str = Field(description="The unique identifier of the model") # Primary Key
|
||||
hash: str = Field(description="The BLAKE3 hash of the model.", frozen=True)
|
||||
base: BaseModelType = Field(description="The base of the model")
|
||||
path: str = Field(description="The path of the model")
|
||||
name: str = Field(description="The name of the model")
|
||||
description: Optional[str] = Field(description="The description of the model", default=None)
|
||||
|
||||
def compute_hash(self, algorithm: ALGORITHM = "blake3") -> str:
|
||||
"""Compute the hash of the model."""
|
||||
return ModelHash(algorithm).hash(self.path)
|
||||
|
||||
|
||||
class _CheckpointConfig(_ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
config_path: str = Field(description="Path to the checkpoint model config file")
|
||||
|
||||
|
||||
class _DiffusersConfig(_ModelConfigBase):
|
||||
"""Model config for diffusers-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||
|
||||
|
||||
class LoRALycorisConfig(_ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||
format: Literal[ModelFormat.Lycoris] = ModelFormat.Lycoris
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Lora}.{ModelFormat.Lycoris}")
|
||||
|
||||
|
||||
class LoRADiffusersConfig(_ModelConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}")
|
||||
|
||||
|
||||
class VaeCheckpointConfig(_ModelConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Vae}.{ModelFormat.Checkpoint}")
|
||||
|
||||
|
||||
class VaeDiffusersConfig(_ModelConfigBase):
|
||||
"""Model config for standalone VAE models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Vae}.{ModelFormat.Diffusers}")
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Diffusers}")
|
||||
|
||||
|
||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Checkpoint}")
|
||||
|
||||
|
||||
class TextualInversionFileConfig(_ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFile}")
|
||||
|
||||
|
||||
class TextualInversionFolderConfig(_ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}")
|
||||
|
||||
|
||||
class _MainConfig(_ModelConfigBase):
|
||||
"""Model config for main models."""
|
||||
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
ztsnr_training: bool = False
|
||||
|
||||
|
||||
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main}.{ModelFormat.Diffusers}")
|
||||
|
||||
|
||||
class IPAdapterConfig(_ModelConfigBase):
|
||||
"""Model config for IP Adaptor format models."""
|
||||
|
||||
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
||||
image_encoder_model_id: str
|
||||
format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.IPAdapter}.{ModelFormat.InvokeAI}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(_ModelConfigBase):
|
||||
"""Model config for ClipVision."""
|
||||
|
||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPVision}.{ModelFormat.Diffusers}")
|
||||
|
||||
|
||||
class T2IAdapterConfig(_ModelConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T2IAdapter}.{ModelFormat.Diffusers}")
|
||||
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
||||
"""
|
||||
if isinstance(v, dict):
|
||||
return f"{v.get('type')}.{v.get('format')}" # pyright: ignore [reportUnknownMemberType]
|
||||
return f"{v.getattr('type')}.{v.getattr('format')}"
|
||||
|
||||
|
||||
AnyModelConfig = Annotated[
|
||||
Union[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
Annotated[VaeDiffusersConfig, VaeDiffusersConfig.get_tag()],
|
||||
Annotated[VaeCheckpointConfig, VaeCheckpointConfig.get_tag()],
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALycorisConfig, LoRALycorisConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
class ModelRecord(BaseModel):
|
||||
"""A record of a model in the database."""
|
||||
|
||||
# Internal DB/record data
|
||||
id: str = Field(description="The unique identifier of the model") # Primary Key
|
||||
config: AnyModelConfig = Field(description="The configuration of the model")
|
||||
source: str = Field(
|
||||
description="The original source of the model (path, URL or repo_id)",
|
||||
frozen=True, # This field is immutable
|
||||
)
|
||||
source_type: ModelSourceType = Field(
|
||||
description="The type of the source of the model",
|
||||
frozen=True, # This field is immutable
|
||||
)
|
||||
source_api_response: Optional[JsonValue] = Field(
|
||||
description="The original API response from which the model was installed.",
|
||||
default=None,
|
||||
frozen=True, # This field is immutable
|
||||
)
|
||||
created_at: datetime | str = Field(description="When the model was created")
|
||||
updated_at: datetime | str = Field(description="When the model was last updated")
|
||||
|
||||
|
||||
class ModelConfigFactory(object):
|
||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||
|
||||
@classmethod
|
||||
def make_config(
|
||||
cls,
|
||||
model_data: Union[Dict[str, Any], AnyModelConfig],
|
||||
key: Optional[str] = None,
|
||||
dest_class: Optional[Type[_ModelConfigBase]] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Return the appropriate config object from raw dict values.
|
||||
|
||||
:param model_data: A raw dict corresponding the obect fields to be
|
||||
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
|
||||
object, which will be passed through unchanged.
|
||||
:param dest_class: The config class to be returned. If not provided, will
|
||||
be selected automatically.
|
||||
"""
|
||||
model: Optional[_ModelConfigBase] = None
|
||||
if isinstance(model_data, _ModelConfigBase):
|
||||
model = model_data
|
||||
elif dest_class:
|
||||
model = dest_class.model_validate(model_data)
|
||||
else:
|
||||
# mypy doesn't typecheck TypeAdapters well?
|
||||
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
if timestamp:
|
||||
model.last_modified = timestamp
|
||||
return model # type: ignore
|
@ -7,60 +7,138 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
|
||||
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
||||
'a8e693a126ea5b831c96064dc569956f'
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
from typing import Callable, Literal, Union
|
||||
|
||||
from imohash import hashfile
|
||||
from blake3 import blake3
|
||||
|
||||
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||
|
||||
ALGORITHM = Literal[
|
||||
"md5",
|
||||
"sha1",
|
||||
"sha1_fast",
|
||||
"sha224",
|
||||
"sha256",
|
||||
"sha384",
|
||||
"sha512",
|
||||
"blake2b",
|
||||
"blake2s",
|
||||
"sha3_224",
|
||||
"sha3_256",
|
||||
"sha3_384",
|
||||
"sha3_512",
|
||||
"shake_128",
|
||||
"shake_256",
|
||||
"blake3",
|
||||
]
|
||||
|
||||
|
||||
class FastModelHash(object):
|
||||
"""FastModelHash obect provides one public class method, hash()."""
|
||||
class ModelHash:
|
||||
"""
|
||||
Creates a hash of a model using a specified algorithm.
|
||||
|
||||
@classmethod
|
||||
def hash(cls, model_location: Union[str, Path]) -> str:
|
||||
:param algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||
|
||||
If the model is a single file, it is hashed directly using the provided algorithm.
|
||||
|
||||
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
|
||||
|
||||
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
|
||||
|
||||
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
||||
that directory hashes are never weaker than the file hashes.
|
||||
|
||||
Usage
|
||||
|
||||
```py
|
||||
ModelHash().hash("path/to/some/model.safetensors")
|
||||
ModelHash("md5").hash("path/to/model/dir/")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, algorithm: ALGORITHM = "blake3") -> None:
|
||||
if algorithm == "blake3":
|
||||
self._hash_file = self._blake3
|
||||
elif algorithm == "sha1_fast":
|
||||
self._hash_file = self._sha1_fast
|
||||
elif algorithm in hashlib.algorithms_available:
|
||||
self._hash_file = self._get_hashlib(algorithm)
|
||||
else:
|
||||
raise ValueError(f"Algorithm {algorithm} not available")
|
||||
|
||||
def hash(self, model_location: Union[str, Path]) -> str:
|
||||
"""
|
||||
Return hexdigest string for model located at model_location.
|
||||
|
||||
If model_location is a directory, the hash is computed by hashing the hashes of all model files in the
|
||||
directory. The final composite hash is always computed using BLAKE3.
|
||||
|
||||
:param model_location: Path to the model
|
||||
"""
|
||||
|
||||
model_location = Path(model_location)
|
||||
if model_location.is_file():
|
||||
return cls._hash_file(model_location)
|
||||
return self._hash_file(model_location)
|
||||
elif model_location.is_dir():
|
||||
return cls._hash_dir(model_location)
|
||||
return self._hash_dir(model_location)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
||||
|
||||
@classmethod
|
||||
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
||||
"""
|
||||
Fasthash a single file and return its hexdigest.
|
||||
def _hash_dir(self, model_location: Path) -> str:
|
||||
"""Compute the hash for all files in a directory and return a hexdigest."""
|
||||
model_component_paths = self._get_file_paths(model_location)
|
||||
|
||||
:param model_location: Path to the model file
|
||||
"""
|
||||
# we return md5 hash of the filehash to make it shorter
|
||||
# cryptographic security not needed here
|
||||
return hashlib.md5(hashfile(model_location)).hexdigest()
|
||||
component_hashes: list[str] = []
|
||||
for component in sorted(model_component_paths):
|
||||
component_hashes.append(self._hash_file(component))
|
||||
|
||||
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
||||
# for the composite hash
|
||||
composite_hasher = blake3()
|
||||
for h in component_hashes:
|
||||
composite_hasher.update(h.encode("utf-8"))
|
||||
return composite_hasher.hexdigest()
|
||||
|
||||
@classmethod
|
||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
||||
components: Dict[str, str] = {}
|
||||
def _get_file_paths(cls, dir: Path) -> list[Path]:
|
||||
"""Return a list of all model files in the directory."""
|
||||
files: list[Path] = []
|
||||
for root, _dirs, _files in os.walk(dir):
|
||||
for file in _files:
|
||||
if file.endswith(MODEL_FILE_EXTENSIONS):
|
||||
files.append(Path(root, file))
|
||||
return files
|
||||
|
||||
for root, _dirs, files in os.walk(model_location):
|
||||
for file in files:
|
||||
# only tally tensor files because diffusers config files change slightly
|
||||
# depending on how the model was downloaded/converted.
|
||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
||||
continue
|
||||
path = (Path(root) / file).as_posix()
|
||||
fast_hash = cls._hash_file(path)
|
||||
components.update({path: fast_hash})
|
||||
@staticmethod
|
||||
def _blake3(file_path: Path) -> str:
|
||||
"""Hashes a file using BLAKE3"""
|
||||
file_hasher = blake3(max_threads=blake3.AUTO)
|
||||
file_hasher.update_mmap(file_path)
|
||||
return file_hasher.hexdigest()
|
||||
|
||||
# hash all the model hashes together, using alphabetic file order
|
||||
md5 = hashlib.md5()
|
||||
for _path, fast_hash in sorted(components.items()):
|
||||
md5.update(fast_hash.encode("utf-8"))
|
||||
return md5.hexdigest()
|
||||
@staticmethod
|
||||
def _sha1_fast(file_path: Path) -> str:
|
||||
"""Hashes a file using SHA1, but with a block size of 2**16.
|
||||
The result is not a correct SHA1 hash for the file, due to the padding introduced by the block size.
|
||||
The algorithm is, however, very fast."""
|
||||
BLOCK_SIZE = 2**16
|
||||
file_hash = hashlib.sha1()
|
||||
with open(file_path, "rb") as f:
|
||||
data = f.read(BLOCK_SIZE)
|
||||
file_hash.update(data)
|
||||
return file_hash.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
||||
"""Hashes a file using a hashlib algorithm"""
|
||||
|
||||
def hasher(file_path: Path) -> str:
|
||||
file_hasher = hashlib.new(algorithm)
|
||||
with open(file_path, "rb") as f:
|
||||
file_hasher.update(f.read())
|
||||
return file_hasher.hexdigest()
|
||||
|
||||
return hasher
|
||||
|
@ -160,7 +160,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
nsfw=model_json["nsfw"],
|
||||
restrictions=LicenseRestrictions(
|
||||
AllowNoCredit=model_json["allowNoCredit"],
|
||||
AllowCommercialUse=CommercialUsage(model_json["allowCommercialUse"]),
|
||||
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
|
||||
AllowDerivatives=model_json["allowDerivatives"],
|
||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||
),
|
||||
|
@ -54,8 +54,8 @@ class LicenseRestrictions(BaseModel):
|
||||
AllowDifferentLicense: bool = Field(
|
||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||
)
|
||||
AllowCommercialUse: Optional[CommercialUsage] = Field(
|
||||
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default=None
|
||||
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
|
||||
description="Type of commercial use allowed if no commercial use is allowed.", default=None
|
||||
)
|
||||
|
||||
|
||||
@ -142,7 +142,10 @@ class CivitaiMetadata(ModelMetadataWithFiles):
|
||||
if self.restrictions.AllowCommercialUse is None:
|
||||
return False
|
||||
else:
|
||||
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
|
||||
# accommodate schema change
|
||||
acu = self.restrictions.AllowCommercialUse
|
||||
commercial_usage = acu if isinstance(acu, set) else {acu}
|
||||
return CommercialUsage.No not in commercial_usage
|
||||
|
||||
@property
|
||||
def allow_derivatives(self) -> bool:
|
||||
|
@ -21,7 +21,7 @@ from .config import (
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .hash import FastModelHash
|
||||
from .hash import ModelHash
|
||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
|
||||
CkptType = Dict[str, Any]
|
||||
@ -147,7 +147,7 @@ class ModelProbe(object):
|
||||
if not probe_class:
|
||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||
|
||||
hash = FastModelHash.hash(model_path)
|
||||
hash = ModelHash().hash(model_path)
|
||||
probe = probe_class(model_path)
|
||||
|
||||
fields["path"] = model_path.as_posix()
|
||||
|
@ -64,6 +64,7 @@ dependencies = [
|
||||
|
||||
# Auxiliary dependencies, pinned only if necessary.
|
||||
"albumentations",
|
||||
"blake3",
|
||||
"click",
|
||||
"datasets",
|
||||
"Deprecated",
|
||||
@ -72,7 +73,6 @@ dependencies = [
|
||||
"easing-functions",
|
||||
"einops",
|
||||
"facexlib",
|
||||
"imohash",
|
||||
"matplotlib", # needed for plotting of Penner easing functions
|
||||
"npyscreen",
|
||||
"omegaconf",
|
||||
|
@ -31,7 +31,7 @@ def test_registration(mm2_installer: ModelInstallServiceBase, embedding_file: Pa
|
||||
assert len(matches) == 0
|
||||
key = mm2_installer.register_path(embedding_file)
|
||||
assert key is not None
|
||||
assert len(key) == 32
|
||||
assert len(key) == 40 # length of the sha1 hash
|
||||
|
||||
|
||||
def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
|
File diff suppressed because one or more lines are too long
@ -133,7 +133,7 @@ def test_metadata_civitai_fetch(mm2_session: Session) -> None:
|
||||
assert metadata.id == 215485
|
||||
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
||||
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
||||
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
|
||||
assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse
|
||||
assert metadata.version_id == 242807
|
||||
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
||||
|
||||
|
80
tests/test_model_hash.py
Normal file
80
tests/test_model_hash.py
Normal file
@ -0,0 +1,80 @@
|
||||
# pyright:reportPrivateUsage=false
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import pytest
|
||||
from blake3 import blake3
|
||||
|
||||
from invokeai.backend.model_manager.hash import ALGORITHM, ModelHash
|
||||
|
||||
test_cases: list[tuple[ALGORITHM, str]] = [
|
||||
("md5", "a0cd925fc063f98dbf029eee315060c3"),
|
||||
("sha1", "9e362940e5603fdc60566ea100a288ba2fe48b8c"),
|
||||
("sha256", "6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"),
|
||||
(
|
||||
"sha512",
|
||||
"c4a10476b21e00042f638ad5755c561d91f2bb599d3504d25409495e1c7eda94543332a1a90fbb4efdaf9ee462c33e0336b5eae4acfb1fa0b186af452dd67dc6",
|
||||
),
|
||||
("blake3", "ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algorithm,expected_hash", test_cases)
|
||||
def test_model_hash_hashes_file(tmp_path: Path, algorithm: ALGORITHM, expected_hash: str):
|
||||
file = Path(tmp_path / "test")
|
||||
file.write_text("model data")
|
||||
md5 = ModelHash(algorithm).hash(file)
|
||||
assert md5 == expected_hash
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3"])
|
||||
def test_model_hash_hashes_dir(tmp_path: Path, algorithm: ALGORITHM):
|
||||
model_hash = ModelHash(algorithm)
|
||||
files = [Path(tmp_path, f"{i}.bin") for i in range(5)]
|
||||
|
||||
for f in files:
|
||||
f.write_text("data")
|
||||
|
||||
md5 = model_hash.hash(tmp_path)
|
||||
|
||||
# Manual implementation of composite hash - always uses BLAKE3
|
||||
composite_hasher = blake3()
|
||||
for f in files:
|
||||
h = model_hash.hash(f)
|
||||
composite_hasher.update(h.encode("utf-8"))
|
||||
|
||||
assert md5 == composite_hasher.hexdigest()
|
||||
|
||||
|
||||
def test_model_hash_raises_error_on_invalid_algorithm():
|
||||
with pytest.raises(ValueError, match="Algorithm invalid_algorithm not available"):
|
||||
ModelHash("invalid_algorithm") # pyright: ignore [reportArgumentType]
|
||||
|
||||
|
||||
def paths_to_str_set(paths: Iterable[Path]) -> set[str]:
|
||||
return {str(p) for p in paths}
|
||||
|
||||
|
||||
def test_model_hash_filters_out_non_model_files(tmp_path: Path):
|
||||
model_files = {
|
||||
Path(tmp_path, f"{i}.{ext}") for i, ext in enumerate([".ckpt", ".safetensors", ".bin", ".pt", ".pth"])
|
||||
}
|
||||
|
||||
for i, f in enumerate(model_files):
|
||||
f.write_text(f"data{i}")
|
||||
|
||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
|
||||
|
||||
# Add file that should be ignored - hash should not change
|
||||
file = tmp_path / "test.icecream"
|
||||
file.write_text("data")
|
||||
|
||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
|
||||
|
||||
# Add file that should not be ignored - hash should change
|
||||
file = tmp_path / "test.bin"
|
||||
file.write_text("more data")
|
||||
model_files.add(file)
|
||||
|
||||
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
|
Reference in New Issue
Block a user