Files
InvokeAI/invokeai/app/services/model_loader_service.py

149 lines
5.1 KiB
Python

# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Union, Optional
from pydantic import Field
from invokeai.app.models.exceptions import CanceledException
from invokeai.backend.model_manager import (
ModelConfigStore,
SubModelType,
)
from invokeai.backend.model_manager.cache import CacheStats
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad
from .config import InvokeAIAppConfig
from .events import EventServiceBase
from .model_record_service import ModelRecordServiceBase
if TYPE_CHECKING:
from ..invocations.baseinvocation import InvocationContext
class ModelLoadServiceBase(ABC):
"""Load models into memory."""
@abstractmethod
def __init__(self,
config: InvokeAIAppConfig,
store: Union[ModelConfigStore, ModelRecordServiceBase],
event_bus: Optional[EventServiceBase] = None):
"""
Initialize a ModelLoadService
:param config: InvokeAIAppConfig object
:param store: ModelConfigStore object for fetching configuration information
:param event_bus: Optional EventServiceBase object. If provided,
installation and download events will be sent to the event bus.
"""
pass
@abstractmethod
def get_model(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""Retrieve the indicated model identified by key.
:param key: Unique key returned by the ModelConfigStore module.
:param submodel_type: Submodel to return (required for main models)
:param context" Optional InvocationContext, used in event reporting.
"""
pass
@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""Reset model cache statistics for graph with graph_id."""
pass
# implementation
class ModelLoadService(ModelLoadServiceBase):
"""Responsible for managing models on disk and in memory."""
_loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process")
_event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None)
def __init__(self,
config: InvokeAIAppConfig,
store: Union[ModelConfigStore, ModelRecordServiceBase],
event_bus: Optional[EventServiceBase] = None
):
"""
Initialize a ModelManagerService.
:param config: InvokeAIAppConfig object
:param store: ModelRecordServiceBase or ModelConfigStore object for fetching configuration information
:param event_bus: Optional EventServiceBase object. If provided,
installation and download events will be sent to the event bus.
"""
self._event_bus = event_bus
kwargs: Dict[str, Any] = {}
if self._event_bus:
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
self._loader = ModelLoad(config, store, **kwargs)
def get_model(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""
Retrieve the indicated model.
The submodel is required when fetching a main model.
"""
model_info: ModelInfo = self._loader.get_model(key, submodel_type)
# we can emit model loading events if we are executing with access to the invocation context
if context:
self._emit_load_event(
context=context,
model_key=key,
submodel=submodel_type,
model_info=model_info,
)
return model_info
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics. Is this used?
"""
self._loader.collect_cache_stats(cache_stats)
def _emit_load_event(
self,
context: InvocationContext,
model_key: str,
submodel: Optional[SubModelType] = None,
model_info: Optional[ModelInfo] = None,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()
if model_info:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_key=model_key,
submodel=submodel,
model_info=model_info,
)
else:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_key=model_key,
submodel=submodel,
)