mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make install_path and register_path work; refactor model probing
This commit is contained in:
parent
8c7a7bc897
commit
80bc9be3ab
@ -173,7 +173,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
@ -336,10 +336,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
def get_config(cls, **kwargs: Dict[str, Any]) -> InvokeAIAppConfig:
|
||||||
"""
|
"""Return a singleton InvokeAIAppConfig configuration object."""
|
||||||
This returns a singleton InvokeAIAppConfig configuration object.
|
|
||||||
"""
|
|
||||||
if (
|
if (
|
||||||
cls.singleton_config is None
|
cls.singleton_config is None
|
||||||
or type(cls.singleton_config) is not cls
|
or type(cls.singleton_config) is not cls
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@ -343,7 +342,7 @@ class EventServiceBase:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def emit_model_install_error(self,
|
def emit_model_install_error(self,
|
||||||
source:str,
|
source: str,
|
||||||
error_type: str,
|
error_type: str,
|
||||||
error: str,
|
error: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -1 +1,2 @@
|
|||||||
|
from .model_install_base import ModelInstallServiceBase # noqa F401
|
||||||
from .model_install_default import ModelInstallService # noqa F401
|
from .model_install_default import ModelInstallService # noqa F401
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union, List
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
|
|
||||||
|
|
||||||
class InstallStatus(str, Enum):
|
class InstallStatus(str, Enum):
|
||||||
@ -27,8 +27,8 @@ class ModelInstallJob(BaseModel):
|
|||||||
source: Union[str, Path, AnyHttpUrl] = Field(description="Source (URL, repo_id, or local path) of model")
|
source: Union[str, Path, AnyHttpUrl] = Field(description="Source (URL, repo_id, or local path) of model")
|
||||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||||
local_path: Optional[Path] = Field(default=None, description="Path to locally-downloaded model")
|
local_path: Optional[Path] = Field(default=None, description="Path to locally-downloaded model")
|
||||||
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
|
error_type: str = Field(default="", description="Class name of the exception that led to status==ERROR")
|
||||||
error: Optional[str] = Field(default=None, description="Error traceback") # noqa #501
|
error: str = Field(default="", description="Error traceback") # noqa #501
|
||||||
|
|
||||||
def set_error(self, e: Exception) -> None:
|
def set_error(self, e: Exception) -> None:
|
||||||
"""Record the error and traceback from an exception."""
|
"""Record the error and traceback from an exception."""
|
||||||
@ -43,8 +43,8 @@ class ModelInstallServiceBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
store: ModelRecordServiceBase,
|
record_store: ModelRecordServiceBase,
|
||||||
event_bus: Optional["EventServiceBase"] = None,
|
event_bus: Optional["EventServiceBase"] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -54,15 +54,22 @@ class ModelInstallServiceBase(ABC):
|
|||||||
:param store: Systemwide ModelConfigStore
|
:param store: Systemwide ModelConfigStore
|
||||||
:param event_bus: InvokeAI event bus for reporting events to.
|
:param event_bus: InvokeAI event bus for reporting events to.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def app_config(self) -> InvokeAIAppConfig:
|
||||||
|
"""Return the appConfig object associated with the installer."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def record_store(self) -> ModelRecordServiceBase:
|
||||||
|
"""Return the ModelRecoreService object associated with the installer."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def register_path(
|
def register_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
name: Optional[str] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
description: Optional[str] = None,
|
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Probe and register the model at model_path.
|
Probe and register the model at model_path.
|
||||||
@ -70,21 +77,32 @@ class ModelInstallServiceBase(ABC):
|
|||||||
This keeps the model in its current location.
|
This keeps the model in its current location.
|
||||||
|
|
||||||
:param model_path: Filesystem Path to the model.
|
:param model_path: Filesystem Path to the model.
|
||||||
:param name: Name for the model (optional)
|
:param metadata: Dict of attributes that will override autoassigned values.
|
||||||
:param description: Description for the model (optional)
|
|
||||||
:param metadata: Dict of attributes that will override probed values.
|
|
||||||
:returns id: The string ID of the registered model.
|
:returns id: The string ID of the registered model.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
@abstractmethod
|
||||||
|
def unregister(self, key: str) -> None:
|
||||||
|
"""Remove model with indicated key from the database."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, key: str) -> None:
|
||||||
|
"""Remove model with indicated key from the database and delete weight files from disk."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def conditionally_delete(self, key: str) -> None:
|
||||||
|
"""
|
||||||
|
Remove model with indicated key from the database
|
||||||
|
and conditeionally delete weight files from disk
|
||||||
|
if they reside within InvokeAI's models directory.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def install_path(
|
def install_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
name: Optional[str] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
description: Optional[str] = None,
|
) -> str:
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
|
||||||
)-> str:
|
|
||||||
"""
|
"""
|
||||||
Probe, register and install the model in the models directory.
|
Probe, register and install the model in the models directory.
|
||||||
|
|
||||||
@ -92,20 +110,15 @@ class ModelInstallServiceBase(ABC):
|
|||||||
the models directory handled by InvokeAI.
|
the models directory handled by InvokeAI.
|
||||||
|
|
||||||
:param model_path: Filesystem Path to the model.
|
:param model_path: Filesystem Path to the model.
|
||||||
:param name: Name for the model (optional)
|
:param metadata: Dict of attributes that will override autoassigned values.
|
||||||
:param description: Description for the model (optional)
|
|
||||||
:param metadata: Dict of attributes that will override probed values.
|
|
||||||
:returns id: The string ID of the registered model.
|
:returns id: The string ID of the registered model.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def install_model(
|
def import_model(
|
||||||
self,
|
self,
|
||||||
source: Union[str, Path, AnyHttpUrl],
|
source: Union[str, Path, AnyHttpUrl],
|
||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
name: Optional[str] = None,
|
|
||||||
description: Optional[str] = None,
|
|
||||||
variant: Optional[str] = None,
|
variant: Optional[str] = None,
|
||||||
subfolder: Optional[str] = None,
|
subfolder: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
metadata: Optional[Dict[str, str]] = None,
|
||||||
@ -125,9 +138,9 @@ class ModelInstallServiceBase(ABC):
|
|||||||
specify a subfolder of the HF repository to download from.
|
specify a subfolder of the HF repository to download from.
|
||||||
|
|
||||||
:param metadata: Optional dict. Any fields in this dict
|
:param metadata: Optional dict. Any fields in this dict
|
||||||
will override corresponding probe fields. Use it to override
|
will override corresponding autoassigned probe fields. Use it to override
|
||||||
`base_type`, `model_type`, `format`, `prediction_type`, `image_size`,
|
`name`, `description`, `base_type`, `model_type`, `format`,
|
||||||
and `ztsnr_training`.
|
`prediction_type`, `image_size`, and/or `ztsnr_training`.
|
||||||
|
|
||||||
:param access_token: Access token for use in downloading remote
|
:param access_token: Access token for use in downloading remote
|
||||||
models.
|
models.
|
||||||
@ -154,10 +167,9 @@ class ModelInstallServiceBase(ABC):
|
|||||||
4. None (usually returns fp32 model)
|
4. None (usually returns fp32 model)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]:
|
||||||
"""
|
"""
|
||||||
Wait for all pending installs to complete.
|
Wait for all pending installs to complete.
|
||||||
|
|
||||||
@ -169,7 +181,6 @@ class ModelInstallServiceBase(ABC):
|
|||||||
It will return a dict that maps the source model
|
It will return a dict that maps the source model
|
||||||
path, URL or repo_id to the ID of the installed model.
|
path, URL or repo_id to the ID of the installed model.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
|
||||||
@ -180,20 +191,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
:param install: Install if True, otherwise register in place.
|
:param install: Install if True, otherwise register in place.
|
||||||
:returns list of IDs: Returns list of IDs of models registered/installed
|
:returns list of IDs: Returns list of IDs of models registered/installed
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def sync_to_config(self):
|
def sync_to_config(self) -> None:
|
||||||
"""Synchronize models on disk to those in memory."""
|
"""Synchronize models on disk to those in memory."""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def hash(self, model_path: Union[Path, str]) -> str:
|
|
||||||
"""
|
|
||||||
Compute and return the fast hash of the model.
|
|
||||||
|
|
||||||
:param model_path: Path to the model on disk.
|
|
||||||
:return str: FastHash of the model for use as an ID.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
@ -1,22 +1,28 @@
|
|||||||
"""Model installation class."""
|
"""Model installation class."""
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
from hashlib import sha256
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union, List
|
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
from random import randbytes
|
||||||
|
from shutil import move, rmtree
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
|
||||||
from .model_install_base import InstallStatus, ModelInstallJob, ModelInstallServiceBase
|
|
||||||
|
|
||||||
from invokeai.backend.model_management.model_probe import ModelProbeInfo, ModelProbe
|
|
||||||
from invokeai.backend.model_manager.config import InvalidModelConfigException, DuplicateModelException
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
|
from invokeai.backend.model_manager.config import (
|
||||||
|
AnyModelConfig,
|
||||||
|
DuplicateModelException,
|
||||||
|
InvalidModelConfigException,
|
||||||
|
)
|
||||||
from invokeai.backend.model_manager.hash import FastModelHash
|
from invokeai.backend.model_manager.hash import FastModelHash
|
||||||
|
from invokeai.backend.model_manager.probe import ModelProbe
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
from .model_install_base import InstallStatus, ModelInstallJob, ModelInstallServiceBase
|
||||||
|
|
||||||
# marker that the queue is done and that thread should exit
|
# marker that the queue is done and that thread should exit
|
||||||
STOP_JOB = ModelInstallJob(source="stop")
|
STOP_JOB = ModelInstallJob(source="stop")
|
||||||
@ -25,113 +31,195 @@ STOP_JOB = ModelInstallJob(source="stop")
|
|||||||
class ModelInstallService(ModelInstallServiceBase):
|
class ModelInstallService(ModelInstallServiceBase):
|
||||||
"""class for InvokeAI model installation."""
|
"""class for InvokeAI model installation."""
|
||||||
|
|
||||||
config: InvokeAIAppConfig
|
_app_config: InvokeAIAppConfig
|
||||||
store: ModelRecordServiceBase
|
_record_store: ModelRecordServiceBase
|
||||||
_event_bus: Optional[EventServiceBase] = None
|
_event_bus: Optional[EventServiceBase] = None
|
||||||
_install_queue: Queue
|
_install_queue: Queue[ModelInstallJob]
|
||||||
|
_install_jobs: Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]
|
||||||
|
_logger: InvokeAILogger
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
store: ModelRecordServiceBase,
|
record_store: ModelRecordServiceBase,
|
||||||
install_queue: Optional[Queue] = None,
|
|
||||||
event_bus: Optional[EventServiceBase] = None
|
event_bus: Optional[EventServiceBase] = None
|
||||||
):
|
):
|
||||||
self.config = config
|
"""
|
||||||
self.store = store
|
Initialize the installer object.
|
||||||
self._install_queue = install_queue or Queue()
|
|
||||||
|
:param app_config: InvokeAIAppConfig object
|
||||||
|
:param record_store: Previously-opened ModelRecordService database
|
||||||
|
:param event_bus: Optional EventService object
|
||||||
|
"""
|
||||||
|
self._app_config = app_config
|
||||||
|
self._record_store = record_store
|
||||||
|
self._install_queue = Queue()
|
||||||
self._event_bus = event_bus
|
self._event_bus = event_bus
|
||||||
|
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
|
||||||
self._start_installer_thread()
|
self._start_installer_thread()
|
||||||
|
|
||||||
def _start_installer_thread(self):
|
@property
|
||||||
|
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||||
|
return self._app_config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def record_store(self) -> ModelRecordServiceBase: # noqa D102
|
||||||
|
return self._record_store
|
||||||
|
|
||||||
|
def _start_installer_thread(self) -> None:
|
||||||
threading.Thread(target=self._install_next_item, daemon=True).start()
|
threading.Thread(target=self._install_next_item, daemon=True).start()
|
||||||
|
|
||||||
def _install_next_item(self):
|
def _install_next_item(self) -> None:
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
job = self._install_queue.get()
|
job = self._install_queue.get()
|
||||||
if job == STOP_JOB:
|
if job == STOP_JOB:
|
||||||
done = True
|
done = True
|
||||||
elif job.status == InstallStatus.WAITING:
|
elif job.status == InstallStatus.WAITING:
|
||||||
|
assert job.local_path is not None
|
||||||
try:
|
try:
|
||||||
self._signal_job_running(job)
|
self._signal_job_running(job)
|
||||||
self.register_path(job.path)
|
self.register_path(job.local_path)
|
||||||
self._signal_job_completed(job)
|
self._signal_job_completed(job)
|
||||||
except (OSError, DuplicateModelException, InvalidModelConfigException) as e:
|
except (OSError, DuplicateModelException, InvalidModelConfigException) as excp:
|
||||||
self._signal_job_errored(job, e)
|
self._signal_job_errored(job, excp)
|
||||||
|
|
||||||
def _signal_job_running(self, job: ModelInstallJob):
|
def _signal_job_running(self, job: ModelInstallJob) -> None:
|
||||||
job.status = InstallStatus.RUNNING
|
job.status = InstallStatus.RUNNING
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_model_install_started(job.source)
|
self._event_bus.emit_model_install_started(str(job.source))
|
||||||
|
|
||||||
def _signal_job_completed(self, job: ModelInstallJob):
|
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||||
job.status = InstallStatus.COMPLETED
|
job.status = InstallStatus.COMPLETED
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_model_install_completed(job.source, job.dest)
|
assert job.local_path is not None
|
||||||
|
self._event_bus.emit_model_install_completed(str(job.source), job.local_path.as_posix())
|
||||||
|
|
||||||
def _signal_job_errored(self, job: ModelInstallJob, e: Exception):
|
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
|
||||||
job.set_error(e)
|
job.set_error(excp)
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_model_install_error(job.source, job.error_type, job.error)
|
self._event_bus.emit_model_install_error(str(job.source), job.error_type, job.error)
|
||||||
|
|
||||||
def register_path(
|
def register_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
name: Optional[str] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
description: Optional[str] = None,
|
) -> str: # noqa D102
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
|
||||||
) -> str:
|
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
info: ModelProbeInfo = self._probe_model(model_path, metadata)
|
metadata = metadata or {}
|
||||||
return self._register(model_path, info)
|
if metadata.get('source') is None:
|
||||||
|
metadata['source'] = model_path.as_posix()
|
||||||
|
return self._register(model_path, metadata)
|
||||||
|
|
||||||
def install_path(
|
def install_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
name: Optional[str] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
description: Optional[str] = None,
|
) -> str: # noqa D102
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
model_path = Path(model_path)
|
||||||
) -> str:
|
metadata = metadata or {}
|
||||||
raise NotImplementedError
|
if metadata.get('source') is None:
|
||||||
|
metadata['source'] = model_path.as_posix()
|
||||||
|
|
||||||
def install_model(
|
info: AnyModelConfig = self._probe_model(Path(model_path), metadata)
|
||||||
|
|
||||||
|
old_hash = info.original_hash
|
||||||
|
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
|
||||||
|
new_path = self._move_model(model_path, dest_path)
|
||||||
|
new_hash = FastModelHash.hash(new_path)
|
||||||
|
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||||
|
|
||||||
|
return self._register(
|
||||||
|
new_path,
|
||||||
|
metadata,
|
||||||
|
info,
|
||||||
|
)
|
||||||
|
|
||||||
|
def import_model(
|
||||||
self,
|
self,
|
||||||
source: Union[str, Path, AnyHttpUrl],
|
source: Union[str, Path, AnyHttpUrl],
|
||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
name: Optional[str] = None,
|
|
||||||
description: Optional[str] = None,
|
|
||||||
variant: Optional[str] = None,
|
variant: Optional[str] = None,
|
||||||
subfolder: Optional[str] = None,
|
subfolder: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
metadata: Optional[Dict[str, str]] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob: # noqa D102
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]: # noqa D102
|
||||||
self._install_queue.join()
|
self._install_queue.join()
|
||||||
|
return self._install_jobs
|
||||||
|
|
||||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def sync_to_config(self):
|
def sync_to_config(self) -> None: # noqa D102
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def hash(self, model_path: Union[Path, str]) -> str:
|
def unregister(self, key: str) -> None: # noqa D102
|
||||||
return FastModelHash.hash(model_path)
|
self.record_store.del_model(key)
|
||||||
|
|
||||||
# The following are internal methods
|
def delete(self, key: str) -> None: # noqa D102
|
||||||
def _create_name(self, model_path: Union[Path, str]) -> str:
|
model = self.record_store.get_model(key)
|
||||||
model_path = Path(model_path)
|
path = self.app_config.models_path / model.path
|
||||||
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}:
|
if path.is_dir():
|
||||||
return model_path.stem
|
rmtree(path)
|
||||||
else:
|
else:
|
||||||
return model_path.name
|
path.unlink()
|
||||||
|
self.unregister(key)
|
||||||
|
|
||||||
def _create_description(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str:
|
def conditionally_delete(self, key: str) -> None: # noqa D102
|
||||||
info = info or ModelProbe.probe(Path(model_path))
|
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||||
name: str = self._create_name(model_path)
|
model = self.record_store.get_model(key)
|
||||||
return f"a {info.model_type} model {name} based on {info.base_type}"
|
models_dir = self.app_config.models_path
|
||||||
|
model_path = models_dir / model.path
|
||||||
|
if model_path.is_relative_to(models_dir):
|
||||||
|
self.delete(key)
|
||||||
|
else:
|
||||||
|
self.unregister(key)
|
||||||
|
|
||||||
def _create_id(self, model_path: Union[Path, str], name: Optional[str] = None) -> str:
|
def _move_model(self, old_path: Path, new_path: Path) -> Path:
|
||||||
name: str = name or self._create_name(model_path)
|
if old_path == new_path:
|
||||||
raise NotImplementedError
|
return old_path
|
||||||
|
|
||||||
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# if path already exists then we jigger the name to make it unique
|
||||||
|
counter: int = 1
|
||||||
|
while new_path.exists():
|
||||||
|
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
|
||||||
|
if not path.exists():
|
||||||
|
new_path = path
|
||||||
|
counter += 1
|
||||||
|
move(old_path, new_path)
|
||||||
|
return new_path
|
||||||
|
|
||||||
|
def _probe_model(self, model_path: Path, metadata: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
||||||
|
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
||||||
|
if metadata: # used to override probe fields
|
||||||
|
for key, value in metadata.items():
|
||||||
|
setattr(info, key, value)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def _create_key(self) -> str:
|
||||||
|
return sha256(randbytes(100)).hexdigest()[0:32]
|
||||||
|
|
||||||
|
def _register(self,
|
||||||
|
model_path: Path,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
info: Optional[AnyModelConfig] = None) -> str:
|
||||||
|
|
||||||
|
info = info or ModelProbe.probe(model_path, metadata)
|
||||||
|
key = self._create_key()
|
||||||
|
|
||||||
|
model_path = model_path.absolute()
|
||||||
|
if model_path.is_relative_to(self.app_config.models_path):
|
||||||
|
model_path = model_path.relative_to(self.app_config.models_path)
|
||||||
|
|
||||||
|
info.path = model_path.as_posix()
|
||||||
|
|
||||||
|
# add 'main' specific fields
|
||||||
|
if hasattr(info, 'config'):
|
||||||
|
# make config relative to our root
|
||||||
|
info.config = self.app_config.legacy_conf_dir / info.config
|
||||||
|
self.record_store.add_model(key, info)
|
||||||
|
return key
|
||||||
|
@ -6,3 +6,11 @@ from .model_records_base import ( # noqa F401
|
|||||||
UnknownModelException,
|
UnknownModelException,
|
||||||
)
|
)
|
||||||
from .model_records_sql import ModelRecordServiceSQL # noqa F401
|
from .model_records_sql import ModelRecordServiceSQL # noqa F401
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'ModelRecordServiceBase',
|
||||||
|
'ModelRecordServiceSQL',
|
||||||
|
'DuplicateModelException',
|
||||||
|
'InvalidModelException',
|
||||||
|
'UnknownModelException',
|
||||||
|
]
|
||||||
|
@ -23,7 +23,7 @@ from enum import Enum
|
|||||||
from typing import Literal, Optional, Type, Union
|
from typing import Literal, Optional, Type, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated, Any, Dict
|
||||||
|
|
||||||
|
|
||||||
class InvalidModelConfigException(Exception):
|
class InvalidModelConfigException(Exception):
|
||||||
@ -125,7 +125,7 @@ class ModelConfigBase(BaseModel):
|
|||||||
validate_assignment=True,
|
validate_assignment=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update(self, attributes: dict):
|
def update(self, attributes: Dict[str, Any]) -> None:
|
||||||
"""Update the object with fields in dict."""
|
"""Update the object with fields in dict."""
|
||||||
for key, value in attributes.items():
|
for key, value in attributes.items():
|
||||||
setattr(self, key, value) # may raise a validation error
|
setattr(self, key, value) # may raise a validation error
|
||||||
@ -198,8 +198,6 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
|||||||
"""Model config for main checkpoint models."""
|
"""Model config for main checkpoint models."""
|
||||||
|
|
||||||
type: Literal[ModelType.Main] = ModelType.Main
|
type: Literal[ModelType.Main] = ModelType.Main
|
||||||
# Note that we do not need prediction_type or upcast_attention here
|
|
||||||
# because they are provided in the checkpoint's own config file.
|
|
||||||
|
|
||||||
|
|
||||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||||
|
645
invokeai/backend/model_manager/probe.py
Normal file
645
invokeai/backend/model_manager/probe.py
Normal file
@ -0,0 +1,645 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
|
import safetensors.torch
|
||||||
|
import torch
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
||||||
|
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
||||||
|
from invokeai.backend.model_management.util import lora_token_vector_length
|
||||||
|
from invokeai.backend.util.util import SilenceWarnings
|
||||||
|
|
||||||
|
from .config import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
InvalidModelConfigException,
|
||||||
|
ModelConfigFactory,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
ModelVariantType,
|
||||||
|
SchedulerPredictionType,
|
||||||
|
)
|
||||||
|
from .hash import FastModelHash
|
||||||
|
|
||||||
|
CkptType = Dict[str, Any]
|
||||||
|
|
||||||
|
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
|
||||||
|
BaseModelType.StableDiffusion1: {
|
||||||
|
ModelVariantType.Normal: "v1-inference.yaml",
|
||||||
|
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||||
|
},
|
||||||
|
BaseModelType.StableDiffusion2: {
|
||||||
|
ModelVariantType.Normal: {
|
||||||
|
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
||||||
|
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
||||||
|
},
|
||||||
|
ModelVariantType.Inpaint: {
|
||||||
|
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
||||||
|
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
BaseModelType.StableDiffusionXL: {
|
||||||
|
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||||
|
},
|
||||||
|
BaseModelType.StableDiffusionXLRefiner: {
|
||||||
|
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
class ProbeBase(object):
|
||||||
|
"""Base class for probes."""
|
||||||
|
|
||||||
|
def __init__(self, model_path: Path):
|
||||||
|
self.model_path = model_path
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
"""Get model base type."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
"""Get model file format."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_variant_type(self) -> Optional[ModelVariantType]:
|
||||||
|
"""Get model variant type."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
|
||||||
|
"""Get model scheduler prediction type."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
class ModelProbe(object):
|
||||||
|
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
||||||
|
"diffusers": {},
|
||||||
|
"checkpoint": {},
|
||||||
|
"onnx": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
CLASS2TYPE = {
|
||||||
|
"StableDiffusionPipeline": ModelType.Main,
|
||||||
|
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||||
|
"StableDiffusionXLPipeline": ModelType.Main,
|
||||||
|
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||||
|
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||||
|
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||||
|
"AutoencoderKL": ModelType.Vae,
|
||||||
|
"AutoencoderTiny": ModelType.Vae,
|
||||||
|
"ControlNetModel": ModelType.ControlNet,
|
||||||
|
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||||
|
"T2IAdapter": ModelType.T2IAdapter,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_probe(
|
||||||
|
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
|
||||||
|
) -> None:
|
||||||
|
cls.PROBES[format][model_type] = probe_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def heuristic_probe(
|
||||||
|
cls,
|
||||||
|
model_path: Path,
|
||||||
|
fields: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> AnyModelConfig:
|
||||||
|
return cls.probe(model_path, fields)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def probe(
|
||||||
|
cls,
|
||||||
|
model_path: Path,
|
||||||
|
fields: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> AnyModelConfig:
|
||||||
|
"""
|
||||||
|
Probe the model at model_path and return sufficient information about it
|
||||||
|
to place it somewhere in the models directory hierarchy. If the model is
|
||||||
|
already loaded into memory, you may provide it as model in order to avoid
|
||||||
|
opening it a second time. The prediction_type_helper callable is a function that receives
|
||||||
|
the path to the model and returns the SchedulerPredictionType.
|
||||||
|
"""
|
||||||
|
if fields is None:
|
||||||
|
fields = {}
|
||||||
|
|
||||||
|
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||||
|
model_info = None
|
||||||
|
model_type = None
|
||||||
|
if format_type == "diffusers":
|
||||||
|
model_type = cls.get_model_type_from_folder(model_path)
|
||||||
|
else:
|
||||||
|
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||||
|
print(f'DEBUG: model_type={model_type}')
|
||||||
|
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
|
||||||
|
|
||||||
|
probe_class = cls.PROBES[format_type].get(model_type)
|
||||||
|
if not probe_class:
|
||||||
|
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||||
|
|
||||||
|
hash = FastModelHash.hash(model_path)
|
||||||
|
probe = probe_class(model_path)
|
||||||
|
|
||||||
|
fields['path'] = model_path.as_posix()
|
||||||
|
fields['type'] = fields.get('type') or model_type
|
||||||
|
fields['base'] = fields.get('base') or probe.get_base_type()
|
||||||
|
fields['variant'] = fields.get('variant') or probe.get_variant_type()
|
||||||
|
fields['prediction_type'] = fields.get('prediction_type') or probe.get_scheduler_prediction_type()
|
||||||
|
fields['name'] = fields.get('name') or cls.get_model_name(model_path)
|
||||||
|
fields['description'] = fields.get('description') or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||||
|
fields['format'] = fields.get('format') or probe.get_format()
|
||||||
|
fields['original_hash'] = fields.get('original_hash') or hash
|
||||||
|
fields['current_hash'] = fields.get('current_hash') or hash
|
||||||
|
|
||||||
|
# additional work for main models
|
||||||
|
if fields['type'] == ModelType.Main:
|
||||||
|
if fields['format'] == ModelFormat.Checkpoint:
|
||||||
|
fields['config'] = cls._get_config_path(model_path, fields['base'], fields['variant'], fields['prediction_type']).as_posix()
|
||||||
|
elif fields['format'] in [ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers]:
|
||||||
|
fields['upcast_attention'] = fields.get('upcast_attention') or (
|
||||||
|
fields['base'] == BaseModelType.StableDiffusion2 and fields['prediction_type'] == SchedulerPredictionType.VPrediction
|
||||||
|
)
|
||||||
|
|
||||||
|
model_info = ModelConfigFactory.make_config(fields)
|
||||||
|
return model_info
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_name(cls, model_path: Path) -> str:
|
||||||
|
if model_path.suffix in {'.safetensors', '.bin', '.pt', '.ckpt'}:
|
||||||
|
return model_path.stem
|
||||||
|
else:
|
||||||
|
return model_path.name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType:
|
||||||
|
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||||
|
raise InvalidModelConfigException(f"{model_path}: unrecognized suffix")
|
||||||
|
|
||||||
|
if model_path.name == "learned_embeds.bin":
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
|
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||||
|
ckpt = ckpt.get("state_dict", ckpt)
|
||||||
|
|
||||||
|
for key in ckpt.keys():
|
||||||
|
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||||
|
return ModelType.Main
|
||||||
|
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||||
|
return ModelType.Vae
|
||||||
|
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||||
|
return ModelType.Lora
|
||||||
|
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||||
|
return ModelType.Lora
|
||||||
|
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||||
|
return ModelType.ControlNet
|
||||||
|
elif key in {"emb_params", "string_to_param"}:
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
|
else:
|
||||||
|
# diffusers-ti
|
||||||
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
|
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_type_from_folder(cls, folder_path: Path) -> ModelType:
|
||||||
|
"""Get the model type of a hugging-face style folder."""
|
||||||
|
class_name = None
|
||||||
|
error_hint = None
|
||||||
|
for suffix in ["bin", "safetensors"]:
|
||||||
|
if (folder_path / f"learned_embeds.{suffix}").exists():
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
||||||
|
return ModelType.Lora
|
||||||
|
if (folder_path / "unet/model.onnx").exists():
|
||||||
|
return ModelType.ONNX
|
||||||
|
if (folder_path / "image_encoder.txt").exists():
|
||||||
|
return ModelType.IPAdapter
|
||||||
|
|
||||||
|
i = folder_path / "model_index.json"
|
||||||
|
c = folder_path / "config.json"
|
||||||
|
config_path = i if i.exists() else c if c.exists() else None
|
||||||
|
|
||||||
|
if config_path:
|
||||||
|
with open(config_path, "r") as file:
|
||||||
|
conf = json.load(file)
|
||||||
|
if "_class_name" in conf:
|
||||||
|
class_name = conf["_class_name"]
|
||||||
|
elif "architectures" in conf:
|
||||||
|
class_name = conf["architectures"][0]
|
||||||
|
else:
|
||||||
|
class_name = None
|
||||||
|
else:
|
||||||
|
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
||||||
|
|
||||||
|
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||||
|
return type
|
||||||
|
else:
|
||||||
|
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
|
||||||
|
|
||||||
|
# give up
|
||||||
|
raise InvalidModelConfigException(
|
||||||
|
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_config_path(cls,
|
||||||
|
model_path: Path,
|
||||||
|
base_type: BaseModelType,
|
||||||
|
variant: ModelVariantType,
|
||||||
|
prediction_type: SchedulerPredictionType) -> Path:
|
||||||
|
# look for a YAML file adjacent to the model file first
|
||||||
|
possible_conf = model_path.with_suffix(".yaml")
|
||||||
|
if possible_conf.exists():
|
||||||
|
return possible_conf.absolute()
|
||||||
|
config_file = LEGACY_CONFIGS[base_type][variant]
|
||||||
|
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||||
|
config_file = config_file[prediction_type]
|
||||||
|
return Path(config_file)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
|
||||||
|
with SilenceWarnings():
|
||||||
|
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||||
|
cls._scan_model(model_path.name, model_path)
|
||||||
|
model = torch.load(model_path)
|
||||||
|
assert isinstance(model, dict)
|
||||||
|
return model
|
||||||
|
else:
|
||||||
|
return safetensors.torch.load_file(model_path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _scan_model(cls, model_name: str, checkpoint: Path) -> None:
|
||||||
|
"""
|
||||||
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||||
|
and option to exit if an infected file is identified.
|
||||||
|
"""
|
||||||
|
# scan model
|
||||||
|
scan_result = scan_file_path(checkpoint)
|
||||||
|
if scan_result.infected_files != 0:
|
||||||
|
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||||
|
|
||||||
|
|
||||||
|
# ##################################################3
|
||||||
|
# Checkpoint probing
|
||||||
|
# ##################################################3
|
||||||
|
|
||||||
|
class CheckpointProbeBase(ProbeBase):
|
||||||
|
def __init__(self, model_path: Path):
|
||||||
|
super().__init__(model_path)
|
||||||
|
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||||
|
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
return ModelFormat("checkpoint")
|
||||||
|
|
||||||
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
|
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
|
||||||
|
if model_type != ModelType.Main:
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||||
|
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
if in_channels == 9:
|
||||||
|
return ModelVariantType.Inpaint
|
||||||
|
elif in_channels == 5:
|
||||||
|
return ModelVariantType.Depth
|
||||||
|
elif in_channels == 4:
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException(
|
||||||
|
f"Cannot determine variant type (in_channels={in_channels}) at {self.model_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||||
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
||||||
|
return BaseModelType.StableDiffusionXLRefiner
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException("Cannot determine base type")
|
||||||
|
|
||||||
|
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||||
|
"""Return model prediction type."""
|
||||||
|
type = self.get_base_type()
|
||||||
|
if type == BaseModelType.StableDiffusion2:
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||||
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||||
|
if "global_step" in checkpoint:
|
||||||
|
if checkpoint["global_step"] == 220000:
|
||||||
|
return SchedulerPredictionType.Epsilon
|
||||||
|
elif checkpoint["global_step"] == 110000:
|
||||||
|
return SchedulerPredictionType.VPrediction
|
||||||
|
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
|
||||||
|
|
||||||
|
elif type == BaseModelType.StableDiffusion1:
|
||||||
|
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
|
||||||
|
else:
|
||||||
|
return SchedulerPredictionType.Epsilon
|
||||||
|
|
||||||
|
|
||||||
|
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
# I can't find any standalone 2.X VAEs to test with!
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
|
|
||||||
|
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||||
|
"""Class for LoRA checkpoints."""
|
||||||
|
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
return ModelFormat("lycoris")
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
token_vector_length = lora_token_vector_length(checkpoint)
|
||||||
|
|
||||||
|
if token_vector_length == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
elif token_vector_length == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif token_vector_length == 2048:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException(f"Unknown LoRA type: {self.model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||||
|
"""Class for probing embeddings."""
|
||||||
|
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
return ModelFormat.EmbeddingFile
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
if "string_to_token" in checkpoint:
|
||||||
|
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
||||||
|
elif "emb_params" in checkpoint:
|
||||||
|
token_dim = checkpoint["emb_params"].shape[-1]
|
||||||
|
else:
|
||||||
|
token_dim = list(checkpoint.values())[0].shape[0]
|
||||||
|
if token_dim == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
elif token_dim == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException("Could not determine base type")
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||||
|
"""Class for probing controlnets."""
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
for key_name in (
|
||||||
|
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||||
|
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||||
|
):
|
||||||
|
if key_name not in checkpoint:
|
||||||
|
continue
|
||||||
|
if checkpoint[key_name].shape[-1] == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
elif checkpoint[key_name].shape[-1] == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
raise InvalidModelConfigException("Unable to determine base type for {self.checkpoint_path}")
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
########################################################
|
||||||
|
# classes for probing folders
|
||||||
|
#######################################################
|
||||||
|
class FolderProbeBase(ProbeBase):
|
||||||
|
|
||||||
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
return ModelFormat("diffusers")
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||||
|
unet_conf = json.load(file)
|
||||||
|
if unet_conf["cross_attention_dim"] == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
elif unet_conf["cross_attention_dim"] == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif unet_conf["cross_attention_dim"] == 1280:
|
||||||
|
return BaseModelType.StableDiffusionXLRefiner
|
||||||
|
elif unet_conf["cross_attention_dim"] == 2048:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||||
|
|
||||||
|
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||||
|
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||||
|
scheduler_conf = json.load(file)
|
||||||
|
if scheduler_conf["prediction_type"] == "v_prediction":
|
||||||
|
return SchedulerPredictionType.VPrediction
|
||||||
|
elif scheduler_conf["prediction_type"] == "epsilon":
|
||||||
|
return SchedulerPredictionType.Epsilon
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
|
||||||
|
|
||||||
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
|
# This only works for pipelines! Any kind of
|
||||||
|
# exception results in our returning the
|
||||||
|
# "normal" variant type
|
||||||
|
try:
|
||||||
|
config_file = self.model_path / "unet" / "config.json"
|
||||||
|
with open(config_file, "r") as file:
|
||||||
|
conf = json.load(file)
|
||||||
|
|
||||||
|
in_channels = conf["in_channels"]
|
||||||
|
if in_channels == 9:
|
||||||
|
return ModelVariantType.Inpaint
|
||||||
|
elif in_channels == 5:
|
||||||
|
return ModelVariantType.Depth
|
||||||
|
elif in_channels == 4:
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
|
||||||
|
class VaeFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
if self._config_looks_like_sdxl():
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
elif self._name_looks_like_sdxl():
|
||||||
|
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||||
|
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
else:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
|
def _config_looks_like_sdxl(self) -> bool:
|
||||||
|
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||||
|
config_file = self.model_path / "config.json"
|
||||||
|
if not config_file.exists():
|
||||||
|
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||||
|
with open(config_file, "r") as file:
|
||||||
|
config = json.load(file)
|
||||||
|
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||||
|
|
||||||
|
def _name_looks_like_sdxl(self) -> bool:
|
||||||
|
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||||
|
|
||||||
|
def _guess_name(self) -> str:
|
||||||
|
name = self.model_path.name
|
||||||
|
if name == "vae":
|
||||||
|
name = self.model_path.parent.name
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionFolderProbe(FolderProbeBase):
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
return ModelFormat.EmbeddingFolder
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
path = self.model_path / "learned_embeds.bin"
|
||||||
|
if not path.exists():
|
||||||
|
raise InvalidModelConfigException(f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file")
|
||||||
|
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXFolderProbe(FolderProbeBase):
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
return ModelFormat("onnx")
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
config_file = self.model_path / "config.json"
|
||||||
|
if not config_file.exists():
|
||||||
|
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||||
|
with open(config_file, "r") as file:
|
||||||
|
config = json.load(file)
|
||||||
|
# no obvious way to distinguish between sd2-base and sd2-768
|
||||||
|
dimension = config["cross_attention_dim"]
|
||||||
|
base_model = (
|
||||||
|
BaseModelType.StableDiffusion1
|
||||||
|
if dimension == 768
|
||||||
|
else (
|
||||||
|
BaseModelType.StableDiffusion2
|
||||||
|
if dimension == 1024
|
||||||
|
else BaseModelType.StableDiffusionXL
|
||||||
|
if dimension == 2048
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not base_model:
|
||||||
|
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
|
||||||
|
return base_model
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
model_file = None
|
||||||
|
for suffix in ["safetensors", "bin"]:
|
||||||
|
base_file = self.model_path / f"pytorch_lora_weights.{suffix}"
|
||||||
|
if base_file.exists():
|
||||||
|
model_file = base_file
|
||||||
|
break
|
||||||
|
if not model_file:
|
||||||
|
raise InvalidModelConfigException("Unknown LoRA format encountered")
|
||||||
|
return LoRACheckpointProbe(model_file).get_base_type()
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterFolderProbe(FolderProbeBase):
|
||||||
|
def get_format(self) -> IPAdapterModelFormat:
|
||||||
|
return IPAdapterModelFormat.InvokeAI.value
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
model_file = self.model_path / "ip_adapter.bin"
|
||||||
|
if not model_file.exists():
|
||||||
|
raise InvalidModelConfigException("Unknown IP-Adapter model format.")
|
||||||
|
|
||||||
|
state_dict = torch.load(model_file, map_location="cpu")
|
||||||
|
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||||
|
if cross_attention_dim == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
elif cross_attention_dim == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif cross_attention_dim == 2048:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
return BaseModelType.Any
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
config_file = self.model_path / "config.json"
|
||||||
|
if not config_file.exists():
|
||||||
|
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||||
|
with open(config_file, "r") as file:
|
||||||
|
config = json.load(file)
|
||||||
|
|
||||||
|
adapter_type = config.get("adapter_type", None)
|
||||||
|
if adapter_type == "full_adapter_xl":
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
elif adapter_type == "full_adapter" or "light_adapter":
|
||||||
|
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException(
|
||||||
|
f"Unable to determine base model for '{self.model_path}' (adapter_type = {adapter_type})."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############## register probe classes ######
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||||
|
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||||
|
|
||||||
|
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
@ -219,6 +219,8 @@ exclude = [
|
|||||||
# global mypy config
|
# global mypy config
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
ignore_missing_imports = true # ignores missing types in third-party libraries
|
ignore_missing_imports = true # ignores missing types in third-party libraries
|
||||||
|
strict = true
|
||||||
|
exclude = ["tests/*"]
|
||||||
|
|
||||||
# overrides for specific modules
|
# overrides for specific modules
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from invokeai.backend.model_management.model_probe import ModelProbe
|
from invokeai.backend.model_manager.probe import ModelProbe
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Probe model type")
|
parser = argparse.ArgumentParser(description="Probe model type")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
89
tests/app/services/model_install/test_model_install.py
Normal file
89
tests/app/services/model_install/test_model_install.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
"""
|
||||||
|
Test the model installer
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
from invokeai.backend.model_manager.config import ModelType, BaseModelType
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceBase
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_file(datadir: Path) -> Path:
|
||||||
|
return datadir / "test_embedding.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app_config(datadir: Path) -> InvokeAIAppConfig:
|
||||||
|
return InvokeAIAppConfig(
|
||||||
|
root=datadir / "root",
|
||||||
|
models_dir=datadir / "root/models",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||||
|
database = SqliteDatabase(app_config, InvokeAILogger.get_logger(config=app_config))
|
||||||
|
store: ModelRecordServiceBase = ModelRecordServiceSQL(database)
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def installer(app_config: InvokeAIAppConfig,
|
||||||
|
store: ModelRecordServiceBase) -> ModelInstallServiceBase:
|
||||||
|
return ModelInstallService(app_config=app_config,
|
||||||
|
record_store=store
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||||
|
store = installer.record_store
|
||||||
|
matches = store.search_by_attr(model_name="test_embedding")
|
||||||
|
assert len(matches) == 0
|
||||||
|
key = installer.register_path(test_file)
|
||||||
|
assert key is not None
|
||||||
|
assert len(key) == 32
|
||||||
|
|
||||||
|
def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||||
|
store = installer.record_store
|
||||||
|
key = installer.register_path(test_file)
|
||||||
|
model_record = store.get_model(key)
|
||||||
|
assert model_record is not None
|
||||||
|
assert model_record.name == "test_embedding"
|
||||||
|
assert model_record.type == ModelType.TextualInversion
|
||||||
|
assert Path(model_record.path) == test_file
|
||||||
|
assert model_record.base == BaseModelType('sd-1')
|
||||||
|
assert model_record.description is not None
|
||||||
|
assert model_record.source is not None
|
||||||
|
assert Path(model_record.source) == test_file
|
||||||
|
|
||||||
|
def test_registration_meta_override_fail(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||||
|
key = None
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
key = installer.register_path(test_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
||||||
|
assert key is None
|
||||||
|
|
||||||
|
def test_registration_meta_override_succeed(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||||
|
store = installer.record_store
|
||||||
|
key = installer.register_path(test_file,
|
||||||
|
{
|
||||||
|
"name": "banana_sushi",
|
||||||
|
"source": "fake/repo_id",
|
||||||
|
"current_hash": "New Hash"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
model_record = store.get_model(key)
|
||||||
|
assert model_record.name == "banana_sushi"
|
||||||
|
assert model_record.source == "fake/repo_id"
|
||||||
|
assert model_record.current_hash == "New Hash"
|
||||||
|
|
||||||
|
def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
|
||||||
|
store = installer.record_store
|
||||||
|
key = installer.install_path(test_file)
|
||||||
|
model_record = store.get_model(key)
|
||||||
|
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
|
||||||
|
assert model_record.source == test_file.as_posix()
|
@ -0,0 +1,79 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: invokeai.backend.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
personalization_config:
|
||||||
|
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
|
||||||
|
params:
|
||||||
|
placeholder_strings: ["*"]
|
||||||
|
initializer_words: ['sculpture']
|
||||||
|
per_image_tokens: false
|
||||||
|
num_vectors_per_token: 1
|
||||||
|
progressive_words: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder
|
@ -0,0 +1 @@
|
|||||||
|
Dummy file to establish git path.
|
Binary file not shown.
Loading…
Reference in New Issue
Block a user