mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
start implementation of installer
This commit is contained in:
parent
6c56233edc
commit
9ea3126118
@ -344,7 +344,8 @@ class EventServiceBase:
|
|||||||
|
|
||||||
def emit_model_install_error(self,
|
def emit_model_install_error(self,
|
||||||
source:str,
|
source:str,
|
||||||
exception: Exception,
|
error_type: str,
|
||||||
|
error: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Emitted when an install job encounters an exception.
|
Emitted when an install job encounters an exception.
|
||||||
@ -352,13 +353,6 @@ class EventServiceBase:
|
|||||||
:param source: Source of the model
|
:param source: Source of the model
|
||||||
:param exception: The exception that raised the error
|
:param exception: The exception that raised the error
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Revisit:
|
|
||||||
# it makes more sense to receive an exception and break it out
|
|
||||||
# into error_type and error here, rather than at the caller's side
|
|
||||||
error_type = exception.__class__.__name__,
|
|
||||||
error = traceback.format_exc(),
|
|
||||||
|
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="model_install_error",
|
event_name="model_install_error",
|
||||||
payload={
|
payload={
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
|
import traceback
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union, List
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
|
|
||||||
@ -14,13 +16,25 @@ class InstallStatus(str, Enum):
|
|||||||
"""State of an install job running in the background."""
|
"""State of an install job running in the background."""
|
||||||
|
|
||||||
WAITING = "waiting" # waiting to be dequeued
|
WAITING = "waiting" # waiting to be dequeued
|
||||||
|
RUNNING = "running" # being processed
|
||||||
COMPLETED = "completed" # finished running
|
COMPLETED = "completed" # finished running
|
||||||
ERROR = "error" # terminated with an error message
|
ERROR = "error" # terminated with an error message
|
||||||
|
|
||||||
|
|
||||||
class ModelInstallStatus(BaseModel):
|
class ModelInstallJob(BaseModel):
|
||||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") # noqa #501
|
"""Object that tracks the current status of an install request."""
|
||||||
error: Optional[Exception] = Field(default=None, description="Exception that led to status==ERROR") # noqa #501
|
|
||||||
|
source: Union[str, Path, AnyHttpUrl] = Field(description="Source (URL, repo_id, or local path) of model")
|
||||||
|
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||||
|
local_path: Optional[Path] = Field(default=None, description="Path to locally-downloaded model")
|
||||||
|
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
|
||||||
|
error: Optional[str] = Field(default=None, description="Error traceback") # noqa #501
|
||||||
|
|
||||||
|
def set_error(self, e: Exception) -> None:
|
||||||
|
"""Record the error and traceback from an exception."""
|
||||||
|
self.error_type = e.__class__.__name__
|
||||||
|
self.error = traceback.format_exc()
|
||||||
|
self.status = InstallStatus.ERROR
|
||||||
|
|
||||||
|
|
||||||
class ModelInstallServiceBase(ABC):
|
class ModelInstallServiceBase(ABC):
|
||||||
@ -42,18 +56,6 @@ class ModelInstallServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def config(self) -> InvokeAIAppConfig:
|
|
||||||
"""Return the app_config used by the installer."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def store(self) -> ModelRecordServiceBase:
|
|
||||||
"""Return the storage backend used by the installer."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def register_path(
|
def register_path(
|
||||||
self,
|
self,
|
||||||
@ -108,7 +110,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
subfolder: Optional[str] = None,
|
subfolder: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
metadata: Optional[Dict[str, str]] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
) -> ModelInstallStatus:
|
) -> ModelInstallJob:
|
||||||
"""Install the indicated model.
|
"""Install the indicated model.
|
||||||
|
|
||||||
:param source: Either a URL or a HuggingFace repo_id.
|
:param source: Either a URL or a HuggingFace repo_id.
|
||||||
@ -142,7 +144,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
The `inplace` flag does not affect the behavior of downloaded
|
The `inplace` flag does not affect the behavior of downloaded
|
||||||
models, which are always moved into the `models` directory.
|
models, which are always moved into the `models` directory.
|
||||||
|
|
||||||
The call returns a ModelInstallStatus object which can be
|
The call returns a ModelInstallJob object which can be
|
||||||
polled to learn the current status and/or error message.
|
polled to learn the current status and/or error message.
|
||||||
|
|
||||||
Variants recognized by HuggingFace currently are:
|
Variants recognized by HuggingFace currently are:
|
||||||
@ -195,34 +197,3 @@ class ModelInstallServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# The following are internal methods
|
|
||||||
@abstractmethod
|
|
||||||
def _create_name(self, model_path: Union[Path, str]) -> str:
|
|
||||||
"""
|
|
||||||
Creates a default name for the model.
|
|
||||||
|
|
||||||
:param model_path: Path to the model on disk.
|
|
||||||
:return str: Model name
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _create_description(self, model_path: Union[Path, str]) -> str:
|
|
||||||
"""
|
|
||||||
Creates a default description for the model.
|
|
||||||
|
|
||||||
:param model_path: Path to the model on disk.
|
|
||||||
:return str: Model description
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _create_id(self, model_path: Union[Path, str], name: Optional[str] = None) -> str:
|
|
||||||
"""
|
|
||||||
Creates a unique ID for the model for use with the model records module. # noqa E501
|
|
||||||
|
|
||||||
:param model_path: Path to the model on disk.
|
|
||||||
:param name: (optional) non-default name for the model
|
|
||||||
:return str: Unique ID
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
@ -1,21 +1,77 @@
|
|||||||
|
"""Model installation class."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union, List
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
|
||||||
from .model_install_base import InstallStatus, ModelInstallStatus, ModelInstallServiceBase
|
from .model_install_base import InstallStatus, ModelInstallJob, ModelInstallServiceBase
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.model_probe import ModelProbeInfo, ModelProbe
|
||||||
|
from invokeai.backend.model_manager.config import InvalidModelConfigException, DuplicateModelException
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
class ModelInstallService(ModelInstallBase, BaseModel):
|
# marker that the queue is done and that thread should exit
|
||||||
|
STOP_JOB = ModelInstallJob(source="stop")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInstallService(ModelInstallServiceBase):
|
||||||
"""class for InvokeAI model installation."""
|
"""class for InvokeAI model installation."""
|
||||||
|
|
||||||
config: InvokeAIAppConfig
|
config: InvokeAIAppConfig
|
||||||
store: ModelRecordServiceBase
|
store: ModelRecordServiceBase
|
||||||
event_bus: Optional[EventServiceBase] = Field(default=None)
|
_event_bus: Optional[EventServiceBase] = None
|
||||||
install_queue: Queue = Field(default_factory=Queue)
|
_install_queue: Queue
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: InvokeAIAppConfig,
|
||||||
|
store: ModelRecordServiceBase,
|
||||||
|
install_queue: Optional[Queue] = None,
|
||||||
|
event_bus: Optional[EventServiceBase] = None
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.store = store
|
||||||
|
self._install_queue = install_queue or Queue()
|
||||||
|
self._event_bus = event_bus
|
||||||
|
self._start_installer_thread()
|
||||||
|
|
||||||
|
def _start_installer_thread(self):
|
||||||
|
threading.Thread(target=self._install_next_item, daemon=True).start()
|
||||||
|
|
||||||
|
def _install_next_item(self):
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
job = self._install_queue.get()
|
||||||
|
if job == STOP_JOB:
|
||||||
|
done = True
|
||||||
|
elif job.status == InstallStatus.WAITING:
|
||||||
|
try:
|
||||||
|
self._signal_job_running(job)
|
||||||
|
self.register_path(job.path)
|
||||||
|
self._signal_job_completed(job)
|
||||||
|
except (OSError, DuplicateModelException, InvalidModelConfigException) as e:
|
||||||
|
self._signal_job_errored(job, e)
|
||||||
|
|
||||||
|
def _signal_job_running(self, job: ModelInstallJob):
|
||||||
|
job.status = InstallStatus.RUNNING
|
||||||
|
if self._event_bus:
|
||||||
|
self._event_bus.emit_model_install_started(job.source)
|
||||||
|
|
||||||
|
def _signal_job_completed(self, job: ModelInstallJob):
|
||||||
|
job.status = InstallStatus.COMPLETED
|
||||||
|
if self._event_bus:
|
||||||
|
self._event_bus.emit_model_install_completed(job.source, job.dest)
|
||||||
|
|
||||||
|
def _signal_job_errored(self, job: ModelInstallJob, e: Exception):
|
||||||
|
job.set_error(e)
|
||||||
|
if self._event_bus:
|
||||||
|
self._event_bus.emit_model_install_error(job.source, job.error_type, job.error)
|
||||||
|
|
||||||
def register_path(
|
def register_path(
|
||||||
self,
|
self,
|
||||||
@ -45,7 +101,7 @@ class ModelInstallService(ModelInstallBase, BaseModel):
|
|||||||
subfolder: Optional[str] = None,
|
subfolder: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
metadata: Optional[Dict[str, str]] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
) -> ModelInstallStatus:
|
) -> ModelInstallJob:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||||
@ -68,10 +124,10 @@ class ModelInstallService(ModelInstallBase, BaseModel):
|
|||||||
else:
|
else:
|
||||||
return model_path.name
|
return model_path.name
|
||||||
|
|
||||||
def _create_description(self, model_path: Union[Path, str]) -> str:
|
def _create_description(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str:
|
||||||
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
|
info = info or ModelProbe.probe(Path(model_path))
|
||||||
name: str = name or self._create_name(model_path)
|
name: str = self._create_name(model_path)
|
||||||
return f"a {info.model_type} model based on {info.base_type}"
|
return f"a {info.model_type} model {name} based on {info.base_type}"
|
||||||
|
|
||||||
def _create_id(self, model_path: Union[Path, str], name: Optional[str] = None) -> str:
|
def _create_id(self, model_path: Union[Path, str], name: Optional[str] = None) -> str:
|
||||||
name: str = name or self._create_name(model_path)
|
name: str = name or self._create_name(model_path)
|
||||||
|
@ -29,6 +29,9 @@ from typing_extensions import Annotated
|
|||||||
class InvalidModelConfigException(Exception):
|
class InvalidModelConfigException(Exception):
|
||||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||||
|
|
||||||
|
class DuplicateModelException(Exception):
|
||||||
|
"""Exception for when a duplicate model is detected during installation."""
|
||||||
|
|
||||||
|
|
||||||
class BaseModelType(str, Enum):
|
class BaseModelType(str, Enum):
|
||||||
"""Base model type."""
|
"""Base model type."""
|
||||||
|
Loading…
Reference in New Issue
Block a user