import_model and list_install_jobs router APIs written

This commit is contained in:
Lincoln Stein
2023-11-25 21:45:59 -05:00
parent ec510d34b5
commit 8aefe2cefe
10 changed files with 145 additions and 20 deletions

View File

@ -17,6 +17,7 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy
class EventServiceBase:
queue_event: str = "queue_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed"""
@ -31,6 +32,13 @@ class EventServiceBase:
payload={"event": event_name, "data": payload},
)
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.model_event,
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
@ -321,7 +329,7 @@ class EventServiceBase:
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_queue_event(
self.__emit_model_event(
event_name="model_install_started",
payload={
"source": source
@ -335,7 +343,7 @@ class EventServiceBase:
:param source: Source of the model; local path, repo_id or url
:param key: Model config record key
"""
self.__emit_queue_event(
self.__emit_model_event(
event_name="model_install_completed",
payload={
"source": source,
@ -354,7 +362,7 @@ class EventServiceBase:
:param source: Source of the model
:param exception: The exception that raised the error
"""
self.__emit_queue_event(
self.__emit_model_event(
event_name="model_install_error",
payload={
"source": source,

View File

@ -23,6 +23,7 @@ if TYPE_CHECKING:
from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .model_install import ModelInstallServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
@ -51,6 +52,7 @@ class InvocationServices:
logger: "Logger"
model_manager: "ModelManagerServiceBase"
model_records: "ModelRecordServiceBase"
model_install: "ModelRecordInstallServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
@ -79,6 +81,7 @@ class InvocationServices:
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
model_install: "ModelInstallServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
@ -105,6 +108,7 @@ class InvocationServices:
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
self.model_install = model_install
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@ -1,6 +1,12 @@
"""Initialization file for model install service package."""
from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException
from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException, ModelSource
from .model_install_default import ModelInstallService
__all__ = ['ModelInstallServiceBase', 'ModelInstallService', 'InstallStatus', 'ModelInstallJob', 'UnknownInstallJobException']
__all__ = ['ModelInstallServiceBase',
'ModelInstallService',
'InstallStatus',
'ModelInstallJob',
'UnknownInstallJobException',
'ModelSource',
]

View File

@ -183,8 +183,12 @@ class ModelInstallServiceBase(ABC):
"""Return the ModelInstallJob corresponding to the provided source."""
@abstractmethod
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
"""Return a dict in which keys are model sources and values are corresponding model install jobs."""
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
"""
List active and complete install jobs.
:param source: Filter by jobs whose sources are a partial match to the argument.
"""
@abstractmethod
def prune_jobs(self) -> None:

View File

@ -12,10 +12,9 @@ from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase, DuplicateModelException
from invokeai.backend.model_manager.config import (
AnyModelConfig,
DuplicateModelException,
InvalidModelConfigException,
)
from invokeai.backend.model_manager.config import ModelType, BaseModelType
@ -26,6 +25,7 @@ from invokeai.backend.util import Chdir, InvokeAILogger
from .model_install_base import ModelSource, InstallStatus, ModelInstallJob, ModelInstallServiceBase, UnknownInstallJobException
# marker that the queue is done and that thread should exit
STOP_JOB = ModelInstallJob(source="stop", local_path=Path("/dev/null"))
@ -76,9 +76,6 @@ class ModelInstallService(ModelInstallServiceBase):
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
return self._install_jobs
def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start()
@ -184,6 +181,13 @@ class ModelInstallService(ModelInstallServiceBase):
else: # waiting for download queue implementation
raise NotImplementedError
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
jobs = self._install_jobs
if not source:
return jobs.values()
else:
return [jobs[x] for x in jobs if source in str(x)]
def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102
try:
return self._install_jobs[source]