mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
all backend features in place; config scanning is failing on controlnet
This commit is contained in:
parent
80bc9be3ab
commit
19baea1883
@ -1,6 +1,5 @@
|
|||||||
"""
|
"""Init file for InvokeAI configure package."""
|
||||||
Init file for InvokeAI configure package
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .config_base import PagingArgumentParser # noqa F401
|
from .config_default import InvokeAIAppConfig, get_invokeai_config
|
||||||
from .config_default import InvokeAIAppConfig, get_invokeai_config # noqa F401
|
|
||||||
|
__all__ = ['InvokeAIAppConfig', 'get_invokeai_config']
|
||||||
|
@ -323,21 +323,23 @@ class EventServiceBase:
|
|||||||
"""
|
"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="model_install_started",
|
event_name="model_install_started",
|
||||||
payload={"source": source},
|
payload={
|
||||||
|
"source": source
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_model_install_completed(self, source: str, dest: str) -> None:
|
def emit_model_install_completed(self, source: str, key: str) -> None:
|
||||||
"""
|
"""
|
||||||
Emitted when an install job is completed successfully.
|
Emitted when an install job is completed successfully.
|
||||||
|
|
||||||
:param source: Source of the model; local path, repo_id or url
|
:param source: Source of the model; local path, repo_id or url
|
||||||
:param dest: Destination of the model files; always a local path
|
:param key: Model config record key
|
||||||
"""
|
"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="model_install_completed",
|
event_name="model_install_completed",
|
||||||
payload={
|
payload={
|
||||||
"source": source,
|
"source": source,
|
||||||
"dest": dest,
|
"key": key,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,2 +1,6 @@
|
|||||||
from .model_install_base import ModelInstallServiceBase # noqa F401
|
"""Initialization file for model install service package."""
|
||||||
from .model_install_default import ModelInstallService # noqa F401
|
|
||||||
|
from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException
|
||||||
|
from .model_install_default import ModelInstallService
|
||||||
|
|
||||||
|
__all__ = ['ModelInstallServiceBase', 'ModelInstallService', 'InstallStatus', 'ModelInstallJob', 'UnknownInstallJobException']
|
||||||
|
@ -21,12 +21,21 @@ class InstallStatus(str, Enum):
|
|||||||
ERROR = "error" # terminated with an error message
|
ERROR = "error" # terminated with an error message
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownInstallJobException(Exception):
|
||||||
|
"""Raised when the status of an unknown job is requested."""
|
||||||
|
|
||||||
|
|
||||||
|
ModelSource = Union[str, Path, AnyHttpUrl]
|
||||||
|
|
||||||
|
|
||||||
class ModelInstallJob(BaseModel):
|
class ModelInstallJob(BaseModel):
|
||||||
"""Object that tracks the current status of an install request."""
|
"""Object that tracks the current status of an install request."""
|
||||||
|
|
||||||
source: Union[str, Path, AnyHttpUrl] = Field(description="Source (URL, repo_id, or local path) of model")
|
|
||||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||||
local_path: Optional[Path] = Field(default=None, description="Path to locally-downloaded model")
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Configuration metadata to apply to model before installing it")
|
||||||
|
inplace: bool = Field(default=False, description="Leave model in its current location; otherwise install under models directory")
|
||||||
|
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||||
|
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||||
|
key: Optional[str] = Field(default=None, description="After model is installed, this is its config record key")
|
||||||
error_type: str = Field(default="", description="Class name of the exception that led to status==ERROR")
|
error_type: str = Field(default="", description="Class name of the exception that led to status==ERROR")
|
||||||
error: str = Field(default="", description="Error traceback") # noqa #501
|
error: str = Field(default="", description="Error traceback") # noqa #501
|
||||||
|
|
||||||
@ -65,6 +74,11 @@ class ModelInstallServiceBase(ABC):
|
|||||||
def record_store(self) -> ModelRecordServiceBase:
|
def record_store(self) -> ModelRecordServiceBase:
|
||||||
"""Return the ModelRecoreService object associated with the installer."""
|
"""Return the ModelRecoreService object associated with the installer."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def event_bus(self) -> Optional[EventServiceBase]:
|
||||||
|
"""Return the event service base object associated with the installer."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def register_path(
|
def register_path(
|
||||||
self,
|
self,
|
||||||
@ -86,16 +100,12 @@ class ModelInstallServiceBase(ABC):
|
|||||||
"""Remove model with indicated key from the database."""
|
"""Remove model with indicated key from the database."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, key: str) -> None:
|
def delete(self, key: str) -> None: # noqa D102
|
||||||
"""Remove model with indicated key from the database and delete weight files from disk."""
|
"""Remove model with indicated key from the database. Delete its files only if they are within our models directory."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def conditionally_delete(self, key: str) -> None:
|
def unconditionally_delete(self, key: str) -> None:
|
||||||
"""
|
"""Remove model with indicated key from the database and unconditionally delete weight files from disk."""
|
||||||
Remove model with indicated key from the database
|
|
||||||
and conditeionally delete weight files from disk
|
|
||||||
if they reside within InvokeAI's models directory.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def install_path(
|
def install_path(
|
||||||
@ -121,7 +131,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
variant: Optional[str] = None,
|
variant: Optional[str] = None,
|
||||||
subfolder: Optional[str] = None,
|
subfolder: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
"""Install the indicated model.
|
"""Install the indicated model.
|
||||||
@ -168,6 +178,18 @@ class ModelInstallServiceBase(ABC):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_job(self, source: ModelSource) -> ModelInstallJob:
|
||||||
|
"""Return the ModelInstallJob corresponding to the provided source."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
|
||||||
|
"""Return a dict in which keys are model sources and values are corresponding model install jobs."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def prune_jobs(self) -> None:
|
||||||
|
"""Prune all completed and errored jobs."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]:
|
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]:
|
||||||
"""
|
"""
|
||||||
@ -194,4 +216,4 @@ class ModelInstallServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def sync_to_config(self) -> None:
|
def sync_to_config(self) -> None:
|
||||||
"""Synchronize models on disk to those in memory."""
|
"""Synchronize models on disk to those in the model record database."""
|
||||||
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||||||
from queue import Queue
|
from queue import Queue
|
||||||
from random import randbytes
|
from random import randbytes
|
||||||
from shutil import move, rmtree
|
from shutil import move, rmtree
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Set, Optional, Union
|
||||||
|
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
|
||||||
@ -18,14 +18,16 @@ from invokeai.backend.model_manager.config import (
|
|||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
InvalidModelConfigException,
|
InvalidModelConfigException,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.config import ModelType, BaseModelType
|
||||||
from invokeai.backend.model_manager.hash import FastModelHash
|
from invokeai.backend.model_manager.hash import FastModelHash
|
||||||
from invokeai.backend.model_manager.probe import ModelProbe
|
from invokeai.backend.model_manager.probe import ModelProbe
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
|
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||||
|
|
||||||
from .model_install_base import InstallStatus, ModelInstallJob, ModelInstallServiceBase
|
from .model_install_base import ModelSource, InstallStatus, ModelInstallJob, ModelInstallServiceBase, UnknownInstallJobException
|
||||||
|
|
||||||
# marker that the queue is done and that thread should exit
|
# marker that the queue is done and that thread should exit
|
||||||
STOP_JOB = ModelInstallJob(source="stop")
|
STOP_JOB = ModelInstallJob(source="stop", local_path=Path("/dev/null"))
|
||||||
|
|
||||||
|
|
||||||
class ModelInstallService(ModelInstallServiceBase):
|
class ModelInstallService(ModelInstallServiceBase):
|
||||||
@ -35,8 +37,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
_record_store: ModelRecordServiceBase
|
_record_store: ModelRecordServiceBase
|
||||||
_event_bus: Optional[EventServiceBase] = None
|
_event_bus: Optional[EventServiceBase] = None
|
||||||
_install_queue: Queue[ModelInstallJob]
|
_install_queue: Queue[ModelInstallJob]
|
||||||
_install_jobs: Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]
|
_install_jobs: Dict[ModelSource, ModelInstallJob]
|
||||||
_logger: InvokeAILogger
|
_logger: InvokeAILogger
|
||||||
|
_cached_model_paths: Set[Path]
|
||||||
|
_models_installed: Set[str]
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
app_config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
@ -52,9 +56,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
"""
|
"""
|
||||||
self._app_config = app_config
|
self._app_config = app_config
|
||||||
self._record_store = record_store
|
self._record_store = record_store
|
||||||
self._install_queue = Queue()
|
|
||||||
self._event_bus = event_bus
|
self._event_bus = event_bus
|
||||||
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
|
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
|
||||||
|
self._install_jobs = {}
|
||||||
|
self._install_queue = Queue()
|
||||||
|
self._cached_model_paths = set()
|
||||||
|
self._models_installed = set()
|
||||||
self._start_installer_thread()
|
self._start_installer_thread()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -65,6 +72,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def record_store(self) -> ModelRecordServiceBase: # noqa D102
|
def record_store(self) -> ModelRecordServiceBase: # noqa D102
|
||||||
return self._record_store
|
return self._record_store
|
||||||
|
|
||||||
|
@property
|
||||||
|
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||||
|
return self._event_bus
|
||||||
|
|
||||||
|
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
|
||||||
|
return self._install_jobs
|
||||||
|
|
||||||
def _start_installer_thread(self) -> None:
|
def _start_installer_thread(self) -> None:
|
||||||
threading.Thread(target=self._install_next_item, daemon=True).start()
|
threading.Thread(target=self._install_next_item, daemon=True).start()
|
||||||
|
|
||||||
@ -74,14 +88,20 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
job = self._install_queue.get()
|
job = self._install_queue.get()
|
||||||
if job == STOP_JOB:
|
if job == STOP_JOB:
|
||||||
done = True
|
done = True
|
||||||
elif job.status == InstallStatus.WAITING:
|
continue
|
||||||
|
|
||||||
assert job.local_path is not None
|
assert job.local_path is not None
|
||||||
try:
|
try:
|
||||||
self._signal_job_running(job)
|
self._signal_job_running(job)
|
||||||
self.register_path(job.local_path)
|
if job.inplace:
|
||||||
|
job.key = self.register_path(job.local_path, job.metadata)
|
||||||
|
else:
|
||||||
|
job.key = self.install_path(job.local_path, job.metadata)
|
||||||
self._signal_job_completed(job)
|
self._signal_job_completed(job)
|
||||||
except (OSError, DuplicateModelException, InvalidModelConfigException) as excp:
|
except (OSError, DuplicateModelException, InvalidModelConfigException) as excp:
|
||||||
self._signal_job_errored(job, excp)
|
self._signal_job_errored(job, excp)
|
||||||
|
finally:
|
||||||
|
self._install_queue.task_done()
|
||||||
|
|
||||||
def _signal_job_running(self, job: ModelInstallJob) -> None:
|
def _signal_job_running(self, job: ModelInstallJob) -> None:
|
||||||
job.status = InstallStatus.RUNNING
|
job.status = InstallStatus.RUNNING
|
||||||
@ -92,7 +112,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
job.status = InstallStatus.COMPLETED
|
job.status = InstallStatus.COMPLETED
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
assert job.local_path is not None
|
assert job.local_path is not None
|
||||||
self._event_bus.emit_model_install_completed(str(job.source), job.local_path.as_posix())
|
self._event_bus.emit_model_install_completed(str(job.source), job.key)
|
||||||
|
|
||||||
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
|
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
|
||||||
job.set_error(excp)
|
job.set_error(excp)
|
||||||
@ -136,29 +156,165 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def import_model(
|
def import_model(
|
||||||
self,
|
self,
|
||||||
source: Union[str, Path, AnyHttpUrl],
|
source: ModelSource,
|
||||||
inplace: bool = True,
|
inplace: bool = True,
|
||||||
variant: Optional[str] = None,
|
variant: Optional[str] = None,
|
||||||
subfolder: Optional[str] = None,
|
subfolder: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
) -> ModelInstallJob: # noqa D102
|
) -> ModelInstallJob: # noqa D102
|
||||||
|
# Clean up a common source of error. Doesn't work with Paths.
|
||||||
|
if isinstance(source, str):
|
||||||
|
source = source.strip()
|
||||||
|
|
||||||
|
if not metadata:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
# Installing a local path
|
||||||
|
if isinstance(source, (str, Path)) and Path(source).exists(): # a path that is already on disk
|
||||||
|
job = ModelInstallJob(metadata=metadata,
|
||||||
|
source=source,
|
||||||
|
inplace=inplace,
|
||||||
|
local_path=Path(source),
|
||||||
|
)
|
||||||
|
self._install_jobs[source] = job
|
||||||
|
self._install_queue.put(job)
|
||||||
|
return job
|
||||||
|
|
||||||
|
else: # waiting for download queue implementation
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], ModelInstallJob]: # noqa D102
|
def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102
|
||||||
|
try:
|
||||||
|
return self._install_jobs[source]
|
||||||
|
except KeyError:
|
||||||
|
raise UnknownInstallJobException(f'{source}: unknown install job')
|
||||||
|
|
||||||
|
def wait_for_installs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
|
||||||
self._install_queue.join()
|
self._install_queue.join()
|
||||||
return self._install_jobs
|
return self._install_jobs
|
||||||
|
|
||||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
def prune_jobs(self) -> None:
|
||||||
raise NotImplementedError
|
"""Prune all completed and errored jobs."""
|
||||||
|
finished_jobs = [source for source in self._install_jobs
|
||||||
|
if self._install_jobs[source].status in [InstallStatus.COMPLETED, InstallStatus.ERROR]
|
||||||
|
]
|
||||||
|
for source in finished_jobs:
|
||||||
|
del self._install_jobs[source]
|
||||||
|
|
||||||
def sync_to_config(self) -> None: # noqa D102
|
def sync_to_config(self) -> None:
|
||||||
raise NotImplementedError
|
"""Synchronize models on disk to those in the config record store database."""
|
||||||
|
self._scan_models_directory()
|
||||||
|
if autoimport := self._app_config.autoimport_dir:
|
||||||
|
self._logger.info("Scanning autoimport directory for new models")
|
||||||
|
self.scan_directory(self._app_config.root_path / autoimport)
|
||||||
|
|
||||||
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||||
|
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
|
||||||
|
callback = self._scan_install if install else self._scan_register
|
||||||
|
search = ModelSearch(on_model_found=callback)
|
||||||
|
self._models_installed: Set[str] = set()
|
||||||
|
search.search(scan_dir)
|
||||||
|
return list(self._models_installed)
|
||||||
|
|
||||||
|
def _scan_models_directory(self) -> None:
|
||||||
|
"""
|
||||||
|
Scan the models directory for new and missing models.
|
||||||
|
|
||||||
|
New models will be added to the storage backend. Missing models
|
||||||
|
will be deleted.
|
||||||
|
"""
|
||||||
|
defunct_models = set()
|
||||||
|
installed = set()
|
||||||
|
|
||||||
|
with Chdir(self._app_config.models_path):
|
||||||
|
self._logger.info("Checking for models that have been moved or deleted from disk")
|
||||||
|
for model_config in self.record_store.all_models():
|
||||||
|
path = Path(model_config.path)
|
||||||
|
if not path.exists():
|
||||||
|
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering")
|
||||||
|
defunct_models.add(model_config.key)
|
||||||
|
for key in defunct_models:
|
||||||
|
self.unregister(key)
|
||||||
|
|
||||||
|
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
|
||||||
|
for cur_base_model in BaseModelType:
|
||||||
|
for cur_model_type in ModelType:
|
||||||
|
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
||||||
|
installed.update(self.scan_directory(models_dir))
|
||||||
|
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||||
|
|
||||||
|
def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig:
|
||||||
|
"""
|
||||||
|
Move model into the location indicated by its basetype, type and name.
|
||||||
|
|
||||||
|
Call this after updating a model's attributes in order to move
|
||||||
|
the model's path into the location indicated by its basetype, type and
|
||||||
|
name. Applies only to models whose paths are within the root `models_dir`
|
||||||
|
directory.
|
||||||
|
|
||||||
|
May raise an UnknownModelException.
|
||||||
|
"""
|
||||||
|
model = self.record_store.get_model(key)
|
||||||
|
old_path = Path(model.path)
|
||||||
|
models_dir = self.app_config.models_path
|
||||||
|
|
||||||
|
if not old_path.is_relative_to(models_dir):
|
||||||
|
return model
|
||||||
|
|
||||||
|
new_path = models_dir / model.base.value / model.type.value / model.name
|
||||||
|
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||||
|
new_path = self._move_model(old_path, new_path)
|
||||||
|
new_hash = FastModelHash.hash(new_path)
|
||||||
|
model.path = new_path.relative_to(models_dir).as_posix()
|
||||||
|
if model.current_hash != new_hash:
|
||||||
|
assert (
|
||||||
|
ignore_hash_change
|
||||||
|
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
|
||||||
|
model.current_hash = new_hash
|
||||||
|
self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}")
|
||||||
|
self.record_store.update_model(key, model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _scan_register(self, model: Path) -> bool:
|
||||||
|
if model in self._cached_model_paths:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
id = self.register_path(model)
|
||||||
|
self._sync_model_path(id) # possibly move it to right place in `models`
|
||||||
|
self._logger.info(f"Registered {model.name} with id {id}")
|
||||||
|
self._models_installed.add(id)
|
||||||
|
except DuplicateModelException:
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _scan_install(self, model: Path) -> bool:
|
||||||
|
if model in self._cached_model_paths:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
id = self.install_path(model)
|
||||||
|
self._logger.info(f"Installed {model} with id {id}")
|
||||||
|
self._models_installed.add(id)
|
||||||
|
except DuplicateModelException:
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
|
||||||
def unregister(self, key: str) -> None: # noqa D102
|
def unregister(self, key: str) -> None: # noqa D102
|
||||||
self.record_store.del_model(key)
|
self.record_store.del_model(key)
|
||||||
|
|
||||||
def delete(self, key: str) -> None: # noqa D102
|
def delete(self, key: str) -> None: # noqa D102
|
||||||
|
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||||
|
model = self.record_store.get_model(key)
|
||||||
|
models_dir = self.app_config.models_path
|
||||||
|
model_path = models_dir / model.path
|
||||||
|
if model_path.is_relative_to(models_dir):
|
||||||
|
self.unconditionally_delete(key)
|
||||||
|
else:
|
||||||
|
self.unregister(key)
|
||||||
|
|
||||||
|
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
||||||
model = self.record_store.get_model(key)
|
model = self.record_store.get_model(key)
|
||||||
path = self.app_config.models_path / model.path
|
path = self.app_config.models_path / model.path
|
||||||
if path.is_dir():
|
if path.is_dir():
|
||||||
@ -167,16 +323,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
path.unlink()
|
path.unlink()
|
||||||
self.unregister(key)
|
self.unregister(key)
|
||||||
|
|
||||||
def conditionally_delete(self, key: str) -> None: # noqa D102
|
|
||||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
|
||||||
model = self.record_store.get_model(key)
|
|
||||||
models_dir = self.app_config.models_path
|
|
||||||
model_path = models_dir / model.path
|
|
||||||
if model_path.is_relative_to(models_dir):
|
|
||||||
self.delete(key)
|
|
||||||
else:
|
|
||||||
self.unregister(key)
|
|
||||||
|
|
||||||
def _move_model(self, old_path: Path, new_path: Path) -> Path:
|
def _move_model(self, old_path: Path, new_path: Path) -> Path:
|
||||||
if old_path == new_path:
|
if old_path == new_path:
|
||||||
return old_path
|
return old_path
|
||||||
|
@ -12,3 +12,6 @@ from .devices import ( # noqa: F401
|
|||||||
torch_dtype,
|
torch_dtype,
|
||||||
)
|
)
|
||||||
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
|
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
|
||||||
|
from .logging import InvokeAILogger
|
||||||
|
|
||||||
|
__all__ = ['Chdir', 'InvokeAILogger', 'choose_precision', 'choose_torch_device']
|
||||||
|
@ -2,15 +2,26 @@
|
|||||||
Test the model installer
|
Test the model installer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic import ValidationError
|
from typing import List, Any, Dict
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
from invokeai.backend.model_manager.config import ModelType, BaseModelType
|
import pytest
|
||||||
|
from pydantic import ValidationError, BaseModel
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
|
from invokeai.app.services.model_install import (
|
||||||
|
ModelInstallService,
|
||||||
|
ModelInstallServiceBase,
|
||||||
|
InstallStatus,
|
||||||
|
ModelInstallJob,
|
||||||
|
UnknownInstallJobException,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_file(datadir: Path) -> Path:
|
def test_file(datadir: Path) -> Path:
|
||||||
@ -36,10 +47,34 @@ def store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
|||||||
def installer(app_config: InvokeAIAppConfig,
|
def installer(app_config: InvokeAIAppConfig,
|
||||||
store: ModelRecordServiceBase) -> ModelInstallServiceBase:
|
store: ModelRecordServiceBase) -> ModelInstallServiceBase:
|
||||||
return ModelInstallService(app_config=app_config,
|
return ModelInstallService(app_config=app_config,
|
||||||
record_store=store
|
record_store=store,
|
||||||
|
event_bus=DummyEventService(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyEvent(BaseModel):
|
||||||
|
"""Dummy Event to use with Dummy Event service."""
|
||||||
|
|
||||||
|
event_name: str
|
||||||
|
payload: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class DummyEventService(EventServiceBase):
|
||||||
|
"""Dummy event service for testing."""
|
||||||
|
|
||||||
|
events: List[DummyEvent]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.events = []
|
||||||
|
|
||||||
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||||
|
"""Dispatch an event by appending it to self.events."""
|
||||||
|
self.events.append(
|
||||||
|
DummyEvent(event_name=payload['event'],
|
||||||
|
payload=payload['data'])
|
||||||
|
)
|
||||||
|
|
||||||
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||||
store = installer.record_store
|
store = installer.record_store
|
||||||
matches = store.search_by_attr(model_name="test_embedding")
|
matches = store.search_by_attr(model_name="test_embedding")
|
||||||
@ -87,3 +122,69 @@ def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config
|
|||||||
model_record = store.get_model(key)
|
model_record = store.get_model(key)
|
||||||
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
|
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
|
||||||
assert model_record.source == test_file.as_posix()
|
assert model_record.source == test_file.as_posix()
|
||||||
|
|
||||||
|
def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
|
||||||
|
"""Note: may want to break this down into several smaller unit tests."""
|
||||||
|
source = test_file
|
||||||
|
description = "Test of metadata assignment"
|
||||||
|
job = installer.import_model(source, inplace=False, metadata={"description": description})
|
||||||
|
assert job is not None
|
||||||
|
assert isinstance(job, ModelInstallJob)
|
||||||
|
|
||||||
|
# See if job is registered properly
|
||||||
|
assert installer.get_job(source) == job
|
||||||
|
|
||||||
|
# test that the job object tracked installation correctly
|
||||||
|
jobs = installer.wait_for_installs()
|
||||||
|
assert jobs[source] is not None
|
||||||
|
assert jobs[source] == job
|
||||||
|
assert jobs[source].status == InstallStatus.COMPLETED
|
||||||
|
|
||||||
|
# test that the expected events were issued
|
||||||
|
bus = installer.event_bus
|
||||||
|
assert bus is not None # sigh - ruff is a stickler for type checking
|
||||||
|
assert isinstance(bus, DummyEventService)
|
||||||
|
assert len(bus.events) == 2
|
||||||
|
event_names = [x.event_name for x in bus.events]
|
||||||
|
assert "model_install_started" in event_names
|
||||||
|
assert "model_install_completed" in event_names
|
||||||
|
assert bus.events[0].payload["source"] == source.as_posix()
|
||||||
|
assert bus.events[1].payload["source"] == source.as_posix()
|
||||||
|
key = bus.events[1].payload["key"]
|
||||||
|
assert key is not None
|
||||||
|
|
||||||
|
# see if the thing actually got installed at the expected location
|
||||||
|
model_record = installer.record_store.get_model(key)
|
||||||
|
assert model_record is not None
|
||||||
|
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
|
||||||
|
assert Path(app_config.models_dir / model_record.path).exists()
|
||||||
|
|
||||||
|
# see if metadata was properly passed through
|
||||||
|
assert model_record.description == description
|
||||||
|
|
||||||
|
# see if prune works properly
|
||||||
|
installer.prune_jobs()
|
||||||
|
with pytest.raises(UnknownInstallJobException):
|
||||||
|
assert installer.get_job(source)
|
||||||
|
|
||||||
|
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
||||||
|
store = installer.record_store
|
||||||
|
key = installer.install_path(test_file)
|
||||||
|
model_record = store.get_model(key)
|
||||||
|
assert Path(app_config.models_dir / model_record.path).exists()
|
||||||
|
assert not test_file.exists() # original should not still be there after installation
|
||||||
|
installer.delete(key)
|
||||||
|
assert not Path(app_config.models_dir / model_record.path).exists() # but installed copy should not!
|
||||||
|
with pytest.raises(UnknownModelException):
|
||||||
|
store.get_model(key)
|
||||||
|
|
||||||
|
def test_delete_register(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
||||||
|
store = installer.record_store
|
||||||
|
key = installer.register_path(test_file)
|
||||||
|
model_record = store.get_model(key)
|
||||||
|
assert Path(app_config.models_dir / model_record.path).exists()
|
||||||
|
assert test_file.exists() # original should still be there after installation
|
||||||
|
installer.delete(key)
|
||||||
|
assert Path(app_config.models_dir / model_record.path).exists()
|
||||||
|
with pytest.raises(UnknownModelException):
|
||||||
|
store.get_model(key)
|
||||||
|
Loading…
Reference in New Issue
Block a user