make install_path and register_path work; refactor model probing

This commit is contained in:
Lincoln Stein 2023-11-23 23:15:32 -05:00
parent 8c7a7bc897
commit 80bc9be3ab
14 changed files with 1027 additions and 121 deletions

View File

@ -173,7 +173,7 @@ from __future__ import annotations
import os
from pathlib import Path
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
from omegaconf import DictConfig, OmegaConf
from pydantic import Field, TypeAdapter
@ -336,10 +336,8 @@ class InvokeAIAppConfig(InvokeAISettings):
)
@classmethod
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
"""
This returns a singleton InvokeAIAppConfig configuration object.
"""
def get_config(cls, **kwargs: Dict[str, Any]) -> InvokeAIAppConfig:
"""Return a singleton InvokeAIAppConfig configuration object."""
if (
cls.singleton_config is None
or type(cls.singleton_config) is not cls

View File

@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import traceback
from typing import Any, Optional
@ -343,7 +342,7 @@ class EventServiceBase:
)
def emit_model_install_error(self,
source:str,
source: str,
error_type: str,
error: str,
) -> None:

View File

@ -1 +1,2 @@
from .model_install_base import ModelInstallServiceBase # noqa F401
from .model_install_default import ModelInstallService # noqa F401

View File

@ -1,15 +1,15 @@
import traceback
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Dict, Optional, Union, List
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase
class InstallStatus(str, Enum):
@ -27,8 +27,8 @@ class ModelInstallJob(BaseModel):
source: Union[str, Path, AnyHttpUrl] = Field(description="Source (URL, repo_id, or local path) of model")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
local_path: Optional[Path] = Field(default=None, description="Path to locally-downloaded model")
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
error: Optional[str] = Field(default=None, description="Error traceback") # noqa #501
error_type: str = Field(default="", description="Class name of the exception that led to status==ERROR")
error: str = Field(default="", description="Error traceback") # noqa #501
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
@ -43,8 +43,8 @@ class ModelInstallServiceBase(ABC):
@abstractmethod
def __init__(
self,
config: InvokeAIAppConfig,
store: ModelRecordServiceBase,
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
event_bus: Optional["EventServiceBase"] = None,
):
"""
@ -54,15 +54,22 @@ class ModelInstallServiceBase(ABC):
:param store: Systemwide ModelConfigStore
:param event_bus: InvokeAI event bus for reporting events to.
"""
pass
@property
@abstractmethod
def app_config(self) -> InvokeAIAppConfig:
"""Return the appConfig object associated with the installer."""
@property
@abstractmethod
def record_store(self) -> ModelRecordServiceBase:
"""Return the ModelRecoreService object associated with the installer."""
@abstractmethod
def register_path(
self,
model_path: Union[Path, str],
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""
Probe and register the model at model_path.
@ -70,21 +77,32 @@ class ModelInstallServiceBase(ABC):
This keeps the model in its current location.
:param model_path: Filesystem Path to the model.
:param name: Name for the model (optional)
:param description: Description for the model (optional)
:param metadata: Dict of attributes that will override probed values.
:param metadata: Dict of attributes that will override autoassigned values.
:returns id: The string ID of the registered model.
"""
pass
@abstractmethod
def unregister(self, key: str) -> None:
"""Remove model with indicated key from the database."""
@abstractmethod
def delete(self, key: str) -> None:
"""Remove model with indicated key from the database and delete weight files from disk."""
@abstractmethod
def conditionally_delete(self, key: str) -> None:
"""
Remove model with indicated key from the database
and conditeionally delete weight files from disk
if they reside within InvokeAI's models directory.
"""
@abstractmethod
def install_path(
self,
model_path: Union[Path, str],
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
)-> str:
metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""
Probe, register and install the model in the models directory.
@ -92,20 +110,15 @@ class ModelInstallServiceBase(ABC):
the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model.
:param name: Name for the model (optional)
:param description: Description for the model (optional)
:param metadata: Dict of attributes that will override probed values.
:param metadata: Dict of attributes that will override autoassigned values.
:returns id: The string ID of the registered model.
"""
pass
@abstractmethod
def install_model(
def import_model(
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
name: Optional[str] = None,
description: Optional[str] = None,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
@ -125,9 +138,9 @@ class ModelInstallServiceBase(ABC):
specify a subfolder of the HF repository to download from.
:param metadata: Optional dict. Any fields in this dict
will override corresponding probe fields. Use it to override
`base_type`, `model_type`, `format`, `prediction_type`, `image_size`,
and `ztsnr_training`.
will override corresponding autoassigned probe fields. Use it to override
`name`, `description`, `base_type`, `model_type`, `format`,
`prediction_type`, `image_size`, and/or `ztsnr_training`.
:param access_token: Access token for use in downloading remote
models.
@ -154,10 +167,9 @@ class ModelInstallServiceBase(ABC):
4. None (usually returns fp32 model)
"""
pass
@abstractmethod
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]:
"""
Wait for all pending installs to complete.
@ -169,7 +181,6 @@ class ModelInstallServiceBase(ABC):
It will return a dict that maps the source model
path, URL or repo_id to the ID of the installed model.
"""
pass
@abstractmethod
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
@ -180,20 +191,7 @@ class ModelInstallServiceBase(ABC):
:param install: Install if True, otherwise register in place.
:returns list of IDs: Returns list of IDs of models registered/installed
"""
pass
@abstractmethod
def sync_to_config(self):
def sync_to_config(self) -> None:
"""Synchronize models on disk to those in memory."""
pass
@abstractmethod
def hash(self, model_path: Union[Path, str]) -> str:
"""
Compute and return the fast hash of the model.
:param model_path: Path to the model on disk.
:return str: FastHash of the model for use as an ID.
"""
pass

View File

@ -1,22 +1,28 @@
"""Model installation class."""
import threading
from hashlib import sha256
from pathlib import Path
from typing import Dict, Optional, Union, List
from queue import Queue
from random import randbytes
from shutil import move, rmtree
from typing import Any, Dict, List, Optional, Union
from pydantic.networks import AnyHttpUrl
from .model_install_base import InstallStatus, ModelInstallJob, ModelInstallServiceBase
from invokeai.backend.model_management.model_probe import ModelProbeInfo, ModelProbe
from invokeai.backend.model_manager.config import InvalidModelConfigException, DuplicateModelException
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.events import EventServiceBase
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager.config import (
AnyModelConfig,
DuplicateModelException,
InvalidModelConfigException,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.util.logging import InvokeAILogger
from .model_install_base import InstallStatus, ModelInstallJob, ModelInstallServiceBase
# marker that the queue is done and that thread should exit
STOP_JOB = ModelInstallJob(source="stop")
@ -25,113 +31,195 @@ STOP_JOB = ModelInstallJob(source="stop")
class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation."""
config: InvokeAIAppConfig
store: ModelRecordServiceBase
_app_config: InvokeAIAppConfig
_record_store: ModelRecordServiceBase
_event_bus: Optional[EventServiceBase] = None
_install_queue: Queue
_install_queue: Queue[ModelInstallJob]
_install_jobs: Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]
_logger: InvokeAILogger
def __init__(self,
config: InvokeAIAppConfig,
store: ModelRecordServiceBase,
install_queue: Optional[Queue] = None,
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
event_bus: Optional[EventServiceBase] = None
):
self.config = config
self.store = store
self._install_queue = install_queue or Queue()
"""
Initialize the installer object.
:param app_config: InvokeAIAppConfig object
:param record_store: Previously-opened ModelRecordService database
:param event_bus: Optional EventService object
"""
self._app_config = app_config
self._record_store = record_store
self._install_queue = Queue()
self._event_bus = event_bus
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
self._start_installer_thread()
def _start_installer_thread(self):
@property
def app_config(self) -> InvokeAIAppConfig: # noqa D102
return self._app_config
@property
def record_store(self) -> ModelRecordServiceBase: # noqa D102
return self._record_store
def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start()
def _install_next_item(self):
def _install_next_item(self) -> None:
done = False
while not done:
job = self._install_queue.get()
if job == STOP_JOB:
done = True
elif job.status == InstallStatus.WAITING:
assert job.local_path is not None
try:
self._signal_job_running(job)
self.register_path(job.path)
self.register_path(job.local_path)
self._signal_job_completed(job)
except (OSError, DuplicateModelException, InvalidModelConfigException) as e:
self._signal_job_errored(job, e)
except (OSError, DuplicateModelException, InvalidModelConfigException) as excp:
self._signal_job_errored(job, excp)
def _signal_job_running(self, job: ModelInstallJob):
def _signal_job_running(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.RUNNING
if self._event_bus:
self._event_bus.emit_model_install_started(job.source)
self._event_bus.emit_model_install_started(str(job.source))
def _signal_job_completed(self, job: ModelInstallJob):
def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED
if self._event_bus:
self._event_bus.emit_model_install_completed(job.source, job.dest)
assert job.local_path is not None
self._event_bus.emit_model_install_completed(str(job.source), job.local_path.as_posix())
def _signal_job_errored(self, job: ModelInstallJob, e: Exception):
job.set_error(e)
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
job.set_error(excp)
if self._event_bus:
self._event_bus.emit_model_install_error(job.source, job.error_type, job.error)
self._event_bus.emit_model_install_error(str(job.source), job.error_type, job.error)
def register_path(
self,
model_path: Union[Path, str],
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
) -> str:
metadata: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = self._probe_model(model_path, metadata)
return self._register(model_path, info)
metadata = metadata or {}
if metadata.get('source') is None:
metadata['source'] = model_path.as_posix()
return self._register(model_path, metadata)
def install_path(
self,
model_path: Union[Path, str],
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
) -> str:
raise NotImplementedError
metadata: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
metadata = metadata or {}
if metadata.get('source') is None:
metadata['source'] = model_path.as_posix()
def install_model(
info: AnyModelConfig = self._probe_model(Path(model_path), metadata)
old_hash = info.original_hash
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
new_path = self._move_model(model_path, dest_path)
new_hash = FastModelHash.hash(new_path)
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
return self._register(
new_path,
metadata,
info,
)
def import_model(
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
name: Optional[str] = None,
description: Optional[str] = None,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
access_token: Optional[str] = None,
) -> ModelInstallJob:
) -> ModelInstallJob: # noqa D102
raise NotImplementedError
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]: # noqa D102
self._install_queue.join()
return self._install_jobs
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
raise NotImplementedError
def sync_to_config(self):
def sync_to_config(self) -> None: # noqa D102
raise NotImplementedError
def hash(self, model_path: Union[Path, str]) -> str:
return FastModelHash.hash(model_path)
def unregister(self, key: str) -> None: # noqa D102
self.record_store.del_model(key)
# The following are internal methods
def _create_name(self, model_path: Union[Path, str]) -> str:
model_path = Path(model_path)
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}:
return model_path.stem
def delete(self, key: str) -> None: # noqa D102
model = self.record_store.get_model(key)
path = self.app_config.models_path / model.path
if path.is_dir():
rmtree(path)
else:
return model_path.name
path.unlink()
self.unregister(key)
def _create_description(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str:
info = info or ModelProbe.probe(Path(model_path))
name: str = self._create_name(model_path)
return f"a {info.model_type} model {name} based on {info.base_type}"
def conditionally_delete(self, key: str) -> None: # noqa D102
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self.record_store.get_model(key)
models_dir = self.app_config.models_path
model_path = models_dir / model.path
if model_path.is_relative_to(models_dir):
self.delete(key)
else:
self.unregister(key)
def _create_id(self, model_path: Union[Path, str], name: Optional[str] = None) -> str:
name: str = name or self._create_name(model_path)
raise NotImplementedError
def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
# if path already exists then we jigger the name to make it unique
counter: int = 1
while new_path.exists():
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
if not path.exists():
new_path = path
counter += 1
move(old_path, new_path)
return new_path
def _probe_model(self, model_path: Path, metadata: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
if metadata: # used to override probe fields
for key, value in metadata.items():
setattr(info, key, value)
return info
def _create_key(self) -> str:
return sha256(randbytes(100)).hexdigest()[0:32]
def _register(self,
model_path: Path,
metadata: Optional[Dict[str, Any]] = None,
info: Optional[AnyModelConfig] = None) -> str:
info = info or ModelProbe.probe(model_path, metadata)
key = self._create_key()
model_path = model_path.absolute()
if model_path.is_relative_to(self.app_config.models_path):
model_path = model_path.relative_to(self.app_config.models_path)
info.path = model_path.as_posix()
# add 'main' specific fields
if hasattr(info, 'config'):
# make config relative to our root
info.config = self.app_config.legacy_conf_dir / info.config
self.record_store.add_model(key, info)
return key

View File

@ -6,3 +6,11 @@ from .model_records_base import ( # noqa F401
UnknownModelException,
)
from .model_records_sql import ModelRecordServiceSQL # noqa F401
__all__ = [
'ModelRecordServiceBase',
'ModelRecordServiceSQL',
'DuplicateModelException',
'InvalidModelException',
'UnknownModelException',
]

View File

@ -23,7 +23,7 @@ from enum import Enum
from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated
from typing_extensions import Annotated, Any, Dict
class InvalidModelConfigException(Exception):
@ -125,7 +125,7 @@ class ModelConfigBase(BaseModel):
validate_assignment=True,
)
def update(self, attributes: dict):
def update(self, attributes: Dict[str, Any]) -> None:
"""Update the object with fields in dict."""
for key, value in attributes.items():
setattr(self, key, value) # may raise a validation error
@ -198,8 +198,6 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
# Note that we do not need prediction_type or upcast_attention here
# because they are provided in the checkpoint's own config file.
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):

View File

@ -0,0 +1,645 @@
import json
import re
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
import safetensors.torch
import torch
from picklescan.scanner import scan_file_path
from invokeai.backend.model_management.models.base import read_checkpoint_meta
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from invokeai.backend.model_management.util import lora_token_vector_length
from invokeai.backend.util.util import SilenceWarnings
from .config import (
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
ModelConfigFactory,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
)
from .hash import FastModelHash
CkptType = Dict[str, Any]
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
},
}
class ProbeBase(object):
"""Base class for probes."""
def __init__(self, model_path: Path):
self.model_path = model_path
def get_base_type(self) -> BaseModelType:
"""Get model base type."""
raise NotImplementedError
def get_format(self) -> ModelFormat:
"""Get model file format."""
raise NotImplementedError
def get_variant_type(self) -> Optional[ModelVariantType]:
"""Get model variant type."""
return None
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
"""Get model scheduler prediction type."""
return None
class ModelProbe(object):
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
"diffusers": {},
"checkpoint": {},
"onnx": {},
}
CLASS2TYPE = {
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
}
@classmethod
def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
) -> None:
cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
) -> AnyModelConfig:
return cls.probe(model_path, fields)
@classmethod
def probe(
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
) -> AnyModelConfig:
"""
Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The prediction_type_helper callable is a function that receives
the path to the model and returns the SchedulerPredictionType.
"""
if fields is None:
fields = {}
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None
model_type = None
if format_type == "diffusers":
model_type = cls.get_model_type_from_folder(model_path)
else:
model_type = cls.get_model_type_from_checkpoint(model_path)
print(f'DEBUG: model_type={model_type}')
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = FastModelHash.hash(model_path)
probe = probe_class(model_path)
fields['path'] = model_path.as_posix()
fields['type'] = fields.get('type') or model_type
fields['base'] = fields.get('base') or probe.get_base_type()
fields['variant'] = fields.get('variant') or probe.get_variant_type()
fields['prediction_type'] = fields.get('prediction_type') or probe.get_scheduler_prediction_type()
fields['name'] = fields.get('name') or cls.get_model_name(model_path)
fields['description'] = fields.get('description') or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
fields['format'] = fields.get('format') or probe.get_format()
fields['original_hash'] = fields.get('original_hash') or hash
fields['current_hash'] = fields.get('current_hash') or hash
# additional work for main models
if fields['type'] == ModelType.Main:
if fields['format'] == ModelFormat.Checkpoint:
fields['config'] = cls._get_config_path(model_path, fields['base'], fields['variant'], fields['prediction_type']).as_posix()
elif fields['format'] in [ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers]:
fields['upcast_attention'] = fields.get('upcast_attention') or (
fields['base'] == BaseModelType.StableDiffusion2 and fields['prediction_type'] == SchedulerPredictionType.VPrediction
)
model_info = ModelConfigFactory.make_config(fields)
return model_info
@classmethod
def get_model_name(cls, model_path: Path) -> str:
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}:
return model_path.stem
else:
return model_path.name
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType:
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
raise InvalidModelConfigException(f"{model_path}: unrecognized suffix")
if model_path.name == "learned_embeds.bin":
return ModelType.TextualInversion
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
ckpt = ckpt.get("state_dict", ckpt)
for key in ckpt.keys():
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.Lora
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
else:
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
@classmethod
def get_model_type_from_folder(cls, folder_path: Path) -> ModelType:
"""Get the model type of a hugging-face style folder."""
class_name = None
error_hint = None
for suffix in ["bin", "safetensors"]:
if (folder_path / f"learned_embeds.{suffix}").exists():
return ModelType.TextualInversion
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
return ModelType.Lora
if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
i = folder_path / "model_index.json"
c = folder_path / "config.json"
config_path = i if i.exists() else c if c.exists() else None
if config_path:
with open(config_path, "r") as file:
conf = json.load(file)
if "_class_name" in conf:
class_name = conf["_class_name"]
elif "architectures" in conf:
class_name = conf["architectures"][0]
else:
class_name = None
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
else:
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
# give up
raise InvalidModelConfigException(
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
)
@classmethod
def _get_config_path(cls,
model_path: Path,
base_type: BaseModelType,
variant: ModelVariantType,
prediction_type: SchedulerPredictionType) -> Path:
# look for a YAML file adjacent to the model file first
possible_conf = model_path.with_suffix(".yaml")
if possible_conf.exists():
return possible_conf.absolute()
config_file = LEGACY_CONFIGS[base_type][variant]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
return Path(config_file)
@classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path.name, model_path)
model = torch.load(model_path)
assert isinstance(model, dict)
return model
else:
return safetensors.torch.load_file(model_path)
@classmethod
def _scan_model(cls, model_name: str, checkpoint: Path) -> None:
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
# ##################################################3
# Checkpoint probing
# ##################################################3
class CheckpointProbeBase(ProbeBase):
def __init__(self, model_path: Path):
super().__init__(model_path)
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
def get_format(self) -> ModelFormat:
return ModelFormat("checkpoint")
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
if model_type != ModelType.Main:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
else:
raise InvalidModelConfigException(
f"Cannot determine variant type (in_channels={in_channels}) at {self.model_path}"
)
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
else:
raise InvalidModelConfigException("Cannot determine base type")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
"""Return model prediction type."""
type = self.get_base_type()
if type == BaseModelType.StableDiffusion2:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if "global_step" in checkpoint:
if checkpoint["global_step"] == 220000:
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
elif type == BaseModelType.StableDiffusion1:
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
else:
return SchedulerPredictionType.Epsilon
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):
"""Class for LoRA checkpoints."""
def get_format(self) -> ModelFormat:
return ModelFormat("lycoris")
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown LoRA type: {self.model_path}")
class TextualInversionCheckpointProbe(CheckpointProbeBase):
"""Class for probing embeddings."""
def get_format(self) -> ModelFormat:
return ModelFormat.EmbeddingFile
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
if "string_to_token" in checkpoint:
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
elif "emb_params" in checkpoint:
token_dim = checkpoint["emb_params"].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1
elif token_dim == 1024:
return BaseModelType.StableDiffusion2
else:
raise InvalidModelConfigException("Could not determine base type")
class ControlNetCheckpointProbe(CheckpointProbeBase):
"""Class for probing controlnets."""
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
elif checkpoint[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
raise InvalidModelConfigException("Unable to determine base type for {self.checkpoint_path}")
class IPAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
def get_format(self) -> ModelFormat:
return ModelFormat("diffusers")
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
scheduler_conf = json.load(file)
if scheduler_conf["prediction_type"] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf["prediction_type"] == "epsilon":
return SchedulerPredictionType.Epsilon
else:
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
def get_variant_type(self) -> ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
try:
config_file = self.model_path / "unet" / "config.json"
with open(config_file, "r") as file:
conf = json.load(file)
in_channels = conf["in_channels"]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
except Exception:
pass
return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
return BaseModelType.StableDiffusionXL
else:
return BaseModelType.StableDiffusion1
def _config_looks_like_sdxl(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.model_path / "config.json"
if not config_file.exists():
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
def _name_looks_like_sdxl(self) -> bool:
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
def _guess_name(self) -> str:
name = self.model_path.name
if name == "vae":
name = self.model_path.parent.name
return name
class TextualInversionFolderProbe(FolderProbeBase):
def get_format(self) -> ModelFormat:
return ModelFormat.EmbeddingFolder
def get_base_type(self) -> BaseModelType:
path = self.model_path / "learned_embeds.bin"
if not path.exists():
raise InvalidModelConfigException(f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file")
return TextualInversionCheckpointProbe(path).get_base_type()
class ONNXFolderProbe(FolderProbeBase):
def get_format(self) -> ModelFormat:
return ModelFormat("onnx")
def get_base_type(self) -> BaseModelType:
return BaseModelType.StableDiffusion1
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
if not config_file.exists():
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
dimension = config["cross_attention_dim"]
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else (
BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
)
if not base_model:
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
return base_model
class LoRAFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
model_file = None
for suffix in ["safetensors", "bin"]:
base_file = self.model_path / f"pytorch_lora_weights.{suffix}"
if base_file.exists():
model_file = base_file
break
if not model_file:
raise InvalidModelConfigException("Unknown LoRA format encountered")
return LoRACheckpointProbe(model_file).get_base_type()
class IPAdapterFolderProbe(FolderProbeBase):
def get_format(self) -> IPAdapterModelFormat:
return IPAdapterModelFormat.InvokeAI.value
def get_base_type(self) -> BaseModelType:
model_file = self.model_path / "ip_adapter.bin"
if not model_file.exists():
raise InvalidModelConfigException("Unknown IP-Adapter model format.")
state_dict = torch.load(model_file, map_location="cpu")
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
class CLIPVisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
if not config_file.exists():
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
adapter_type = config.get("adapter_type", None)
if adapter_type == "full_adapter_xl":
return BaseModelType.StableDiffusionXL
elif adapter_type == "full_adapter" or "light_adapter":
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
return BaseModelType.StableDiffusion1
else:
raise InvalidModelConfigException(
f"Unable to determine base model for '{self.model_path}' (adapter_type = {adapter_type})."
)
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@ -219,6 +219,8 @@ exclude = [
# global mypy config
[tool.mypy]
ignore_missing_imports = true # ignores missing types in third-party libraries
strict = true
exclude = ["tests/*"]
# overrides for specific modules
[[tool.mypy.overrides]]

View File

@ -3,7 +3,7 @@
import argparse
from pathlib import Path
from invokeai.backend.model_management.model_probe import ModelProbe
from invokeai.backend.model_manager.probe import ModelProbe
parser = argparse.ArgumentParser(description="Probe model type")
parser.add_argument(

View File

@ -0,0 +1,89 @@
"""
Test the model installer
"""
import pytest
from pathlib import Path
from pydantic import ValidationError
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.model_manager.config import ModelType, BaseModelType
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceBase
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
@pytest.fixture
def test_file(datadir: Path) -> Path:
return datadir / "test_embedding.safetensors"
@pytest.fixture
def app_config(datadir: Path) -> InvokeAIAppConfig:
return InvokeAIAppConfig(
root=datadir / "root",
models_dir=datadir / "root/models",
)
@pytest.fixture
def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
database = SqliteDatabase(app_config, InvokeAILogger.get_logger(config=app_config))
store: ModelRecordServiceBase = ModelRecordServiceSQL(database)
return store
@pytest.fixture
def installer(app_config: InvokeAIAppConfig,
store: ModelRecordServiceBase) -> ModelInstallServiceBase:
return ModelInstallService(app_config=app_config,
record_store=store
)
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = installer.register_path(test_file)
assert key is not None
assert len(key) == 32
def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store
key = installer.register_path(test_file)
model_record = store.get_model(key)
assert model_record is not None
assert model_record.name == "test_embedding"
assert model_record.type == ModelType.TextualInversion
assert Path(model_record.path) == test_file
assert model_record.base == BaseModelType('sd-1')
assert model_record.description is not None
assert model_record.source is not None
assert Path(model_record.source) == test_file
def test_registration_meta_override_fail(installer: ModelInstallServiceBase, test_file: Path) -> None:
key = None
with pytest.raises(ValidationError):
key = installer.register_path(test_file, {"name": "banana_sushi", "type": ModelType("lora")})
assert key is None
def test_registration_meta_override_succeed(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store
key = installer.register_path(test_file,
{
"name": "banana_sushi",
"source": "fake/repo_id",
"current_hash": "New Hash"
}
)
model_record = store.get_model(key)
assert model_record.name == "banana_sushi"
assert model_record.source == "fake/repo_id"
assert model_record.current_hash == "New Hash"
def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
store = installer.record_store
key = installer.install_path(test_file)
model_record = store.get_model(key)
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
assert model_record.source == test_file.as_posix()

View File

@ -0,0 +1,79 @@
model:
base_learning_rate: 1.0e-04
target: invokeai.backend.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
personalization_config:
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ['sculpture']
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder

View File

@ -0,0 +1 @@
Dummy file to establish git path.