mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
149 lines
5.1 KiB
Python
149 lines
5.1 KiB
Python
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
|
|
|
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import TYPE_CHECKING, Any, Dict, Union, Optional
|
|
from pydantic import Field
|
|
|
|
from invokeai.app.models.exceptions import CanceledException
|
|
from invokeai.backend.model_manager import (
|
|
ModelConfigStore,
|
|
SubModelType,
|
|
)
|
|
from invokeai.backend.model_manager.cache import CacheStats
|
|
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad
|
|
|
|
from .config import InvokeAIAppConfig
|
|
from .events import EventServiceBase
|
|
from .model_record_service import ModelRecordServiceBase
|
|
|
|
if TYPE_CHECKING:
|
|
from ..invocations.baseinvocation import InvocationContext
|
|
|
|
|
|
class ModelLoadServiceBase(ABC):
|
|
"""Load models into memory."""
|
|
|
|
@abstractmethod
|
|
def __init__(self,
|
|
config: InvokeAIAppConfig,
|
|
store: Union[ModelConfigStore, ModelRecordServiceBase],
|
|
event_bus: Optional[EventServiceBase] = None):
|
|
"""
|
|
Initialize a ModelLoadService
|
|
|
|
:param config: InvokeAIAppConfig object
|
|
:param store: ModelConfigStore object for fetching configuration information
|
|
:param event_bus: Optional EventServiceBase object. If provided,
|
|
installation and download events will be sent to the event bus.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_model(
|
|
self,
|
|
key: str,
|
|
submodel_type: Optional[SubModelType] = None,
|
|
context: Optional[InvocationContext] = None,
|
|
) -> ModelInfo:
|
|
"""Retrieve the indicated model identified by key.
|
|
|
|
:param key: Unique key returned by the ModelConfigStore module.
|
|
:param submodel_type: Submodel to return (required for main models)
|
|
:param context" Optional InvocationContext, used in event reporting.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def collect_cache_stats(self, cache_stats: CacheStats):
|
|
"""Reset model cache statistics for graph with graph_id."""
|
|
pass
|
|
|
|
|
|
# implementation
|
|
class ModelLoadService(ModelLoadServiceBase):
|
|
"""Responsible for managing models on disk and in memory."""
|
|
|
|
_loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process")
|
|
_event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None)
|
|
|
|
def __init__(self,
|
|
config: InvokeAIAppConfig,
|
|
store: Union[ModelConfigStore, ModelRecordServiceBase],
|
|
event_bus: Optional[EventServiceBase] = None
|
|
):
|
|
"""
|
|
Initialize a ModelManagerService.
|
|
|
|
:param config: InvokeAIAppConfig object
|
|
:param store: ModelRecordServiceBase or ModelConfigStore object for fetching configuration information
|
|
:param event_bus: Optional EventServiceBase object. If provided,
|
|
installation and download events will be sent to the event bus.
|
|
"""
|
|
self._event_bus = event_bus
|
|
kwargs: Dict[str, Any] = {}
|
|
if self._event_bus:
|
|
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
|
|
self._loader = ModelLoad(config, store, **kwargs)
|
|
|
|
def get_model(
|
|
self,
|
|
key: str,
|
|
submodel_type: Optional[SubModelType] = None,
|
|
context: Optional[InvocationContext] = None,
|
|
) -> ModelInfo:
|
|
"""
|
|
Retrieve the indicated model.
|
|
|
|
The submodel is required when fetching a main model.
|
|
"""
|
|
model_info: ModelInfo = self._loader.get_model(key, submodel_type)
|
|
|
|
# we can emit model loading events if we are executing with access to the invocation context
|
|
if context:
|
|
self._emit_load_event(
|
|
context=context,
|
|
model_key=key,
|
|
submodel=submodel_type,
|
|
model_info=model_info,
|
|
)
|
|
|
|
return model_info
|
|
|
|
def collect_cache_stats(self, cache_stats: CacheStats):
|
|
"""
|
|
Reset model cache statistics. Is this used?
|
|
"""
|
|
self._loader.collect_cache_stats(cache_stats)
|
|
|
|
def _emit_load_event(
|
|
self,
|
|
context: InvocationContext,
|
|
model_key: str,
|
|
submodel: Optional[SubModelType] = None,
|
|
model_info: Optional[ModelInfo] = None,
|
|
):
|
|
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
|
raise CanceledException()
|
|
|
|
if model_info:
|
|
context.services.events.emit_model_load_completed(
|
|
queue_id=context.queue_id,
|
|
queue_item_id=context.queue_item_id,
|
|
queue_batch_id=context.queue_batch_id,
|
|
graph_execution_state_id=context.graph_execution_state_id,
|
|
model_key=model_key,
|
|
submodel=submodel,
|
|
model_info=model_info,
|
|
)
|
|
else:
|
|
context.services.events.emit_model_load_started(
|
|
queue_id=context.queue_id,
|
|
queue_item_id=context.queue_item_id,
|
|
queue_batch_id=context.queue_batch_id,
|
|
graph_execution_state_id=context.graph_execution_state_id,
|
|
model_key=model_key,
|
|
submodel=submodel,
|
|
)
|