mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
merge with main and resolve conflicts
This commit is contained in:
@ -7,7 +7,6 @@ from typing import Callable, Dict, Optional
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
@ -18,18 +17,12 @@ class ModelLoadServiceBase(ABC):
|
||||
"""Wrapper around AnyModelLoader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context_data: Invocation context data used for event reporting
|
||||
"""
|
||||
|
||||
@property
|
||||
|
@ -11,7 +11,6 @@ from torch import load as torch_load
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import (
|
||||
LoadedModel,
|
||||
@ -59,25 +58,18 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
if context_data:
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
||||
# We don't have an invoker during testing
|
||||
# TODO(psyche): Mock this method on the invoker in the tests
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
|
||||
|
||||
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
loaded_model: LoadedModel = implementation(
|
||||
@ -87,13 +79,9 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
convert_cache=self._convert_cache,
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
if context_data:
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
loaded=True,
|
||||
)
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
|
||||
|
||||
return loaded_model
|
||||
|
||||
def load_model_from_path(
|
||||
@ -150,32 +138,3 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
raw_model = loader(model_path)
|
||||
ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModel(_locker=ram_cache.get(key=cache_key))
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context_data: InvocationContextData,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
if not self._invoker:
|
||||
return
|
||||
|
||||
if not loaded:
|
||||
self._invoker.services.events.emit_model_load_started(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
else:
|
||||
self._invoker.services.events.emit_model_load_completed(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
Reference in New Issue
Block a user