From 7b6159f8d65a3876ef7ecc8e8a31dacf9a50d213 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 16 Jul 2023 02:12:01 +1000 Subject: [PATCH] feat(nodes): emit model loading events - remove dependency on having access to a `node` during emits, would need a bit of additional args passed through the system and I don't think its necessary at this point. this also allowed us to drop an extraneous fetching/parsing of the session from db. - provide the invocation context to all `get_model()` calls, so the events are able to be emitted - test all model loading events in the app and confirm socket events are received --- invokeai/app/invocations/latent.py | 15 +++++++------- invokeai/app/services/events.py | 12 +++-------- .../app/services/model_manager_service.py | 20 +++++-------------- 3 files changed, 16 insertions(+), 31 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index baf78c7c23..f9844c5932 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -76,7 +76,7 @@ def get_scheduler( scheduler_name, SCHEDULER_MAP['ddim'] ) orig_scheduler_info = context.services.model_manager.get_model( - **scheduler_info.dict() + **scheduler_info.dict(), context=context, ) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -262,6 +262,7 @@ class TextToLatentsInvocation(BaseInvocation): model_name=control_info.control_model.model_name, model_type=ModelType.ControlNet, base_model=control_info.control_model.base_model, + context=context, ) ) @@ -313,14 +314,14 @@ class TextToLatentsInvocation(BaseInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"}) + **lora.dict(exclude={"weight"}), context=context, ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict() + **self.unet.unet.dict(), context=context, ) with ExitStack() as exit_stack,\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ @@ -403,14 +404,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"}) + **lora.dict(exclude={"weight"}), context=context, ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict() + **self.unet.unet.dict(), context=context, ) with ExitStack() as exit_stack,\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ @@ -491,7 +492,7 @@ class LatentsToImageInvocation(BaseInvocation): latents = context.services.latents.get(self.latents.latents_name) vae_info = context.services.model_manager.get_model( - **self.vae.vae.dict(), + **self.vae.vae.dict(), context=context, ) with vae_info as vae: @@ -636,7 +637,7 @@ class ImageToLatentsInvocation(BaseInvocation): #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) vae_info = context.services.model_manager.get_model( - **self.vae.vae.dict(), + **self.vae.vae.dict(), context=context, ) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 6c516c9b74..30d1b5e7a9 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -105,8 +105,6 @@ class EventServiceBase: def emit_model_load_started ( self, graph_execution_state_id: str, - node: dict, - source_node_id: str, model_name: str, base_model: BaseModelType, model_type: ModelType, @@ -117,8 +115,6 @@ class EventServiceBase: event_name="model_load_started", payload=dict( graph_execution_state_id=graph_execution_state_id, - node=node, - source_node_id=source_node_id, model_name=model_name, base_model=base_model, model_type=model_type, @@ -129,8 +125,6 @@ class EventServiceBase: def emit_model_load_completed( self, graph_execution_state_id: str, - node: dict, - source_node_id: str, model_name: str, base_model: BaseModelType, model_type: ModelType, @@ -142,12 +136,12 @@ class EventServiceBase: event_name="model_load_completed", payload=dict( graph_execution_state_id=graph_execution_state_id, - node=node, - source_node_id=source_node_id, model_name=model_name, base_model=base_model, model_type=model_type, submodel=submodel, - model_info=model_info, + hash=model_info.hash, + location=model_info.location, + precision=str(model_info.precision), ), ) diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 67db5c9478..51ba1199fb 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -338,7 +338,6 @@ class ModelManagerService(ModelManagerServiceBase): base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - node: Optional[BaseInvocation] = None, context: Optional[InvocationContext] = None, ) -> ModelInfo: """ @@ -346,11 +345,9 @@ class ModelManagerService(ModelManagerServiceBase): part (such as the vae) of a diffusers mode. """ - # if we are called from within a node, then we get to emit - # load start and complete events - if node and context: + # we can emit model loading events if we are executing with access to the invocation context + if context: self._emit_load_event( - node=node, context=context, model_name=model_name, base_model=base_model, @@ -365,9 +362,8 @@ class ModelManagerService(ModelManagerServiceBase): submodel, ) - if node and context: + if context: self._emit_load_event( - node=node, context=context, model_name=model_name, base_model=base_model, @@ -509,23 +505,19 @@ class ModelManagerService(ModelManagerServiceBase): def _emit_load_event( self, - node, context, model_name: str, base_model: BaseModelType, model_type: ModelType, - submodel: SubModelType, + submodel: Optional[SubModelType] = None, model_info: Optional[ModelInfo] = None, ): if context.services.queue.is_canceled(context.graph_execution_state_id): raise CanceledException() - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[node.id] + if model_info: context.services.events.emit_model_load_completed( graph_execution_state_id=context.graph_execution_state_id, - node=node.dict(), - source_node_id=source_node_id, model_name=model_name, base_model=base_model, model_type=model_type, @@ -535,8 +527,6 @@ class ModelManagerService(ModelManagerServiceBase): else: context.services.events.emit_model_load_started( graph_execution_state_id=context.graph_execution_state_id, - node=node.dict(), - source_node_id=source_node_id, model_name=model_name, base_model=base_model, model_type=model_type,