diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index bfb982b4c0..3682ec0c6b 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -344,7 +344,8 @@ class EventServiceBase: def emit_model_install_error(self, source:str, - exception: Exception, + error_type: str, + error: str, ) -> None: """ Emitted when an install job encounters an exception. @@ -352,13 +353,6 @@ class EventServiceBase: :param source: Source of the model :param exception: The exception that raised the error """ - - # Revisit: - # it makes more sense to receive an exception and break it out - # into error_type and error here, rather than at the caller's side - error_type = exception.__class__.__name__, - error = traceback.format_exc(), - self.__emit_queue_event( event_name="model_install_error", payload={ diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index fcf3ab9d17..0b5e5cb650 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -1,11 +1,13 @@ +import traceback + from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, List from pydantic import BaseModel, Field from pydantic.networks import AnyHttpUrl -from invokeai.app.services.model_records import ModelRecordServiceBase +from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.events import EventServiceBase @@ -14,13 +16,25 @@ class InstallStatus(str, Enum): """State of an install job running in the background.""" WAITING = "waiting" # waiting to be dequeued + RUNNING = "running" # being processed COMPLETED = "completed" # finished running ERROR = "error" # terminated with an error message -class ModelInstallStatus(BaseModel): - status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") # noqa #501 - error: Optional[Exception] = Field(default=None, description="Exception that led to status==ERROR") # noqa #501 +class ModelInstallJob(BaseModel): + """Object that tracks the current status of an install request.""" + + source: Union[str, Path, AnyHttpUrl] = Field(description="Source (URL, repo_id, or local path) of model") + status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") + local_path: Optional[Path] = Field(default=None, description="Path to locally-downloaded model") + error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR") + error: Optional[str] = Field(default=None, description="Error traceback") # noqa #501 + + def set_error(self, e: Exception) -> None: + """Record the error and traceback from an exception.""" + self.error_type = e.__class__.__name__ + self.error = traceback.format_exc() + self.status = InstallStatus.ERROR class ModelInstallServiceBase(ABC): @@ -42,18 +56,6 @@ class ModelInstallServiceBase(ABC): """ pass - @property - @abstractmethod - def config(self) -> InvokeAIAppConfig: - """Return the app_config used by the installer.""" - pass - - @property - @abstractmethod - def store(self) -> ModelRecordServiceBase: - """Return the storage backend used by the installer.""" - pass - @abstractmethod def register_path( self, @@ -108,7 +110,7 @@ class ModelInstallServiceBase(ABC): subfolder: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, access_token: Optional[str] = None, - ) -> ModelInstallStatus: + ) -> ModelInstallJob: """Install the indicated model. :param source: Either a URL or a HuggingFace repo_id. @@ -142,7 +144,7 @@ class ModelInstallServiceBase(ABC): The `inplace` flag does not affect the behavior of downloaded models, which are always moved into the `models` directory. - The call returns a ModelInstallStatus object which can be + The call returns a ModelInstallJob object which can be polled to learn the current status and/or error message. Variants recognized by HuggingFace currently are: @@ -195,34 +197,3 @@ class ModelInstallServiceBase(ABC): """ pass - # The following are internal methods - @abstractmethod - def _create_name(self, model_path: Union[Path, str]) -> str: - """ - Creates a default name for the model. - - :param model_path: Path to the model on disk. - :return str: Model name - """ - pass - - @abstractmethod - def _create_description(self, model_path: Union[Path, str]) -> str: - """ - Creates a default description for the model. - - :param model_path: Path to the model on disk. - :return str: Model description - """ - pass - - @abstractmethod - def _create_id(self, model_path: Union[Path, str], name: Optional[str] = None) -> str: - """ - Creates a unique ID for the model for use with the model records module. # noqa E501 - - :param model_path: Path to the model on disk. - :param name: (optional) non-default name for the model - :return str: Unique ID - """ - pass diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 8ed61b8b36..11ca0feaf2 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -1,21 +1,77 @@ +"""Model installation class.""" + +import threading + from pathlib import Path -from typing import Dict, Optional, Union -from pydantic import BaseModel, Field +from typing import Dict, Optional, Union, List from queue import Queue from pydantic.networks import AnyHttpUrl -from .model_install_base import InstallStatus, ModelInstallStatus, ModelInstallServiceBase +from .model_install_base import InstallStatus, ModelInstallJob, ModelInstallServiceBase + +from invokeai.backend.model_management.model_probe import ModelProbeInfo, ModelProbe +from invokeai.backend.model_manager.config import InvalidModelConfigException, DuplicateModelException + from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.events import EventServiceBase +from invokeai.backend.util.logging import InvokeAILogger -class ModelInstallService(ModelInstallBase, BaseModel): +# marker that the queue is done and that thread should exit +STOP_JOB = ModelInstallJob(source="stop") + + +class ModelInstallService(ModelInstallServiceBase): """class for InvokeAI model installation.""" config: InvokeAIAppConfig store: ModelRecordServiceBase - event_bus: Optional[EventServiceBase] = Field(default=None) - install_queue: Queue = Field(default_factory=Queue) + _event_bus: Optional[EventServiceBase] = None + _install_queue: Queue + + def __init__(self, + config: InvokeAIAppConfig, + store: ModelRecordServiceBase, + install_queue: Optional[Queue] = None, + event_bus: Optional[EventServiceBase] = None + ): + self.config = config + self.store = store + self._install_queue = install_queue or Queue() + self._event_bus = event_bus + self._start_installer_thread() + + def _start_installer_thread(self): + threading.Thread(target=self._install_next_item, daemon=True).start() + + def _install_next_item(self): + done = False + while not done: + job = self._install_queue.get() + if job == STOP_JOB: + done = True + elif job.status == InstallStatus.WAITING: + try: + self._signal_job_running(job) + self.register_path(job.path) + self._signal_job_completed(job) + except (OSError, DuplicateModelException, InvalidModelConfigException) as e: + self._signal_job_errored(job, e) + + def _signal_job_running(self, job: ModelInstallJob): + job.status = InstallStatus.RUNNING + if self._event_bus: + self._event_bus.emit_model_install_started(job.source) + + def _signal_job_completed(self, job: ModelInstallJob): + job.status = InstallStatus.COMPLETED + if self._event_bus: + self._event_bus.emit_model_install_completed(job.source, job.dest) + + def _signal_job_errored(self, job: ModelInstallJob, e: Exception): + job.set_error(e) + if self._event_bus: + self._event_bus.emit_model_install_error(job.source, job.error_type, job.error) def register_path( self, @@ -45,7 +101,7 @@ class ModelInstallService(ModelInstallBase, BaseModel): subfolder: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, access_token: Optional[str] = None, - ) -> ModelInstallStatus: + ) -> ModelInstallJob: raise NotImplementedError def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]: @@ -68,10 +124,10 @@ class ModelInstallService(ModelInstallBase, BaseModel): else: return model_path.name - def _create_description(self, model_path: Union[Path, str]) -> str: - info: ModelProbeInfo = ModelProbe.probe(Path(model_path)) - name: str = name or self._create_name(model_path) - return f"a {info.model_type} model based on {info.base_type}" + def _create_description(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str: + info = info or ModelProbe.probe(Path(model_path)) + name: str = self._create_name(model_path) + return f"a {info.model_type} model {name} based on {info.base_type}" def _create_id(self, model_path: Union[Path, str], name: Optional[str] = None) -> str: name: str = name or self._create_name(model_path) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 457e6b0823..d5fcbf0cd2 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -29,6 +29,9 @@ from typing_extensions import Annotated class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" +class DuplicateModelException(Exception): + """Exception for when a duplicate model is detected during installation.""" + class BaseModelType(str, Enum): """Base model type."""