start implementation of installer

This commit is contained in:
Lincoln Stein 2023-11-20 23:02:30 -05:00
parent 6c56233edc
commit 9ea3126118
4 changed files with 93 additions and 69 deletions

View File

@ -344,7 +344,8 @@ class EventServiceBase:
def emit_model_install_error(self,
source:str,
exception: Exception,
error_type: str,
error: str,
) -> None:
"""
Emitted when an install job encounters an exception.
@ -352,13 +353,6 @@ class EventServiceBase:
:param source: Source of the model
:param exception: The exception that raised the error
"""
# Revisit:
# it makes more sense to receive an exception and break it out
# into error_type and error here, rather than at the caller's side
error_type = exception.__class__.__name__,
error = traceback.format_exc(),
self.__emit_queue_event(
event_name="model_install_error",
payload={

View File

@ -1,7 +1,9 @@
import traceback
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, List
from pydantic import BaseModel, Field
from pydantic.networks import AnyHttpUrl
@ -14,13 +16,25 @@ class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
class ModelInstallStatus(BaseModel):
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") # noqa #501
error: Optional[Exception] = Field(default=None, description="Exception that led to status==ERROR") # noqa #501
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
source: Union[str, Path, AnyHttpUrl] = Field(description="Source (URL, repo_id, or local path) of model")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
local_path: Optional[Path] = Field(default=None, description="Path to locally-downloaded model")
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
error: Optional[str] = Field(default=None, description="Error traceback") # noqa #501
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self.error_type = e.__class__.__name__
self.error = traceback.format_exc()
self.status = InstallStatus.ERROR
class ModelInstallServiceBase(ABC):
@ -42,18 +56,6 @@ class ModelInstallServiceBase(ABC):
"""
pass
@property
@abstractmethod
def config(self) -> InvokeAIAppConfig:
"""Return the app_config used by the installer."""
pass
@property
@abstractmethod
def store(self) -> ModelRecordServiceBase:
"""Return the storage backend used by the installer."""
pass
@abstractmethod
def register_path(
self,
@ -108,7 +110,7 @@ class ModelInstallServiceBase(ABC):
subfolder: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
access_token: Optional[str] = None,
) -> ModelInstallStatus:
) -> ModelInstallJob:
"""Install the indicated model.
:param source: Either a URL or a HuggingFace repo_id.
@ -142,7 +144,7 @@ class ModelInstallServiceBase(ABC):
The `inplace` flag does not affect the behavior of downloaded
models, which are always moved into the `models` directory.
The call returns a ModelInstallStatus object which can be
The call returns a ModelInstallJob object which can be
polled to learn the current status and/or error message.
Variants recognized by HuggingFace currently are:
@ -195,34 +197,3 @@ class ModelInstallServiceBase(ABC):
"""
pass
# The following are internal methods
@abstractmethod
def _create_name(self, model_path: Union[Path, str]) -> str:
"""
Creates a default name for the model.
:param model_path: Path to the model on disk.
:return str: Model name
"""
pass
@abstractmethod
def _create_description(self, model_path: Union[Path, str]) -> str:
"""
Creates a default description for the model.
:param model_path: Path to the model on disk.
:return str: Model description
"""
pass
@abstractmethod
def _create_id(self, model_path: Union[Path, str], name: Optional[str] = None) -> str:
"""
Creates a unique ID for the model for use with the model records module. # noqa E501
:param model_path: Path to the model on disk.
:param name: (optional) non-default name for the model
:return str: Unique ID
"""
pass

View File

@ -1,21 +1,77 @@
"""Model installation class."""
import threading
from pathlib import Path
from typing import Dict, Optional, Union
from pydantic import BaseModel, Field
from typing import Dict, Optional, Union, List
from queue import Queue
from pydantic.networks import AnyHttpUrl
from .model_install_base import InstallStatus, ModelInstallStatus, ModelInstallServiceBase
from .model_install_base import InstallStatus, ModelInstallJob, ModelInstallServiceBase
from invokeai.backend.model_management.model_probe import ModelProbeInfo, ModelProbe
from invokeai.backend.model_manager.config import InvalidModelConfigException, DuplicateModelException
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.events import EventServiceBase
from invokeai.backend.util.logging import InvokeAILogger
class ModelInstallService(ModelInstallBase, BaseModel):
# marker that the queue is done and that thread should exit
STOP_JOB = ModelInstallJob(source="stop")
class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation."""
config: InvokeAIAppConfig
store: ModelRecordServiceBase
event_bus: Optional[EventServiceBase] = Field(default=None)
install_queue: Queue = Field(default_factory=Queue)
_event_bus: Optional[EventServiceBase] = None
_install_queue: Queue
def __init__(self,
config: InvokeAIAppConfig,
store: ModelRecordServiceBase,
install_queue: Optional[Queue] = None,
event_bus: Optional[EventServiceBase] = None
):
self.config = config
self.store = store
self._install_queue = install_queue or Queue()
self._event_bus = event_bus
self._start_installer_thread()
def _start_installer_thread(self):
threading.Thread(target=self._install_next_item, daemon=True).start()
def _install_next_item(self):
done = False
while not done:
job = self._install_queue.get()
if job == STOP_JOB:
done = True
elif job.status == InstallStatus.WAITING:
try:
self._signal_job_running(job)
self.register_path(job.path)
self._signal_job_completed(job)
except (OSError, DuplicateModelException, InvalidModelConfigException) as e:
self._signal_job_errored(job, e)
def _signal_job_running(self, job: ModelInstallJob):
job.status = InstallStatus.RUNNING
if self._event_bus:
self._event_bus.emit_model_install_started(job.source)
def _signal_job_completed(self, job: ModelInstallJob):
job.status = InstallStatus.COMPLETED
if self._event_bus:
self._event_bus.emit_model_install_completed(job.source, job.dest)
def _signal_job_errored(self, job: ModelInstallJob, e: Exception):
job.set_error(e)
if self._event_bus:
self._event_bus.emit_model_install_error(job.source, job.error_type, job.error)
def register_path(
self,
@ -45,7 +101,7 @@ class ModelInstallService(ModelInstallBase, BaseModel):
subfolder: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
access_token: Optional[str] = None,
) -> ModelInstallStatus:
) -> ModelInstallJob:
raise NotImplementedError
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
@ -68,10 +124,10 @@ class ModelInstallService(ModelInstallBase, BaseModel):
else:
return model_path.name
def _create_description(self, model_path: Union[Path, str]) -> str:
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
name: str = name or self._create_name(model_path)
return f"a {info.model_type} model based on {info.base_type}"
def _create_description(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str:
info = info or ModelProbe.probe(Path(model_path))
name: str = self._create_name(model_path)
return f"a {info.model_type} model {name} based on {info.base_type}"
def _create_id(self, model_path: Union[Path, str], name: Optional[str] = None) -> str:
name: str = name or self._create_name(model_path)

View File

@ -29,6 +29,9 @@ from typing_extensions import Annotated
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
class DuplicateModelException(Exception):
"""Exception for when a duplicate model is detected during installation."""
class BaseModelType(str, Enum):
"""Base model type."""