feat(nodes): add warning socket event

This commit is contained in:
psychedelicious 2023-06-13 20:51:24 +10:00
parent 42e48b2bef
commit abee37eab3
2 changed files with 35 additions and 1 deletions

View File

@ -49,8 +49,14 @@ class DynamicPromptInvocation(BaseInvocation):
combinatorial: bool = Field( combinatorial: bool = Field(
default=False, description="Whether to use the combinatorial generator" default=False, description="Whether to use the combinatorial generator"
) )
# wildcard_path: Optional[str] = Field(default=None, description="Wildcard path")
def invoke(self, context: InvocationContext) -> PromptListOutput: def invoke(self, context: InvocationContext) -> PromptListOutput:
# if self.wildcard_path is not None:
# try:
# os.stat(self.wildcard_path)
# except FileNotFoundError:
# context.services.logger.warn(f"Invalid wildcard path ({self.wildcard_path}), ignoring")
try: try:
if self.combinatorial: if self.combinatorial:
generator = CombinatorialPromptGenerator() generator = CombinatorialPromptGenerator()
@ -61,6 +67,16 @@ class DynamicPromptInvocation(BaseInvocation):
except ParseException as e: except ParseException as e:
warning = f"Invalid dynamic prompt: {e}" warning = f"Invalid dynamic prompt: {e}"
context.services.logger.warn(warning) context.services.logger.warn(warning)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
context.services.events.emit_invocation_warning(
warning=warning,
graph_execution_state_id=context.graph_execution_state_id,
node=self.dict(),
source_node_id=source_node_id,
)
return PromptListOutput(prompts=[self.prompt], count=1) return PromptListOutput(prompts=[self.prompt], count=1)
return PromptListOutput(prompts=prompts, count=len(prompts)) return PromptListOutput(prompts=prompts, count=len(prompts))

View File

@ -69,7 +69,7 @@ class EventServiceBase:
source_node_id: str, source_node_id: str,
error: str, error: str,
) -> None: ) -> None:
"""Emitted when an invocation has completed""" """Emitted when an invocation has encountered a fatal error"""
self.__emit_session_event( self.__emit_session_event(
event_name="invocation_error", event_name="invocation_error",
payload=dict( payload=dict(
@ -80,6 +80,24 @@ class EventServiceBase:
), ),
) )
def emit_invocation_warning(
self,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
warning: str,
) -> None:
"""Emitted when an invocation has encountered a state that may be problematic"""
self.__emit_session_event(
event_name="invocation_warning",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
warning=warning,
),
)
def emit_invocation_started( def emit_invocation_started(
self, graph_execution_state_id: str, node: dict, source_node_id: str self, graph_execution_state_id: str, node: dict, source_node_id: str
) -> None: ) -> None: