mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-25 04:51:07 +00:00
import_model and list_install_jobs router APIs written
This commit is contained in:
@ -17,6 +17,7 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy
|
||||
|
||||
class EventServiceBase:
|
||||
queue_event: str = "queue_event"
|
||||
model_event: str = "model_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
@ -31,6 +32,13 @@ class EventServiceBase:
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def __emit_model_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.model_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
@ -321,7 +329,7 @@ class EventServiceBase:
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_queue_event(
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_started",
|
||||
payload={
|
||||
"source": source
|
||||
@ -335,7 +343,7 @@ class EventServiceBase:
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
:param key: Model config record key
|
||||
"""
|
||||
self.__emit_queue_event(
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_completed",
|
||||
payload={
|
||||
"source": source,
|
||||
@ -354,7 +362,7 @@ class EventServiceBase:
|
||||
:param source: Source of the model
|
||||
:param exception: The exception that raised the error
|
||||
"""
|
||||
self.__emit_queue_event(
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_error",
|
||||
payload={
|
||||
"source": source,
|
||||
|
@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .model_records import ModelRecordServiceBase
|
||||
from .model_install import ModelInstallServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
@ -51,6 +52,7 @@ class InvocationServices:
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
model_records: "ModelRecordServiceBase"
|
||||
model_install: "ModelRecordInstallServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
@ -79,6 +81,7 @@ class InvocationServices:
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
model_records: "ModelRecordServiceBase",
|
||||
model_install: "ModelInstallServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
@ -105,6 +108,7 @@ class InvocationServices:
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.model_records = model_records
|
||||
self.model_install = model_install
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
|
@ -1,6 +1,12 @@
|
||||
"""Initialization file for model install service package."""
|
||||
|
||||
from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException
|
||||
from .model_install_base import InstallStatus, ModelInstallServiceBase, ModelInstallJob, UnknownInstallJobException, ModelSource
|
||||
from .model_install_default import ModelInstallService
|
||||
|
||||
__all__ = ['ModelInstallServiceBase', 'ModelInstallService', 'InstallStatus', 'ModelInstallJob', 'UnknownInstallJobException']
|
||||
__all__ = ['ModelInstallServiceBase',
|
||||
'ModelInstallService',
|
||||
'InstallStatus',
|
||||
'ModelInstallJob',
|
||||
'UnknownInstallJobException',
|
||||
'ModelSource',
|
||||
]
|
||||
|
@ -183,8 +183,12 @@ class ModelInstallServiceBase(ABC):
|
||||
"""Return the ModelInstallJob corresponding to the provided source."""
|
||||
|
||||
@abstractmethod
|
||||
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
|
||||
"""Return a dict in which keys are model sources and values are corresponding model install jobs."""
|
||||
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
|
||||
"""
|
||||
List active and complete install jobs.
|
||||
|
||||
:param source: Filter by jobs whose sources are a partial match to the argument.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prune_jobs(self) -> None:
|
||||
|
@ -12,10 +12,9 @@ from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, DuplicateModelException
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
DuplicateModelException,
|
||||
InvalidModelConfigException,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import ModelType, BaseModelType
|
||||
@ -26,6 +25,7 @@ from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
|
||||
from .model_install_base import ModelSource, InstallStatus, ModelInstallJob, ModelInstallServiceBase, UnknownInstallJobException
|
||||
|
||||
|
||||
# marker that the queue is done and that thread should exit
|
||||
STOP_JOB = ModelInstallJob(source="stop", local_path=Path("/dev/null"))
|
||||
|
||||
@ -76,9 +76,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||
return self._event_bus
|
||||
|
||||
def get_jobs(self) -> Dict[ModelSource, ModelInstallJob]: # noqa D102
|
||||
return self._install_jobs
|
||||
|
||||
def _start_installer_thread(self) -> None:
|
||||
threading.Thread(target=self._install_next_item, daemon=True).start()
|
||||
|
||||
@ -184,6 +181,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
else: # waiting for download queue implementation
|
||||
raise NotImplementedError
|
||||
|
||||
def list_jobs(self, source: Optional[ModelSource]=None) -> List[ModelInstallJob]: # noqa D102
|
||||
jobs = self._install_jobs
|
||||
if not source:
|
||||
return jobs.values()
|
||||
else:
|
||||
return [jobs[x] for x in jobs if source in str(x)]
|
||||
|
||||
def get_job(self, source: ModelSource) -> ModelInstallJob: # noqa D102
|
||||
try:
|
||||
return self._install_jobs[source]
|
||||
|
Reference in New Issue
Block a user