feat(nodes): emit model loading events

- remove dependency on having access to a `node` during emits, would need a bit of additional args passed through the system and I don't think its necessary at this point. this also allowed us to drop an extraneous fetching/parsing of the session from db.
- provide the invocation context to all `get_model()` calls, so the events are able to be emitted
- test all model loading events in the app and confirm socket events are received
This commit is contained in:
psychedelicious 2023-07-16 02:12:01 +10:00
parent 610c3a4512
commit 7b6159f8d6
3 changed files with 16 additions and 31 deletions

View File

@ -76,7 +76,7 @@ def get_scheduler(
scheduler_name, SCHEDULER_MAP['ddim'] scheduler_name, SCHEDULER_MAP['ddim']
) )
orig_scheduler_info = context.services.model_manager.get_model( orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict() **scheduler_info.dict(), context=context,
) )
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
@ -262,6 +262,7 @@ class TextToLatentsInvocation(BaseInvocation):
model_name=control_info.control_model.model_name, model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet, model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model, base_model=control_info.control_model.base_model,
context=context,
) )
) )
@ -313,14 +314,14 @@ class TextToLatentsInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}) **lora.dict(exclude={"weight"}), context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context,
) )
with ExitStack() as exit_stack,\ with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
@ -403,14 +404,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}) **lora.dict(exclude={"weight"}), context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context,
) )
with ExitStack() as exit_stack,\ with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
@ -491,7 +492,7 @@ class LatentsToImageInvocation(BaseInvocation):
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), **self.vae.vae.dict(), context=context,
) )
with vae_info as vae: with vae_info as vae:
@ -636,7 +637,7 @@ class ImageToLatentsInvocation(BaseInvocation):
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), **self.vae.vae.dict(), context=context,
) )
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))

View File

@ -105,8 +105,6 @@ class EventServiceBase:
def emit_model_load_started ( def emit_model_load_started (
self, self,
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
@ -117,8 +115,6 @@ class EventServiceBase:
event_name="model_load_started", event_name="model_load_started",
payload=dict( payload=dict(
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
@ -129,8 +125,6 @@ class EventServiceBase:
def emit_model_load_completed( def emit_model_load_completed(
self, self,
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
@ -142,12 +136,12 @@ class EventServiceBase:
event_name="model_load_completed", event_name="model_load_completed",
payload=dict( payload=dict(
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info, hash=model_info.hash,
location=model_info.location,
precision=str(model_info.precision),
), ),
) )

View File

@ -338,7 +338,6 @@ class ModelManagerService(ModelManagerServiceBase):
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel: Optional[SubModelType] = None, submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None, context: Optional[InvocationContext] = None,
) -> ModelInfo: ) -> ModelInfo:
""" """
@ -346,11 +345,9 @@ class ModelManagerService(ModelManagerServiceBase):
part (such as the vae) of a diffusers mode. part (such as the vae) of a diffusers mode.
""" """
# if we are called from within a node, then we get to emit # we can emit model loading events if we are executing with access to the invocation context
# load start and complete events if context:
if node and context:
self._emit_load_event( self._emit_load_event(
node=node,
context=context, context=context,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
@ -365,9 +362,8 @@ class ModelManagerService(ModelManagerServiceBase):
submodel, submodel,
) )
if node and context: if context:
self._emit_load_event( self._emit_load_event(
node=node,
context=context, context=context,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
@ -509,23 +505,19 @@ class ModelManagerService(ModelManagerServiceBase):
def _emit_load_event( def _emit_load_event(
self, self,
node,
context, context,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel: SubModelType, submodel: Optional[SubModelType] = None,
model_info: Optional[ModelInfo] = None, model_info: Optional[ModelInfo] = None,
): ):
if context.services.queue.is_canceled(context.graph_execution_state_id): if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException() raise CanceledException()
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
if model_info: if model_info:
context.services.events.emit_model_load_completed( context.services.events.emit_model_load_completed(
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
@ -535,8 +527,6 @@ class ModelManagerService(ModelManagerServiceBase):
else: else:
context.services.events.emit_model_load_started( context.services.events.emit_model_load_started(
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,