mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
9 Commits
next-fix-t
...
psyche/mm/
Author | SHA1 | Date | |
---|---|---|---|
bd222454cd | |||
ee78412aaa | |||
5b675d8481 | |||
3493ae7cdb | |||
4ee52ed689 | |||
9b4f1126c0 | |||
7b2f81babb | |||
2b1cb569eb | |||
98a13aa7dc |
@ -7,7 +7,6 @@ import time
|
|||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
from random import randbytes
|
|
||||||
from shutil import copyfile, copytree, move, rmtree
|
from shutil import copyfile, copytree, move, rmtree
|
||||||
from tempfile import mkdtemp
|
from tempfile import mkdtemp
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
@ -28,7 +27,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelRepoVariant,
|
ModelRepoVariant,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.hash import FastModelHash
|
from invokeai.backend.model_manager.hash import ModelHash
|
||||||
from invokeai.backend.model_manager.metadata import (
|
from invokeai.backend.model_manager.metadata import (
|
||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
CivitaiMetadataFetch,
|
CivitaiMetadataFetch,
|
||||||
@ -167,7 +166,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
raise DuplicateModelException(
|
raise DuplicateModelException(
|
||||||
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
|
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
|
||||||
) from excp
|
) from excp
|
||||||
new_hash = FastModelHash.hash(new_path)
|
new_hash = ModelHash().hash(new_path)
|
||||||
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||||
|
|
||||||
return self._register(
|
return self._register(
|
||||||
@ -469,7 +468,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
new_path = models_dir / model.base.value / model.type.value / model.name
|
new_path = models_dir / model.base.value / model.type.value / model.name
|
||||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||||
new_path = self._move_model(old_path, new_path)
|
new_path = self._move_model(old_path, new_path)
|
||||||
new_hash = FastModelHash.hash(new_path)
|
new_hash = ModelHash().hash(new_path)
|
||||||
model.path = new_path.relative_to(models_dir).as_posix()
|
model.path = new_path.relative_to(models_dir).as_posix()
|
||||||
if model.current_hash != new_hash:
|
if model.current_hash != new_hash:
|
||||||
assert (
|
assert (
|
||||||
@ -536,16 +535,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
setattr(info, key, value)
|
setattr(info, key, value)
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def _create_key(self) -> str:
|
|
||||||
return sha256(randbytes(100)).hexdigest()[0:32]
|
|
||||||
|
|
||||||
def _register(
|
def _register(
|
||||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
key = self._create_key()
|
# the model key is either the forced key specified in config,
|
||||||
if config and not config.get("key", None):
|
# or it is the file/directory hash computed by probe
|
||||||
config["key"] = key
|
|
||||||
info = info or ModelProbe.probe(model_path, config)
|
info = info or ModelProbe.probe(model_path, config)
|
||||||
|
override_key: Optional[str] = config.get("key") if config else None
|
||||||
|
|
||||||
|
assert info.original_hash # always assigned by probe()
|
||||||
|
info.key = override_key or info.original_hash
|
||||||
|
|
||||||
model_path = model_path.absolute()
|
model_path = model_path.absolute()
|
||||||
if model_path.is_relative_to(self.app_config.models_path):
|
if model_path.is_relative_to(self.app_config.models_path):
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from hashlib import sha1
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -22,7 +21,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelConfigFactory,
|
ModelConfigFactory,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.hash import FastModelHash
|
from invokeai.backend.model_manager.hash import ModelHash
|
||||||
|
|
||||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||||
|
|
||||||
@ -73,19 +72,27 @@ class MigrateModelYamlToDb1:
|
|||||||
|
|
||||||
base_type, model_type, model_name = str(model_key).split("/")
|
base_type, model_type, model_name = str(model_key).split("/")
|
||||||
try:
|
try:
|
||||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
hash = ModelHash().hash(self.config.models_path / stanza.path)
|
||||||
except OSError:
|
except OSError:
|
||||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assert isinstance(model_key, str)
|
|
||||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
|
||||||
|
|
||||||
stanza["base"] = BaseModelType(base_type)
|
stanza["base"] = BaseModelType(base_type)
|
||||||
stanza["type"] = ModelType(model_type)
|
stanza["type"] = ModelType(model_type)
|
||||||
stanza["name"] = model_name
|
stanza["name"] = model_name
|
||||||
stanza["original_hash"] = hash
|
stanza["original_hash"] = hash
|
||||||
stanza["current_hash"] = hash
|
stanza["current_hash"] = hash
|
||||||
|
new_key = hash # deterministic key assignment
|
||||||
|
|
||||||
|
# special case for ip adapters, which need the new `image_encoder_model_id` field
|
||||||
|
if stanza["type"] == ModelType.IPAdapter:
|
||||||
|
try:
|
||||||
|
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
|
||||||
|
self.config.models_path / stanza.path
|
||||||
|
)
|
||||||
|
except OSError:
|
||||||
|
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||||
|
|
||||||
@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
|
|||||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||||
self._update_model(key, new_config)
|
self._update_model(key, new_config)
|
||||||
else:
|
else:
|
||||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
self.logger.info(f"Adding model {model_name} with key {new_key}")
|
||||||
self._add_model(new_key, new_config)
|
self._add_model(new_key, new_config)
|
||||||
except DuplicateModelException:
|
except DuplicateModelException:
|
||||||
self.logger.warning(f"Model {model_name} is already in the database")
|
self.logger.warning(f"Model {model_name} is already in the database")
|
||||||
@ -149,3 +156,8 @@ class MigrateModelYamlToDb1:
|
|||||||
)
|
)
|
||||||
except sqlite3.IntegrityError as exc:
|
except sqlite3.IntegrityError as exc:
|
||||||
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
||||||
|
|
||||||
|
def _get_image_encoder_model_id(self, model_path: Path) -> str:
|
||||||
|
with open(model_path / "image_encoder.txt") as f:
|
||||||
|
encoder = f.read()
|
||||||
|
return encoder.strip()
|
||||||
|
400
invokeai/backend/model_manager/config_new.py
Normal file
400
invokeai/backend/model_manager/config_new.py
Normal file
@ -0,0 +1,400 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""
|
||||||
|
Configuration definitions for image generation models.
|
||||||
|
|
||||||
|
Typical usage:
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager import ModelConfigFactory
|
||||||
|
raw = dict(path='models/sd-1/main/foo.ckpt',
|
||||||
|
name='foo',
|
||||||
|
base='sd-1',
|
||||||
|
type='main',
|
||||||
|
config='configs/stable-diffusion/v1-inference.yaml',
|
||||||
|
variant='normal',
|
||||||
|
format='checkpoint'
|
||||||
|
)
|
||||||
|
config = ModelConfigFactory.make_config(raw)
|
||||||
|
print(config.name)
|
||||||
|
|
||||||
|
Validation errors will raise an InvalidModelConfigException error.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Literal, Optional, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
|
from pydantic import BaseModel, Discriminator, Field, JsonValue, Tag, TypeAdapter
|
||||||
|
from typing_extensions import Annotated, Any, Dict
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.hash import ALGORITHM, ModelHash
|
||||||
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
|
# ModelMixin is the base class for all diffusers and transformers models
|
||||||
|
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||||
|
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidModelConfigException(Exception):
|
||||||
|
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSourceType(str, Enum):
|
||||||
|
"""The source of the model."""
|
||||||
|
|
||||||
|
HF_REPO_ID = "hf_repo_id"
|
||||||
|
CIVITAI = "civitai"
|
||||||
|
URL = "url"
|
||||||
|
PATH = "path"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelType(str, Enum):
|
||||||
|
"""Base model type."""
|
||||||
|
|
||||||
|
Any = "any"
|
||||||
|
StableDiffusion1 = "sd-1"
|
||||||
|
StableDiffusion2 = "sd-2"
|
||||||
|
StableDiffusionXL = "sdxl"
|
||||||
|
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||||
|
# Kandinsky2_1 = "kandinsky-2.1"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(str, Enum):
|
||||||
|
"""Model type."""
|
||||||
|
|
||||||
|
ONNX = "onnx"
|
||||||
|
Main = "main"
|
||||||
|
Vae = "vae"
|
||||||
|
Lora = "lora"
|
||||||
|
ControlNet = "controlnet" # used by model_probe
|
||||||
|
TextualInversion = "embedding"
|
||||||
|
IPAdapter = "ip_adapter"
|
||||||
|
CLIPVision = "clip_vision"
|
||||||
|
T2IAdapter = "t2i_adapter"
|
||||||
|
|
||||||
|
|
||||||
|
class SubModelType(str, Enum):
|
||||||
|
"""Submodel type."""
|
||||||
|
|
||||||
|
UNet = "unet"
|
||||||
|
TextEncoder = "text_encoder"
|
||||||
|
TextEncoder2 = "text_encoder_2"
|
||||||
|
Tokenizer = "tokenizer"
|
||||||
|
Tokenizer2 = "tokenizer_2"
|
||||||
|
Vae = "vae"
|
||||||
|
VaeDecoder = "vae_decoder"
|
||||||
|
VaeEncoder = "vae_encoder"
|
||||||
|
Scheduler = "scheduler"
|
||||||
|
SafetyChecker = "safety_checker"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelVariantType(str, Enum):
|
||||||
|
"""Variant type."""
|
||||||
|
|
||||||
|
Normal = "normal"
|
||||||
|
Inpaint = "inpaint"
|
||||||
|
Depth = "depth"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFormat(str, Enum):
|
||||||
|
"""Storage format of model."""
|
||||||
|
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
Checkpoint = "checkpoint"
|
||||||
|
Lycoris = "lycoris"
|
||||||
|
Onnx = "onnx"
|
||||||
|
Olive = "olive"
|
||||||
|
EmbeddingFile = "embedding_file"
|
||||||
|
EmbeddingFolder = "embedding_folder"
|
||||||
|
InvokeAI = "invokeai"
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerPredictionType(str, Enum):
|
||||||
|
"""Scheduler prediction type."""
|
||||||
|
|
||||||
|
Epsilon = "epsilon"
|
||||||
|
VPrediction = "v_prediction"
|
||||||
|
Sample = "sample"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRepoVariant(str, Enum):
|
||||||
|
"""Various hugging face variants on the diffusers format."""
|
||||||
|
|
||||||
|
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
|
||||||
|
FP16 = "fp16"
|
||||||
|
FP32 = "fp32"
|
||||||
|
ONNX = "onnx"
|
||||||
|
OPENVINO = "openvino"
|
||||||
|
FLAX = "flax"
|
||||||
|
|
||||||
|
|
||||||
|
class _ModelConfigBase(BaseModel):
|
||||||
|
"""The configuration of a model."""
|
||||||
|
|
||||||
|
id: str = Field(description="The unique identifier of the model") # Primary Key
|
||||||
|
hash: str = Field(description="The BLAKE3 hash of the model.", frozen=True)
|
||||||
|
base: BaseModelType = Field(description="The base of the model")
|
||||||
|
path: str = Field(description="The path of the model")
|
||||||
|
name: str = Field(description="The name of the model")
|
||||||
|
description: Optional[str] = Field(description="The description of the model", default=None)
|
||||||
|
|
||||||
|
def compute_hash(self, algorithm: ALGORITHM = "blake3") -> str:
|
||||||
|
"""Compute the hash of the model."""
|
||||||
|
return ModelHash(algorithm).hash(self.path)
|
||||||
|
|
||||||
|
|
||||||
|
class _CheckpointConfig(_ModelConfigBase):
|
||||||
|
"""Model config for checkpoint-style models."""
|
||||||
|
|
||||||
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||||
|
config_path: str = Field(description="Path to the checkpoint model config file")
|
||||||
|
|
||||||
|
|
||||||
|
class _DiffusersConfig(_ModelConfigBase):
|
||||||
|
"""Model config for diffusers-style models."""
|
||||||
|
|
||||||
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALycorisConfig(_ModelConfigBase):
|
||||||
|
"""Model config for LoRA/Lycoris models."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||||
|
format: Literal[ModelFormat.Lycoris] = ModelFormat.Lycoris
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.Lora}.{ModelFormat.Lycoris}")
|
||||||
|
|
||||||
|
|
||||||
|
class LoRADiffusersConfig(_ModelConfigBase):
|
||||||
|
"""Model config for LoRA/Diffusers models."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||||
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}")
|
||||||
|
|
||||||
|
|
||||||
|
class VaeCheckpointConfig(_ModelConfigBase):
|
||||||
|
"""Model config for standalone VAE models."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||||
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.Vae}.{ModelFormat.Checkpoint}")
|
||||||
|
|
||||||
|
|
||||||
|
class VaeDiffusersConfig(_ModelConfigBase):
|
||||||
|
"""Model config for standalone VAE models (diffusers version)."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||||
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.Vae}.{ModelFormat.Diffusers}")
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||||
|
"""Model config for ControlNet models (diffusers version)."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||||
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Diffusers}")
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||||
|
"""Model config for ControlNet models (diffusers version)."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||||
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Checkpoint}")
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionFileConfig(_ModelConfigBase):
|
||||||
|
"""Model config for textual inversion embeddings."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||||
|
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFile}")
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionFolderConfig(_ModelConfigBase):
|
||||||
|
"""Model config for textual inversion embeddings."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||||
|
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}")
|
||||||
|
|
||||||
|
|
||||||
|
class _MainConfig(_ModelConfigBase):
|
||||||
|
"""Model config for main models."""
|
||||||
|
|
||||||
|
variant: ModelVariantType = ModelVariantType.Normal
|
||||||
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||||
|
upcast_attention: bool = False
|
||||||
|
ztsnr_training: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||||
|
"""Model config for main checkpoint models."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.Main] = ModelType.Main
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}")
|
||||||
|
|
||||||
|
|
||||||
|
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||||
|
"""Model config for main diffusers models."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.Main] = ModelType.Main
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.Main}.{ModelFormat.Diffusers}")
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterConfig(_ModelConfigBase):
|
||||||
|
"""Model config for IP Adaptor format models."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
||||||
|
image_encoder_model_id: str
|
||||||
|
format: Literal[ModelFormat.InvokeAI]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.IPAdapter}.{ModelFormat.InvokeAI}")
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVisionDiffusersConfig(_ModelConfigBase):
|
||||||
|
"""Model config for ClipVision."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||||
|
format: Literal[ModelFormat.Diffusers]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.CLIPVision}.{ModelFormat.Diffusers}")
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterConfig(_ModelConfigBase):
|
||||||
|
"""Model config for T2I."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||||
|
format: Literal[ModelFormat.Diffusers]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.T2IAdapter}.{ModelFormat.Diffusers}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_discriminator_value(v: Any) -> str:
|
||||||
|
"""
|
||||||
|
Computes the discriminator value for a model config.
|
||||||
|
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
||||||
|
"""
|
||||||
|
if isinstance(v, dict):
|
||||||
|
return f"{v.get('type')}.{v.get('format')}" # pyright: ignore [reportUnknownMemberType]
|
||||||
|
return f"{v.getattr('type')}.{v.getattr('format')}"
|
||||||
|
|
||||||
|
|
||||||
|
AnyModelConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||||
|
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||||
|
Annotated[VaeDiffusersConfig, VaeDiffusersConfig.get_tag()],
|
||||||
|
Annotated[VaeCheckpointConfig, VaeCheckpointConfig.get_tag()],
|
||||||
|
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||||
|
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||||
|
Annotated[LoRALycorisConfig, LoRALycorisConfig.get_tag()],
|
||||||
|
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||||
|
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||||
|
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||||
|
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
|
||||||
|
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||||
|
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||||
|
],
|
||||||
|
Discriminator(get_model_discriminator_value),
|
||||||
|
]
|
||||||
|
|
||||||
|
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRecord(BaseModel):
|
||||||
|
"""A record of a model in the database."""
|
||||||
|
|
||||||
|
# Internal DB/record data
|
||||||
|
id: str = Field(description="The unique identifier of the model") # Primary Key
|
||||||
|
config: AnyModelConfig = Field(description="The configuration of the model")
|
||||||
|
source: str = Field(
|
||||||
|
description="The original source of the model (path, URL or repo_id)",
|
||||||
|
frozen=True, # This field is immutable
|
||||||
|
)
|
||||||
|
source_type: ModelSourceType = Field(
|
||||||
|
description="The type of the source of the model",
|
||||||
|
frozen=True, # This field is immutable
|
||||||
|
)
|
||||||
|
source_api_response: Optional[JsonValue] = Field(
|
||||||
|
description="The original API response from which the model was installed.",
|
||||||
|
default=None,
|
||||||
|
frozen=True, # This field is immutable
|
||||||
|
)
|
||||||
|
created_at: datetime | str = Field(description="When the model was created")
|
||||||
|
updated_at: datetime | str = Field(description="When the model was last updated")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfigFactory(object):
|
||||||
|
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_config(
|
||||||
|
cls,
|
||||||
|
model_data: Union[Dict[str, Any], AnyModelConfig],
|
||||||
|
key: Optional[str] = None,
|
||||||
|
dest_class: Optional[Type[_ModelConfigBase]] = None,
|
||||||
|
timestamp: Optional[float] = None,
|
||||||
|
) -> AnyModelConfig:
|
||||||
|
"""
|
||||||
|
Return the appropriate config object from raw dict values.
|
||||||
|
|
||||||
|
:param model_data: A raw dict corresponding the obect fields to be
|
||||||
|
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
|
||||||
|
object, which will be passed through unchanged.
|
||||||
|
:param dest_class: The config class to be returned. If not provided, will
|
||||||
|
be selected automatically.
|
||||||
|
"""
|
||||||
|
model: Optional[_ModelConfigBase] = None
|
||||||
|
if isinstance(model_data, _ModelConfigBase):
|
||||||
|
model = model_data
|
||||||
|
elif dest_class:
|
||||||
|
model = dest_class.model_validate(model_data)
|
||||||
|
else:
|
||||||
|
# mypy doesn't typecheck TypeAdapters well?
|
||||||
|
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
||||||
|
assert model is not None
|
||||||
|
if key:
|
||||||
|
model.key = key
|
||||||
|
if timestamp:
|
||||||
|
model.last_modified = timestamp
|
||||||
|
return model # type: ignore
|
@ -7,60 +7,138 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
|
|||||||
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
||||||
'a8e693a126ea5b831c96064dc569956f'
|
'a8e693a126ea5b831c96064dc569956f'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Union
|
from typing import Callable, Literal, Union
|
||||||
|
|
||||||
from imohash import hashfile
|
from blake3 import blake3
|
||||||
|
|
||||||
|
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||||
|
|
||||||
|
ALGORITHM = Literal[
|
||||||
|
"md5",
|
||||||
|
"sha1",
|
||||||
|
"sha1_fast",
|
||||||
|
"sha224",
|
||||||
|
"sha256",
|
||||||
|
"sha384",
|
||||||
|
"sha512",
|
||||||
|
"blake2b",
|
||||||
|
"blake2s",
|
||||||
|
"sha3_224",
|
||||||
|
"sha3_256",
|
||||||
|
"sha3_384",
|
||||||
|
"sha3_512",
|
||||||
|
"shake_128",
|
||||||
|
"shake_256",
|
||||||
|
"blake3",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class FastModelHash(object):
|
class ModelHash:
|
||||||
"""FastModelHash obect provides one public class method, hash()."""
|
"""
|
||||||
|
Creates a hash of a model using a specified algorithm.
|
||||||
|
|
||||||
@classmethod
|
:param algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||||
def hash(cls, model_location: Union[str, Path]) -> str:
|
|
||||||
|
If the model is a single file, it is hashed directly using the provided algorithm.
|
||||||
|
|
||||||
|
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
|
||||||
|
|
||||||
|
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
|
||||||
|
|
||||||
|
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
||||||
|
that directory hashes are never weaker than the file hashes.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
|
||||||
|
```py
|
||||||
|
ModelHash().hash("path/to/some/model.safetensors")
|
||||||
|
ModelHash("md5").hash("path/to/model/dir/")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, algorithm: ALGORITHM = "blake3") -> None:
|
||||||
|
if algorithm == "blake3":
|
||||||
|
self._hash_file = self._blake3
|
||||||
|
elif algorithm == "sha1_fast":
|
||||||
|
self._hash_file = self._sha1_fast
|
||||||
|
elif algorithm in hashlib.algorithms_available:
|
||||||
|
self._hash_file = self._get_hashlib(algorithm)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Algorithm {algorithm} not available")
|
||||||
|
|
||||||
|
def hash(self, model_location: Union[str, Path]) -> str:
|
||||||
"""
|
"""
|
||||||
Return hexdigest string for model located at model_location.
|
Return hexdigest string for model located at model_location.
|
||||||
|
|
||||||
|
If model_location is a directory, the hash is computed by hashing the hashes of all model files in the
|
||||||
|
directory. The final composite hash is always computed using BLAKE3.
|
||||||
|
|
||||||
:param model_location: Path to the model
|
:param model_location: Path to the model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_location = Path(model_location)
|
model_location = Path(model_location)
|
||||||
if model_location.is_file():
|
if model_location.is_file():
|
||||||
return cls._hash_file(model_location)
|
return self._hash_file(model_location)
|
||||||
elif model_location.is_dir():
|
elif model_location.is_dir():
|
||||||
return cls._hash_dir(model_location)
|
return self._hash_dir(model_location)
|
||||||
else:
|
else:
|
||||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
raise OSError(f"Not a valid file or directory: {model_location}")
|
||||||
|
|
||||||
@classmethod
|
def _hash_dir(self, model_location: Path) -> str:
|
||||||
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
"""Compute the hash for all files in a directory and return a hexdigest."""
|
||||||
"""
|
model_component_paths = self._get_file_paths(model_location)
|
||||||
Fasthash a single file and return its hexdigest.
|
|
||||||
|
|
||||||
:param model_location: Path to the model file
|
component_hashes: list[str] = []
|
||||||
"""
|
for component in sorted(model_component_paths):
|
||||||
# we return md5 hash of the filehash to make it shorter
|
component_hashes.append(self._hash_file(component))
|
||||||
# cryptographic security not needed here
|
|
||||||
return hashlib.md5(hashfile(model_location)).hexdigest()
|
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
|
||||||
|
# for the composite hash
|
||||||
|
composite_hasher = blake3()
|
||||||
|
for h in component_hashes:
|
||||||
|
composite_hasher.update(h.encode("utf-8"))
|
||||||
|
return composite_hasher.hexdigest()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
def _get_file_paths(cls, dir: Path) -> list[Path]:
|
||||||
components: Dict[str, str] = {}
|
"""Return a list of all model files in the directory."""
|
||||||
|
files: list[Path] = []
|
||||||
|
for root, _dirs, _files in os.walk(dir):
|
||||||
|
for file in _files:
|
||||||
|
if file.endswith(MODEL_FILE_EXTENSIONS):
|
||||||
|
files.append(Path(root, file))
|
||||||
|
return files
|
||||||
|
|
||||||
for root, _dirs, files in os.walk(model_location):
|
@staticmethod
|
||||||
for file in files:
|
def _blake3(file_path: Path) -> str:
|
||||||
# only tally tensor files because diffusers config files change slightly
|
"""Hashes a file using BLAKE3"""
|
||||||
# depending on how the model was downloaded/converted.
|
file_hasher = blake3(max_threads=blake3.AUTO)
|
||||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
file_hasher.update_mmap(file_path)
|
||||||
continue
|
return file_hasher.hexdigest()
|
||||||
path = (Path(root) / file).as_posix()
|
|
||||||
fast_hash = cls._hash_file(path)
|
|
||||||
components.update({path: fast_hash})
|
|
||||||
|
|
||||||
# hash all the model hashes together, using alphabetic file order
|
@staticmethod
|
||||||
md5 = hashlib.md5()
|
def _sha1_fast(file_path: Path) -> str:
|
||||||
for _path, fast_hash in sorted(components.items()):
|
"""Hashes a file using SHA1, but with a block size of 2**16.
|
||||||
md5.update(fast_hash.encode("utf-8"))
|
The result is not a correct SHA1 hash for the file, due to the padding introduced by the block size.
|
||||||
return md5.hexdigest()
|
The algorithm is, however, very fast."""
|
||||||
|
BLOCK_SIZE = 2**16
|
||||||
|
file_hash = hashlib.sha1()
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
data = f.read(BLOCK_SIZE)
|
||||||
|
file_hash.update(data)
|
||||||
|
return file_hash.hexdigest()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
||||||
|
"""Hashes a file using a hashlib algorithm"""
|
||||||
|
|
||||||
|
def hasher(file_path: Path) -> str:
|
||||||
|
file_hasher = hashlib.new(algorithm)
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
file_hasher.update(f.read())
|
||||||
|
return file_hasher.hexdigest()
|
||||||
|
|
||||||
|
return hasher
|
||||||
|
@ -160,7 +160,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
nsfw=model_json["nsfw"],
|
nsfw=model_json["nsfw"],
|
||||||
restrictions=LicenseRestrictions(
|
restrictions=LicenseRestrictions(
|
||||||
AllowNoCredit=model_json["allowNoCredit"],
|
AllowNoCredit=model_json["allowNoCredit"],
|
||||||
AllowCommercialUse=CommercialUsage(model_json["allowCommercialUse"]),
|
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
|
||||||
AllowDerivatives=model_json["allowDerivatives"],
|
AllowDerivatives=model_json["allowDerivatives"],
|
||||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||||
),
|
),
|
||||||
|
@ -54,8 +54,8 @@ class LicenseRestrictions(BaseModel):
|
|||||||
AllowDifferentLicense: bool = Field(
|
AllowDifferentLicense: bool = Field(
|
||||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||||
)
|
)
|
||||||
AllowCommercialUse: Optional[CommercialUsage] = Field(
|
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
|
||||||
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default=None
|
description="Type of commercial use allowed if no commercial use is allowed.", default=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -142,7 +142,10 @@ class CivitaiMetadata(ModelMetadataWithFiles):
|
|||||||
if self.restrictions.AllowCommercialUse is None:
|
if self.restrictions.AllowCommercialUse is None:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
|
# accommodate schema change
|
||||||
|
acu = self.restrictions.AllowCommercialUse
|
||||||
|
commercial_usage = acu if isinstance(acu, set) else {acu}
|
||||||
|
return CommercialUsage.No not in commercial_usage
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def allow_derivatives(self) -> bool:
|
def allow_derivatives(self) -> bool:
|
||||||
|
@ -21,7 +21,7 @@ from .config import (
|
|||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
)
|
)
|
||||||
from .hash import FastModelHash
|
from .hash import ModelHash
|
||||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||||
|
|
||||||
CkptType = Dict[str, Any]
|
CkptType = Dict[str, Any]
|
||||||
@ -147,7 +147,7 @@ class ModelProbe(object):
|
|||||||
if not probe_class:
|
if not probe_class:
|
||||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||||
|
|
||||||
hash = FastModelHash.hash(model_path)
|
hash = ModelHash().hash(model_path)
|
||||||
probe = probe_class(model_path)
|
probe = probe_class(model_path)
|
||||||
|
|
||||||
fields["path"] = model_path.as_posix()
|
fields["path"] = model_path.as_posix()
|
||||||
|
@ -64,6 +64,7 @@ dependencies = [
|
|||||||
|
|
||||||
# Auxiliary dependencies, pinned only if necessary.
|
# Auxiliary dependencies, pinned only if necessary.
|
||||||
"albumentations",
|
"albumentations",
|
||||||
|
"blake3",
|
||||||
"click",
|
"click",
|
||||||
"datasets",
|
"datasets",
|
||||||
"Deprecated",
|
"Deprecated",
|
||||||
@ -72,7 +73,6 @@ dependencies = [
|
|||||||
"easing-functions",
|
"easing-functions",
|
||||||
"einops",
|
"einops",
|
||||||
"facexlib",
|
"facexlib",
|
||||||
"imohash",
|
|
||||||
"matplotlib", # needed for plotting of Penner easing functions
|
"matplotlib", # needed for plotting of Penner easing functions
|
||||||
"npyscreen",
|
"npyscreen",
|
||||||
"omegaconf",
|
"omegaconf",
|
||||||
|
@ -31,7 +31,7 @@ def test_registration(mm2_installer: ModelInstallServiceBase, embedding_file: Pa
|
|||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
key = mm2_installer.register_path(embedding_file)
|
key = mm2_installer.register_path(embedding_file)
|
||||||
assert key is not None
|
assert key is not None
|
||||||
assert len(key) == 32
|
assert len(key) == 40 # length of the sha1 hash
|
||||||
|
|
||||||
|
|
||||||
def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||||
|
File diff suppressed because one or more lines are too long
@ -133,7 +133,7 @@ def test_metadata_civitai_fetch(mm2_session: Session) -> None:
|
|||||||
assert metadata.id == 215485
|
assert metadata.id == 215485
|
||||||
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
||||||
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
||||||
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
|
assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse
|
||||||
assert metadata.version_id == 242807
|
assert metadata.version_id == 242807
|
||||||
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
||||||
|
|
||||||
|
80
tests/test_model_hash.py
Normal file
80
tests/test_model_hash.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# pyright:reportPrivateUsage=false
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from blake3 import blake3
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.hash import ALGORITHM, ModelHash
|
||||||
|
|
||||||
|
test_cases: list[tuple[ALGORITHM, str]] = [
|
||||||
|
("md5", "a0cd925fc063f98dbf029eee315060c3"),
|
||||||
|
("sha1", "9e362940e5603fdc60566ea100a288ba2fe48b8c"),
|
||||||
|
("sha256", "6dbdb6a147ad4d808455652bf5a10120161678395f6bfbd21eb6fe4e731aceeb"),
|
||||||
|
(
|
||||||
|
"sha512",
|
||||||
|
"c4a10476b21e00042f638ad5755c561d91f2bb599d3504d25409495e1c7eda94543332a1a90fbb4efdaf9ee462c33e0336b5eae4acfb1fa0b186af452dd67dc6",
|
||||||
|
),
|
||||||
|
("blake3", "ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("algorithm,expected_hash", test_cases)
|
||||||
|
def test_model_hash_hashes_file(tmp_path: Path, algorithm: ALGORITHM, expected_hash: str):
|
||||||
|
file = Path(tmp_path / "test")
|
||||||
|
file.write_text("model data")
|
||||||
|
md5 = ModelHash(algorithm).hash(file)
|
||||||
|
assert md5 == expected_hash
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("algorithm", ["md5", "sha1", "sha256", "sha512", "blake3"])
|
||||||
|
def test_model_hash_hashes_dir(tmp_path: Path, algorithm: ALGORITHM):
|
||||||
|
model_hash = ModelHash(algorithm)
|
||||||
|
files = [Path(tmp_path, f"{i}.bin") for i in range(5)]
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
f.write_text("data")
|
||||||
|
|
||||||
|
md5 = model_hash.hash(tmp_path)
|
||||||
|
|
||||||
|
# Manual implementation of composite hash - always uses BLAKE3
|
||||||
|
composite_hasher = blake3()
|
||||||
|
for f in files:
|
||||||
|
h = model_hash.hash(f)
|
||||||
|
composite_hasher.update(h.encode("utf-8"))
|
||||||
|
|
||||||
|
assert md5 == composite_hasher.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_hash_raises_error_on_invalid_algorithm():
|
||||||
|
with pytest.raises(ValueError, match="Algorithm invalid_algorithm not available"):
|
||||||
|
ModelHash("invalid_algorithm") # pyright: ignore [reportArgumentType]
|
||||||
|
|
||||||
|
|
||||||
|
def paths_to_str_set(paths: Iterable[Path]) -> set[str]:
|
||||||
|
return {str(p) for p in paths}
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_hash_filters_out_non_model_files(tmp_path: Path):
|
||||||
|
model_files = {
|
||||||
|
Path(tmp_path, f"{i}.{ext}") for i, ext in enumerate([".ckpt", ".safetensors", ".bin", ".pt", ".pth"])
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, f in enumerate(model_files):
|
||||||
|
f.write_text(f"data{i}")
|
||||||
|
|
||||||
|
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
|
||||||
|
|
||||||
|
# Add file that should be ignored - hash should not change
|
||||||
|
file = tmp_path / "test.icecream"
|
||||||
|
file.write_text("data")
|
||||||
|
|
||||||
|
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
|
||||||
|
|
||||||
|
# Add file that should not be ignored - hash should change
|
||||||
|
file = tmp_path / "test.bin"
|
||||||
|
file.write_text("more data")
|
||||||
|
model_files.add(file)
|
||||||
|
|
||||||
|
assert paths_to_str_set(ModelHash._get_file_paths(tmp_path)) == paths_to_str_set(model_files)
|
Reference in New Issue
Block a user