mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
start implementation of installer
This commit is contained in:
parent
6c56233edc
commit
9ea3126118
@ -344,7 +344,8 @@ class EventServiceBase:
|
||||
|
||||
def emit_model_install_error(self,
|
||||
source:str,
|
||||
exception: Exception,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""
|
||||
Emitted when an install job encounters an exception.
|
||||
@ -352,13 +353,6 @@ class EventServiceBase:
|
||||
:param source: Source of the model
|
||||
:param exception: The exception that raised the error
|
||||
"""
|
||||
|
||||
# Revisit:
|
||||
# it makes more sense to receive an exception and break it out
|
||||
# into error_type and error here, rather than at the caller's side
|
||||
error_type = exception.__class__.__name__,
|
||||
error = traceback.format_exc(),
|
||||
|
||||
self.__emit_queue_event(
|
||||
event_name="model_install_error",
|
||||
payload={
|
||||
|
@ -1,7 +1,9 @@
|
||||
import traceback
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Optional, Union, List
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
@ -14,13 +16,25 @@ class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
|
||||
|
||||
class ModelInstallStatus(BaseModel):
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") # noqa #501
|
||||
error: Optional[Exception] = Field(default=None, description="Exception that led to status==ERROR") # noqa #501
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
|
||||
source: Union[str, Path, AnyHttpUrl] = Field(description="Source (URL, repo_id, or local path) of model")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
local_path: Optional[Path] = Field(default=None, description="Path to locally-downloaded model")
|
||||
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
|
||||
error: Optional[str] = Field(default=None, description="Error traceback") # noqa #501
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
self.error_type = e.__class__.__name__
|
||||
self.error = traceback.format_exc()
|
||||
self.status = InstallStatus.ERROR
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ABC):
|
||||
@ -42,18 +56,6 @@ class ModelInstallServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the app_config used by the installer."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def store(self) -> ModelRecordServiceBase:
|
||||
"""Return the storage backend used by the installer."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def register_path(
|
||||
self,
|
||||
@ -108,7 +110,7 @@ class ModelInstallServiceBase(ABC):
|
||||
subfolder: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallStatus:
|
||||
) -> ModelInstallJob:
|
||||
"""Install the indicated model.
|
||||
|
||||
:param source: Either a URL or a HuggingFace repo_id.
|
||||
@ -142,7 +144,7 @@ class ModelInstallServiceBase(ABC):
|
||||
The `inplace` flag does not affect the behavior of downloaded
|
||||
models, which are always moved into the `models` directory.
|
||||
|
||||
The call returns a ModelInstallStatus object which can be
|
||||
The call returns a ModelInstallJob object which can be
|
||||
polled to learn the current status and/or error message.
|
||||
|
||||
Variants recognized by HuggingFace currently are:
|
||||
@ -195,34 +197,3 @@ class ModelInstallServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
# The following are internal methods
|
||||
@abstractmethod
|
||||
def _create_name(self, model_path: Union[Path, str]) -> str:
|
||||
"""
|
||||
Creates a default name for the model.
|
||||
|
||||
:param model_path: Path to the model on disk.
|
||||
:return str: Model name
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _create_description(self, model_path: Union[Path, str]) -> str:
|
||||
"""
|
||||
Creates a default description for the model.
|
||||
|
||||
:param model_path: Path to the model on disk.
|
||||
:return str: Model description
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _create_id(self, model_path: Union[Path, str], name: Optional[str] = None) -> str:
|
||||
"""
|
||||
Creates a unique ID for the model for use with the model records module. # noqa E501
|
||||
|
||||
:param model_path: Path to the model on disk.
|
||||
:param name: (optional) non-default name for the model
|
||||
:return str: Unique ID
|
||||
"""
|
||||
pass
|
||||
|
@ -1,21 +1,77 @@
|
||||
"""Model installation class."""
|
||||
|
||||
import threading
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Optional, Union, List
|
||||
from queue import Queue
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from .model_install_base import InstallStatus, ModelInstallStatus, ModelInstallServiceBase
|
||||
from .model_install_base import InstallStatus, ModelInstallJob, ModelInstallServiceBase
|
||||
|
||||
from invokeai.backend.model_management.model_probe import ModelProbeInfo, ModelProbe
|
||||
from invokeai.backend.model_manager.config import InvalidModelConfigException, DuplicateModelException
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
class ModelInstallService(ModelInstallBase, BaseModel):
|
||||
# marker that the queue is done and that thread should exit
|
||||
STOP_JOB = ModelInstallJob(source="stop")
|
||||
|
||||
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
"""class for InvokeAI model installation."""
|
||||
|
||||
config: InvokeAIAppConfig
|
||||
store: ModelRecordServiceBase
|
||||
event_bus: Optional[EventServiceBase] = Field(default=None)
|
||||
install_queue: Queue = Field(default_factory=Queue)
|
||||
_event_bus: Optional[EventServiceBase] = None
|
||||
_install_queue: Queue
|
||||
|
||||
def __init__(self,
|
||||
config: InvokeAIAppConfig,
|
||||
store: ModelRecordServiceBase,
|
||||
install_queue: Optional[Queue] = None,
|
||||
event_bus: Optional[EventServiceBase] = None
|
||||
):
|
||||
self.config = config
|
||||
self.store = store
|
||||
self._install_queue = install_queue or Queue()
|
||||
self._event_bus = event_bus
|
||||
self._start_installer_thread()
|
||||
|
||||
def _start_installer_thread(self):
|
||||
threading.Thread(target=self._install_next_item, daemon=True).start()
|
||||
|
||||
def _install_next_item(self):
|
||||
done = False
|
||||
while not done:
|
||||
job = self._install_queue.get()
|
||||
if job == STOP_JOB:
|
||||
done = True
|
||||
elif job.status == InstallStatus.WAITING:
|
||||
try:
|
||||
self._signal_job_running(job)
|
||||
self.register_path(job.path)
|
||||
self._signal_job_completed(job)
|
||||
except (OSError, DuplicateModelException, InvalidModelConfigException) as e:
|
||||
self._signal_job_errored(job, e)
|
||||
|
||||
def _signal_job_running(self, job: ModelInstallJob):
|
||||
job.status = InstallStatus.RUNNING
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_started(job.source)
|
||||
|
||||
def _signal_job_completed(self, job: ModelInstallJob):
|
||||
job.status = InstallStatus.COMPLETED
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_completed(job.source, job.dest)
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob, e: Exception):
|
||||
job.set_error(e)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_error(job.source, job.error_type, job.error)
|
||||
|
||||
def register_path(
|
||||
self,
|
||||
@ -45,7 +101,7 @@ class ModelInstallService(ModelInstallBase, BaseModel):
|
||||
subfolder: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallStatus:
|
||||
) -> ModelInstallJob:
|
||||
raise NotImplementedError
|
||||
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
@ -68,10 +124,10 @@ class ModelInstallService(ModelInstallBase, BaseModel):
|
||||
else:
|
||||
return model_path.name
|
||||
|
||||
def _create_description(self, model_path: Union[Path, str]) -> str:
|
||||
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
|
||||
name: str = name or self._create_name(model_path)
|
||||
return f"a {info.model_type} model based on {info.base_type}"
|
||||
def _create_description(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str:
|
||||
info = info or ModelProbe.probe(Path(model_path))
|
||||
name: str = self._create_name(model_path)
|
||||
return f"a {info.model_type} model {name} based on {info.base_type}"
|
||||
|
||||
def _create_id(self, model_path: Union[Path, str], name: Optional[str] = None) -> str:
|
||||
name: str = name or self._create_name(model_path)
|
||||
|
@ -29,6 +29,9 @@ from typing_extensions import Annotated
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
"""Exception for when a duplicate model is detected during installation."""
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
"""Base model type."""
|
||||
|
Loading…
Reference in New Issue
Block a user