From abee37eab3a804ca159fb7b913270238f5b90d3b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 13 Jun 2023 20:51:24 +1000 Subject: [PATCH] feat(nodes): add warning socket event --- invokeai/app/invocations/prompt.py | 16 ++++++++++++++++ invokeai/app/services/events.py | 20 +++++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index fd9e08912d..e29640d211 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -49,8 +49,14 @@ class DynamicPromptInvocation(BaseInvocation): combinatorial: bool = Field( default=False, description="Whether to use the combinatorial generator" ) + # wildcard_path: Optional[str] = Field(default=None, description="Wildcard path") def invoke(self, context: InvocationContext) -> PromptListOutput: + # if self.wildcard_path is not None: + # try: + # os.stat(self.wildcard_path) + # except FileNotFoundError: + # context.services.logger.warn(f"Invalid wildcard path ({self.wildcard_path}), ignoring") try: if self.combinatorial: generator = CombinatorialPromptGenerator() @@ -61,6 +67,16 @@ class DynamicPromptInvocation(BaseInvocation): except ParseException as e: warning = f"Invalid dynamic prompt: {e}" context.services.logger.warn(warning) + graph_execution_state = context.services.graph_execution_manager.get( + context.graph_execution_state_id + ) + source_node_id = graph_execution_state.prepared_source_mapping[self.id] + context.services.events.emit_invocation_warning( + warning=warning, + graph_execution_state_id=context.graph_execution_state_id, + node=self.dict(), + source_node_id=source_node_id, + ) return PromptListOutput(prompts=[self.prompt], count=1) return PromptListOutput(prompts=prompts, count=len(prompts)) diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 788f24dbce..a22b2989d3 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -69,7 +69,7 @@ class EventServiceBase: source_node_id: str, error: str, ) -> None: - """Emitted when an invocation has completed""" + """Emitted when an invocation has encountered a fatal error""" self.__emit_session_event( event_name="invocation_error", payload=dict( @@ -80,6 +80,24 @@ class EventServiceBase: ), ) + def emit_invocation_warning( + self, + graph_execution_state_id: str, + node: dict, + source_node_id: str, + warning: str, + ) -> None: + """Emitted when an invocation has encountered a state that may be problematic""" + self.__emit_session_event( + event_name="invocation_warning", + payload=dict( + graph_execution_state_id=graph_execution_state_id, + node=node, + source_node_id=source_node_id, + warning=warning, + ), + ) + def emit_invocation_started( self, graph_execution_state_id: str, node: dict, source_node_id: str ) -> None: