feat(events): add submodel_type to model load events

This was lost during MM2 migration
This commit is contained in:
psychedelicious 2024-03-14 17:38:49 +11:00 committed by blessedcoolant
parent ba3d8af161
commit ef55077e84
2 changed files with 10 additions and 0 deletions

View File

@ -12,6 +12,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
) )
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager import AnyModelConfig from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.config import SubModelType
class EventServiceBase: class EventServiceBase:
@ -180,6 +181,7 @@ class EventServiceBase:
queue_batch_id: str, queue_batch_id: str,
graph_execution_state_id: str, graph_execution_state_id: str,
model_config: AnyModelConfig, model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None: ) -> None:
"""Emitted when a model is requested""" """Emitted when a model is requested"""
self.__emit_queue_event( self.__emit_queue_event(
@ -190,6 +192,7 @@ class EventServiceBase:
"queue_batch_id": queue_batch_id, "queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"), "model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
}, },
) )
@ -200,6 +203,7 @@ class EventServiceBase:
queue_batch_id: str, queue_batch_id: str,
graph_execution_state_id: str, graph_execution_state_id: str,
model_config: AnyModelConfig, model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None: ) -> None:
"""Emitted when a model is correctly loaded (returns model info)""" """Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event( self.__emit_queue_event(
@ -210,6 +214,7 @@ class EventServiceBase:
"queue_batch_id": queue_batch_id, "queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"), "model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
}, },
) )

View File

@ -68,6 +68,7 @@ class ModelLoadService(ModelLoadServiceBase):
self._emit_load_event( self._emit_load_event(
context_data=context_data, context_data=context_data,
model_config=model_config, model_config=model_config,
submodel_type=submodel_type,
) )
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
@ -82,6 +83,7 @@ class ModelLoadService(ModelLoadServiceBase):
self._emit_load_event( self._emit_load_event(
context_data=context_data, context_data=context_data,
model_config=model_config, model_config=model_config,
submodel_type=submodel_type,
loaded=True, loaded=True,
) )
return loaded_model return loaded_model
@ -91,6 +93,7 @@ class ModelLoadService(ModelLoadServiceBase):
context_data: InvocationContextData, context_data: InvocationContextData,
model_config: AnyModelConfig, model_config: AnyModelConfig,
loaded: Optional[bool] = False, loaded: Optional[bool] = False,
submodel_type: Optional[SubModelType] = None,
) -> None: ) -> None:
if not self._invoker: if not self._invoker:
return return
@ -102,6 +105,7 @@ class ModelLoadService(ModelLoadServiceBase):
queue_batch_id=context_data.queue_item.batch_id, queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id, graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config, model_config=model_config,
submodel_type=submodel_type,
) )
else: else:
self._invoker.services.events.emit_model_load_completed( self._invoker.services.events.emit_model_load_completed(
@ -110,4 +114,5 @@ class ModelLoadService(ModelLoadServiceBase):
queue_batch_id=context_data.queue_item.batch_id, queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id, graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config, model_config=model_config,
submodel_type=submodel_type,
) )