mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make install_path and register_path work; refactor model probing
This commit is contained in:
parent
8c7a7bc897
commit
80bc9be3ab
@ -173,7 +173,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import Field, TypeAdapter
|
||||
@ -336,10 +336,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
This returns a singleton InvokeAIAppConfig configuration object.
|
||||
"""
|
||||
def get_config(cls, **kwargs: Dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Return a singleton InvokeAIAppConfig configuration object."""
|
||||
if (
|
||||
cls.singleton_config is None
|
||||
or type(cls.singleton_config) is not cls
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import traceback
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -343,7 +342,7 @@ class EventServiceBase:
|
||||
)
|
||||
|
||||
def emit_model_install_error(self,
|
||||
source:str,
|
||||
source: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
|
@ -1 +1,2 @@
|
||||
from .model_install_base import ModelInstallServiceBase # noqa F401
|
||||
from .model_install_default import ModelInstallService # noqa F401
|
||||
|
@ -1,15 +1,15 @@
|
||||
import traceback
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union, List
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
@ -27,8 +27,8 @@ class ModelInstallJob(BaseModel):
|
||||
source: Union[str, Path, AnyHttpUrl] = Field(description="Source (URL, repo_id, or local path) of model")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
local_path: Optional[Path] = Field(default=None, description="Path to locally-downloaded model")
|
||||
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
|
||||
error: Optional[str] = Field(default=None, description="Error traceback") # noqa #501
|
||||
error_type: str = Field(default="", description="Class name of the exception that led to status==ERROR")
|
||||
error: str = Field(default="", description="Error traceback") # noqa #501
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
@ -43,8 +43,8 @@ class ModelInstallServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
store: ModelRecordServiceBase,
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
"""
|
||||
@ -54,15 +54,22 @@ class ModelInstallServiceBase(ABC):
|
||||
:param store: Systemwide ModelConfigStore
|
||||
:param event_bus: InvokeAI event bus for reporting events to.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def app_config(self) -> InvokeAIAppConfig:
|
||||
"""Return the appConfig object associated with the installer."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def record_store(self) -> ModelRecordServiceBase:
|
||||
"""Return the ModelRecoreService object associated with the installer."""
|
||||
|
||||
@abstractmethod
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Probe and register the model at model_path.
|
||||
@ -70,21 +77,32 @@ class ModelInstallServiceBase(ABC):
|
||||
This keeps the model in its current location.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param name: Name for the model (optional)
|
||||
:param description: Description for the model (optional)
|
||||
:param metadata: Dict of attributes that will override probed values.
|
||||
:param metadata: Dict of attributes that will override autoassigned values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unregister(self, key: str) -> None:
|
||||
"""Remove model with indicated key from the database."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
"""Remove model with indicated key from the database and delete weight files from disk."""
|
||||
|
||||
@abstractmethod
|
||||
def conditionally_delete(self, key: str) -> None:
|
||||
"""
|
||||
Remove model with indicated key from the database
|
||||
and conditeionally delete weight files from disk
|
||||
if they reside within InvokeAI's models directory.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
)-> str:
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Probe, register and install the model in the models directory.
|
||||
|
||||
@ -92,20 +110,15 @@ class ModelInstallServiceBase(ABC):
|
||||
the models directory handled by InvokeAI.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param name: Name for the model (optional)
|
||||
:param description: Description for the model (optional)
|
||||
:param metadata: Dict of attributes that will override probed values.
|
||||
:param metadata: Dict of attributes that will override autoassigned values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def install_model(
|
||||
def import_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
@ -125,9 +138,9 @@ class ModelInstallServiceBase(ABC):
|
||||
specify a subfolder of the HF repository to download from.
|
||||
|
||||
:param metadata: Optional dict. Any fields in this dict
|
||||
will override corresponding probe fields. Use it to override
|
||||
`base_type`, `model_type`, `format`, `prediction_type`, `image_size`,
|
||||
and `ztsnr_training`.
|
||||
will override corresponding autoassigned probe fields. Use it to override
|
||||
`name`, `description`, `base_type`, `model_type`, `format`,
|
||||
`prediction_type`, `image_size`, and/or `ztsnr_training`.
|
||||
|
||||
:param access_token: Access token for use in downloading remote
|
||||
models.
|
||||
@ -154,10 +167,9 @@ class ModelInstallServiceBase(ABC):
|
||||
4. None (usually returns fp32 model)
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
@ -169,7 +181,6 @@ class ModelInstallServiceBase(ABC):
|
||||
It will return a dict that maps the source model
|
||||
path, URL or repo_id to the ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
|
||||
@ -180,20 +191,7 @@ class ModelInstallServiceBase(ABC):
|
||||
:param install: Install if True, otherwise register in place.
|
||||
:returns list of IDs: Returns list of IDs of models registered/installed
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
def sync_to_config(self) -> None:
|
||||
"""Synchronize models on disk to those in memory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def hash(self, model_path: Union[Path, str]) -> str:
|
||||
"""
|
||||
Compute and return the fast hash of the model.
|
||||
|
||||
:param model_path: Path to the model on disk.
|
||||
:return str: FastHash of the model for use as an ID.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -1,22 +1,28 @@
|
||||
"""Model installation class."""
|
||||
|
||||
import threading
|
||||
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union, List
|
||||
from queue import Queue
|
||||
from random import randbytes
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from .model_install_base import InstallStatus, ModelInstallJob, ModelInstallServiceBase
|
||||
|
||||
from invokeai.backend.model_management.model_probe import ModelProbeInfo, ModelProbe
|
||||
from invokeai.backend.model_manager.config import InvalidModelConfigException, DuplicateModelException
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
DuplicateModelException,
|
||||
InvalidModelConfigException,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_install_base import InstallStatus, ModelInstallJob, ModelInstallServiceBase
|
||||
|
||||
# marker that the queue is done and that thread should exit
|
||||
STOP_JOB = ModelInstallJob(source="stop")
|
||||
@ -25,113 +31,195 @@ STOP_JOB = ModelInstallJob(source="stop")
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
"""class for InvokeAI model installation."""
|
||||
|
||||
config: InvokeAIAppConfig
|
||||
store: ModelRecordServiceBase
|
||||
_app_config: InvokeAIAppConfig
|
||||
_record_store: ModelRecordServiceBase
|
||||
_event_bus: Optional[EventServiceBase] = None
|
||||
_install_queue: Queue
|
||||
_install_queue: Queue[ModelInstallJob]
|
||||
_install_jobs: Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]
|
||||
_logger: InvokeAILogger
|
||||
|
||||
def __init__(self,
|
||||
config: InvokeAIAppConfig,
|
||||
store: ModelRecordServiceBase,
|
||||
install_queue: Optional[Queue] = None,
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
event_bus: Optional[EventServiceBase] = None
|
||||
):
|
||||
self.config = config
|
||||
self.store = store
|
||||
self._install_queue = install_queue or Queue()
|
||||
"""
|
||||
Initialize the installer object.
|
||||
|
||||
:param app_config: InvokeAIAppConfig object
|
||||
:param record_store: Previously-opened ModelRecordService database
|
||||
:param event_bus: Optional EventService object
|
||||
"""
|
||||
self._app_config = app_config
|
||||
self._record_store = record_store
|
||||
self._install_queue = Queue()
|
||||
self._event_bus = event_bus
|
||||
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
|
||||
self._start_installer_thread()
|
||||
|
||||
def _start_installer_thread(self):
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
return self._app_config
|
||||
|
||||
@property
|
||||
def record_store(self) -> ModelRecordServiceBase: # noqa D102
|
||||
return self._record_store
|
||||
|
||||
def _start_installer_thread(self) -> None:
|
||||
threading.Thread(target=self._install_next_item, daemon=True).start()
|
||||
|
||||
def _install_next_item(self):
|
||||
def _install_next_item(self) -> None:
|
||||
done = False
|
||||
while not done:
|
||||
job = self._install_queue.get()
|
||||
if job == STOP_JOB:
|
||||
done = True
|
||||
elif job.status == InstallStatus.WAITING:
|
||||
assert job.local_path is not None
|
||||
try:
|
||||
self._signal_job_running(job)
|
||||
self.register_path(job.path)
|
||||
self.register_path(job.local_path)
|
||||
self._signal_job_completed(job)
|
||||
except (OSError, DuplicateModelException, InvalidModelConfigException) as e:
|
||||
self._signal_job_errored(job, e)
|
||||
except (OSError, DuplicateModelException, InvalidModelConfigException) as excp:
|
||||
self._signal_job_errored(job, excp)
|
||||
|
||||
def _signal_job_running(self, job: ModelInstallJob):
|
||||
def _signal_job_running(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.RUNNING
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_started(job.source)
|
||||
self._event_bus.emit_model_install_started(str(job.source))
|
||||
|
||||
def _signal_job_completed(self, job: ModelInstallJob):
|
||||
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.COMPLETED
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_completed(job.source, job.dest)
|
||||
assert job.local_path is not None
|
||||
self._event_bus.emit_model_install_completed(str(job.source), job.local_path.as_posix())
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob, e: Exception):
|
||||
job.set_error(e)
|
||||
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
|
||||
job.set_error(excp)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_error(job.source, job.error_type, job.error)
|
||||
self._event_bus.emit_model_install_error(str(job.source), job.error_type, job.error)
|
||||
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, metadata)
|
||||
return self._register(model_path, info)
|
||||
metadata = metadata or {}
|
||||
if metadata.get('source') is None:
|
||||
metadata['source'] = model_path.as_posix()
|
||||
return self._register(model_path, metadata)
|
||||
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
metadata = metadata or {}
|
||||
if metadata.get('source') is None:
|
||||
metadata['source'] = model_path.as_posix()
|
||||
|
||||
def install_model(
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), metadata)
|
||||
|
||||
old_hash = info.original_hash
|
||||
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
|
||||
new_path = self._move_model(model_path, dest_path)
|
||||
new_hash = FastModelHash.hash(new_path)
|
||||
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||
|
||||
return self._register(
|
||||
new_path,
|
||||
metadata,
|
||||
info,
|
||||
)
|
||||
|
||||
def import_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
) -> ModelInstallJob: # noqa D102
|
||||
raise NotImplementedError
|
||||
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]: # noqa D102
|
||||
self._install_queue.join()
|
||||
return self._install_jobs
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
raise NotImplementedError
|
||||
|
||||
def sync_to_config(self):
|
||||
def sync_to_config(self) -> None: # noqa D102
|
||||
raise NotImplementedError
|
||||
|
||||
def hash(self, model_path: Union[Path, str]) -> str:
|
||||
return FastModelHash.hash(model_path)
|
||||
def unregister(self, key: str) -> None: # noqa D102
|
||||
self.record_store.del_model(key)
|
||||
|
||||
# The following are internal methods
|
||||
def _create_name(self, model_path: Union[Path, str]) -> str:
|
||||
model_path = Path(model_path)
|
||||
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}:
|
||||
return model_path.stem
|
||||
def delete(self, key: str) -> None: # noqa D102
|
||||
model = self.record_store.get_model(key)
|
||||
path = self.app_config.models_path / model.path
|
||||
if path.is_dir():
|
||||
rmtree(path)
|
||||
else:
|
||||
return model_path.name
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def _create_description(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str:
|
||||
info = info or ModelProbe.probe(Path(model_path))
|
||||
name: str = self._create_name(model_path)
|
||||
return f"a {info.model_type} model {name} based on {info.base_type}"
|
||||
def conditionally_delete(self, key: str) -> None: # noqa D102
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
model = self.record_store.get_model(key)
|
||||
models_dir = self.app_config.models_path
|
||||
model_path = models_dir / model.path
|
||||
if model_path.is_relative_to(models_dir):
|
||||
self.delete(key)
|
||||
else:
|
||||
self.unregister(key)
|
||||
|
||||
def _create_id(self, model_path: Union[Path, str], name: Optional[str] = None) -> str:
|
||||
name: str = name or self._create_name(model_path)
|
||||
raise NotImplementedError
|
||||
def _move_model(self, old_path: Path, new_path: Path) -> Path:
|
||||
if old_path == new_path:
|
||||
return old_path
|
||||
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# if path already exists then we jigger the name to make it unique
|
||||
counter: int = 1
|
||||
while new_path.exists():
|
||||
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
|
||||
if not path.exists():
|
||||
new_path = path
|
||||
counter += 1
|
||||
move(old_path, new_path)
|
||||
return new_path
|
||||
|
||||
def _probe_model(self, model_path: Path, metadata: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
||||
if metadata: # used to override probe fields
|
||||
for key, value in metadata.items():
|
||||
setattr(info, key, value)
|
||||
return info
|
||||
|
||||
def _create_key(self) -> str:
|
||||
return sha256(randbytes(100)).hexdigest()[0:32]
|
||||
|
||||
def _register(self,
|
||||
model_path: Path,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
info: Optional[AnyModelConfig] = None) -> str:
|
||||
|
||||
info = info or ModelProbe.probe(model_path, metadata)
|
||||
key = self._create_key()
|
||||
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
model_path = model_path.relative_to(self.app_config.models_path)
|
||||
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# add 'main' specific fields
|
||||
if hasattr(info, 'config'):
|
||||
# make config relative to our root
|
||||
info.config = self.app_config.legacy_conf_dir / info.config
|
||||
self.record_store.add_model(key, info)
|
||||
return key
|
||||
|
@ -6,3 +6,11 @@ from .model_records_base import ( # noqa F401
|
||||
UnknownModelException,
|
||||
)
|
||||
from .model_records_sql import ModelRecordServiceSQL # noqa F401
|
||||
|
||||
__all__ = [
|
||||
'ModelRecordServiceBase',
|
||||
'ModelRecordServiceSQL',
|
||||
'DuplicateModelException',
|
||||
'InvalidModelException',
|
||||
'UnknownModelException',
|
||||
]
|
||||
|
@ -23,7 +23,7 @@ from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from typing_extensions import Annotated
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
@ -125,7 +125,7 @@ class ModelConfigBase(BaseModel):
|
||||
validate_assignment=True,
|
||||
)
|
||||
|
||||
def update(self, attributes: dict):
|
||||
def update(self, attributes: Dict[str, Any]) -> None:
|
||||
"""Update the object with fields in dict."""
|
||||
for key, value in attributes.items():
|
||||
setattr(self, key, value) # may raise a validation error
|
||||
@ -198,8 +198,6 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
# Note that we do not need prediction_type or upcast_attention here
|
||||
# because they are provided in the checkpoint's own config file.
|
||||
|
||||
|
||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
|
645
invokeai/backend/model_manager/probe.py
Normal file
645
invokeai/backend/model_manager/probe.py
Normal file
@ -0,0 +1,645 @@
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
||||
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
||||
from invokeai.backend.model_management.util import lora_token_vector_length
|
||||
from invokeai.backend.util.util import SilenceWarnings
|
||||
|
||||
from .config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .hash import FastModelHash
|
||||
|
||||
CkptType = Dict[str, Any]
|
||||
|
||||
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: "v1-inference.yaml",
|
||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
||||
},
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||
},
|
||||
}
|
||||
|
||||
class ProbeBase(object):
|
||||
"""Base class for probes."""
|
||||
|
||||
def __init__(self, model_path: Path):
|
||||
self.model_path = model_path
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Get model base type."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
"""Get model file format."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_variant_type(self) -> Optional[ModelVariantType]:
|
||||
"""Get model variant type."""
|
||||
return None
|
||||
|
||||
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
|
||||
"""Get model scheduler prediction type."""
|
||||
return None
|
||||
|
||||
class ModelProbe(object):
|
||||
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
||||
"diffusers": {},
|
||||
"checkpoint": {},
|
||||
"onnx": {},
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
"StableDiffusionPipeline": ModelType.Main,
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_probe(
|
||||
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
|
||||
) -> None:
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
fields: Optional[Dict[str, Any]] = None,
|
||||
) -> AnyModelConfig:
|
||||
return cls.probe(model_path, fields)
|
||||
|
||||
@classmethod
|
||||
def probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
fields: Optional[Dict[str, Any]] = None,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Probe the model at model_path and return sufficient information about it
|
||||
to place it somewhere in the models directory hierarchy. If the model is
|
||||
already loaded into memory, you may provide it as model in order to avoid
|
||||
opening it a second time. The prediction_type_helper callable is a function that receives
|
||||
the path to the model and returns the SchedulerPredictionType.
|
||||
"""
|
||||
if fields is None:
|
||||
fields = {}
|
||||
|
||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||
model_info = None
|
||||
model_type = None
|
||||
if format_type == "diffusers":
|
||||
model_type = cls.get_model_type_from_folder(model_path)
|
||||
else:
|
||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||
print(f'DEBUG: model_type={model_type}')
|
||||
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
if not probe_class:
|
||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||
|
||||
hash = FastModelHash.hash(model_path)
|
||||
probe = probe_class(model_path)
|
||||
|
||||
fields['path'] = model_path.as_posix()
|
||||
fields['type'] = fields.get('type') or model_type
|
||||
fields['base'] = fields.get('base') or probe.get_base_type()
|
||||
fields['variant'] = fields.get('variant') or probe.get_variant_type()
|
||||
fields['prediction_type'] = fields.get('prediction_type') or probe.get_scheduler_prediction_type()
|
||||
fields['name'] = fields.get('name') or cls.get_model_name(model_path)
|
||||
fields['description'] = fields.get('description') or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
fields['format'] = fields.get('format') or probe.get_format()
|
||||
fields['original_hash'] = fields.get('original_hash') or hash
|
||||
fields['current_hash'] = fields.get('current_hash') or hash
|
||||
|
||||
# additional work for main models
|
||||
if fields['type'] == ModelType.Main:
|
||||
if fields['format'] == ModelFormat.Checkpoint:
|
||||
fields['config'] = cls._get_config_path(model_path, fields['base'], fields['variant'], fields['prediction_type']).as_posix()
|
||||
elif fields['format'] in [ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers]:
|
||||
fields['upcast_attention'] = fields.get('upcast_attention') or (
|
||||
fields['base'] == BaseModelType.StableDiffusion2 and fields['prediction_type'] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields)
|
||||
return model_info
|
||||
|
||||
@classmethod
|
||||
def get_model_name(cls, model_path: Path) -> str:
|
||||
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}:
|
||||
return model_path.stem
|
||||
else:
|
||||
return model_path.name
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType:
|
||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||
raise InvalidModelConfigException(f"{model_path}: unrecognized suffix")
|
||||
|
||||
if model_path.name == "learned_embeds.bin":
|
||||
return ModelType.TextualInversion
|
||||
|
||||
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in ckpt.keys():
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
return ModelType.Vae
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
|
||||
else:
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_folder(cls, folder_path: Path) -> ModelType:
|
||||
"""Get the model type of a hugging-face style folder."""
|
||||
class_name = None
|
||||
error_hint = None
|
||||
for suffix in ["bin", "safetensors"]:
|
||||
if (folder_path / f"learned_embeds.{suffix}").exists():
|
||||
return ModelType.TextualInversion
|
||||
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
||||
return ModelType.Lora
|
||||
if (folder_path / "unet/model.onnx").exists():
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
return ModelType.IPAdapter
|
||||
|
||||
i = folder_path / "model_index.json"
|
||||
c = folder_path / "config.json"
|
||||
config_path = i if i.exists() else c if c.exists() else None
|
||||
|
||||
if config_path:
|
||||
with open(config_path, "r") as file:
|
||||
conf = json.load(file)
|
||||
if "_class_name" in conf:
|
||||
class_name = conf["_class_name"]
|
||||
elif "architectures" in conf:
|
||||
class_name = conf["architectures"][0]
|
||||
else:
|
||||
class_name = None
|
||||
else:
|
||||
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
||||
|
||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||
return type
|
||||
else:
|
||||
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
|
||||
|
||||
# give up
|
||||
raise InvalidModelConfigException(
|
||||
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_config_path(cls,
|
||||
model_path: Path,
|
||||
base_type: BaseModelType,
|
||||
variant: ModelVariantType,
|
||||
prediction_type: SchedulerPredictionType) -> Path:
|
||||
# look for a YAML file adjacent to the model file first
|
||||
possible_conf = model_path.with_suffix(".yaml")
|
||||
if possible_conf.exists():
|
||||
return possible_conf.absolute()
|
||||
config_file = LEGACY_CONFIGS[base_type][variant]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
return Path(config_file)
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
|
||||
with SilenceWarnings():
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
cls._scan_model(model_path.name, model_path)
|
||||
model = torch.load(model_path)
|
||||
assert isinstance(model, dict)
|
||||
return model
|
||||
else:
|
||||
return safetensors.torch.load_file(model_path)
|
||||
|
||||
@classmethod
|
||||
def _scan_model(cls, model_name: str, checkpoint: Path) -> None:
|
||||
"""
|
||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
|
||||
|
||||
# ##################################################3
|
||||
# Checkpoint probing
|
||||
# ##################################################3
|
||||
|
||||
class CheckpointProbeBase(ProbeBase):
|
||||
def __init__(self, model_path: Path):
|
||||
super().__init__(model_path)
|
||||
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat("checkpoint")
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
|
||||
if model_type != ModelType.Main:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
f"Cannot determine variant type (in_channels={in_channels}) at {self.model_path}"
|
||||
)
|
||||
|
||||
|
||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
else:
|
||||
raise InvalidModelConfigException("Cannot determine base type")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
"""Return model prediction type."""
|
||||
type = self.get_base_type()
|
||||
if type == BaseModelType.StableDiffusion2:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
if "global_step" in checkpoint:
|
||||
if checkpoint["global_step"] == 220000:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
elif checkpoint["global_step"] == 110000:
|
||||
return SchedulerPredictionType.VPrediction
|
||||
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
|
||||
|
||||
elif type == BaseModelType.StableDiffusion1:
|
||||
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
|
||||
else:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# I can't find any standalone 2.X VAEs to test with!
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
|
||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
"""Class for LoRA checkpoints."""
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat("lycoris")
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown LoRA type: {self.model_path}")
|
||||
|
||||
|
||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
"""Class for probing embeddings."""
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat.EmbeddingFile
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
if "string_to_token" in checkpoint:
|
||||
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
||||
elif "emb_params" in checkpoint:
|
||||
token_dim = checkpoint["emb_params"].shape[-1]
|
||||
else:
|
||||
token_dim = list(checkpoint.values())[0].shape[0]
|
||||
if token_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
else:
|
||||
raise InvalidModelConfigException("Could not determine base type")
|
||||
|
||||
|
||||
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
"""Class for probing controlnets."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
for key_name in (
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
):
|
||||
if key_name not in checkpoint:
|
||||
continue
|
||||
if checkpoint[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif checkpoint[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
raise InvalidModelConfigException("Unable to determine base type for {self.checkpoint_path}")
|
||||
|
||||
|
||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
class FolderProbeBase(ProbeBase):
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
return ModelVariantType.Normal
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat("diffusers")
|
||||
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
scheduler_conf = json.load(file)
|
||||
if scheduler_conf["prediction_type"] == "v_prediction":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif scheduler_conf["prediction_type"] == "epsilon":
|
||||
return SchedulerPredictionType.Epsilon
|
||||
else:
|
||||
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
# "normal" variant type
|
||||
try:
|
||||
config_file = self.model_path / "unet" / "config.json"
|
||||
with open(config_file, "r") as file:
|
||||
conf = json.load(file)
|
||||
|
||||
in_channels = conf["in_channels"]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
except Exception:
|
||||
pass
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self._config_looks_like_sdxl():
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif self._name_looks_like_sdxl():
|
||||
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def _config_looks_like_sdxl(self) -> bool:
|
||||
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||
config_file = self.model_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
|
||||
def _name_looks_like_sdxl(self) -> bool:
|
||||
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||
|
||||
def _guess_name(self) -> str:
|
||||
name = self.model_path.name
|
||||
if name == "vae":
|
||||
name = self.model_path.parent.name
|
||||
return name
|
||||
|
||||
|
||||
class TextualInversionFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat.EmbeddingFolder
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
path = self.model_path / "learned_embeds.bin"
|
||||
if not path.exists():
|
||||
raise InvalidModelConfigException(f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file")
|
||||
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||
|
||||
|
||||
class ONNXFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat("onnx")
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class ControlNetFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.model_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
dimension = config["cross_attention_dim"]
|
||||
base_model = (
|
||||
BaseModelType.StableDiffusion1
|
||||
if dimension == 768
|
||||
else (
|
||||
BaseModelType.StableDiffusion2
|
||||
if dimension == 1024
|
||||
else BaseModelType.StableDiffusionXL
|
||||
if dimension == 2048
|
||||
else None
|
||||
)
|
||||
)
|
||||
if not base_model:
|
||||
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
|
||||
return base_model
|
||||
|
||||
|
||||
class LoRAFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
model_file = None
|
||||
for suffix in ["safetensors", "bin"]:
|
||||
base_file = self.model_path / f"pytorch_lora_weights.{suffix}"
|
||||
if base_file.exists():
|
||||
model_file = base_file
|
||||
break
|
||||
if not model_file:
|
||||
raise InvalidModelConfigException("Unknown LoRA format encountered")
|
||||
return LoRACheckpointProbe(model_file).get_base_type()
|
||||
|
||||
|
||||
class IPAdapterFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> IPAdapterModelFormat:
|
||||
return IPAdapterModelFormat.InvokeAI.value
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
model_file = self.model_path / "ip_adapter.bin"
|
||||
if not model_file.exists():
|
||||
raise InvalidModelConfigException("Unknown IP-Adapter model format.")
|
||||
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||
if cross_attention_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif cross_attention_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif cross_attention_dim == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
|
||||
|
||||
|
||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.model_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
|
||||
adapter_type = config.get("adapter_type", None)
|
||||
if adapter_type == "full_adapter_xl":
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif adapter_type == "full_adapter" or "light_adapter":
|
||||
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
|
||||
return BaseModelType.StableDiffusion1
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
f"Unable to determine base model for '{self.model_path}' (adapter_type = {adapter_type})."
|
||||
)
|
||||
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
@ -219,6 +219,8 @@ exclude = [
|
||||
# global mypy config
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true # ignores missing types in third-party libraries
|
||||
strict = true
|
||||
exclude = ["tests/*"]
|
||||
|
||||
# overrides for specific modules
|
||||
[[tool.mypy.overrides]]
|
||||
|
@ -3,7 +3,7 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
|
||||
parser = argparse.ArgumentParser(description="Probe model type")
|
||||
parser.add_argument(
|
||||
|
89
tests/app/services/model_install/test_model_install.py
Normal file
89
tests/app/services/model_install/test_model_install.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""
|
||||
Test the model installer
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from pydantic import ValidationError
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.model_manager.config import ModelType, BaseModelType
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceBase
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||
|
||||
@pytest.fixture
|
||||
def test_file(datadir: Path) -> Path:
|
||||
return datadir / "test_embedding.safetensors"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_config(datadir: Path) -> InvokeAIAppConfig:
|
||||
return InvokeAIAppConfig(
|
||||
root=datadir / "root",
|
||||
models_dir=datadir / "root/models",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||
database = SqliteDatabase(app_config, InvokeAILogger.get_logger(config=app_config))
|
||||
store: ModelRecordServiceBase = ModelRecordServiceSQL(database)
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def installer(app_config: InvokeAIAppConfig,
|
||||
store: ModelRecordServiceBase) -> ModelInstallServiceBase:
|
||||
return ModelInstallService(app_config=app_config,
|
||||
record_store=store
|
||||
)
|
||||
|
||||
|
||||
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||
store = installer.record_store
|
||||
matches = store.search_by_attr(model_name="test_embedding")
|
||||
assert len(matches) == 0
|
||||
key = installer.register_path(test_file)
|
||||
assert key is not None
|
||||
assert len(key) == 32
|
||||
|
||||
def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||
store = installer.record_store
|
||||
key = installer.register_path(test_file)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record is not None
|
||||
assert model_record.name == "test_embedding"
|
||||
assert model_record.type == ModelType.TextualInversion
|
||||
assert Path(model_record.path) == test_file
|
||||
assert model_record.base == BaseModelType('sd-1')
|
||||
assert model_record.description is not None
|
||||
assert model_record.source is not None
|
||||
assert Path(model_record.source) == test_file
|
||||
|
||||
def test_registration_meta_override_fail(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||
key = None
|
||||
with pytest.raises(ValidationError):
|
||||
key = installer.register_path(test_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
||||
assert key is None
|
||||
|
||||
def test_registration_meta_override_succeed(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||
store = installer.record_store
|
||||
key = installer.register_path(test_file,
|
||||
{
|
||||
"name": "banana_sushi",
|
||||
"source": "fake/repo_id",
|
||||
"current_hash": "New Hash"
|
||||
}
|
||||
)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.name == "banana_sushi"
|
||||
assert model_record.source == "fake/repo_id"
|
||||
assert model_record.current_hash == "New Hash"
|
||||
|
||||
def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
|
||||
store = installer.record_store
|
||||
key = installer.install_path(test_file)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
|
||||
assert model_record.source == test_file.as_posix()
|
@ -0,0 +1,79 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: invokeai.backend.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
personalization_config:
|
||||
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
|
||||
params:
|
||||
placeholder_strings: ["*"]
|
||||
initializer_words: ['sculpture']
|
||||
per_image_tokens: false
|
||||
num_vectors_per_token: 1
|
||||
progressive_words: False
|
||||
|
||||
unet_config:
|
||||
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder
|
@ -0,0 +1 @@
|
||||
Dummy file to establish git path.
|
Binary file not shown.
Loading…
Reference in New Issue
Block a user