mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: canvas not working on queue
Add `batch_id` to outbound events. This necessitates adding it to both `InvocationContext` and `InvocationQueueItem`. This allows the canvas to receive images. When the user enqueues a batch on the canvas, it is expected that all images from that batch are directed to the canvas. The simplest, most flexible solution is to add the `batch_id` to the invocation context-y stuff. Then everything knows what batch it came from, and we can have the canvas pick up images associated with its list of canvas `batch_id`s.
This commit is contained in:
committed by
Kent Keirsey
parent
1c38cce16d
commit
bdfdf854fc
@ -425,12 +425,21 @@ class InvocationContext:
|
|||||||
graph_execution_state_id: str
|
graph_execution_state_id: str
|
||||||
queue_id: str
|
queue_id: str
|
||||||
queue_item_id: int
|
queue_item_id: int
|
||||||
|
queue_batch_id: str
|
||||||
|
|
||||||
def __init__(self, services: InvocationServices, queue_id: str, queue_item_id: int, graph_execution_state_id: str):
|
def __init__(
|
||||||
|
self,
|
||||||
|
services: InvocationServices,
|
||||||
|
queue_id: str,
|
||||||
|
queue_item_id: int,
|
||||||
|
queue_batch_id: str,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
):
|
||||||
self.services = services
|
self.services = services
|
||||||
self.graph_execution_state_id = graph_execution_state_id
|
self.graph_execution_state_id = graph_execution_state_id
|
||||||
self.queue_id = queue_id
|
self.queue_id = queue_id
|
||||||
self.queue_item_id = queue_item_id
|
self.queue_item_id = queue_item_id
|
||||||
|
self.queue_batch_id = queue_batch_id
|
||||||
|
|
||||||
|
|
||||||
class BaseInvocationOutput(BaseModel):
|
class BaseInvocationOutput(BaseModel):
|
||||||
|
@ -30,6 +30,7 @@ class EventServiceBase:
|
|||||||
self,
|
self,
|
||||||
queue_id: str,
|
queue_id: str,
|
||||||
queue_item_id: int,
|
queue_item_id: int,
|
||||||
|
queue_batch_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
node: dict,
|
node: dict,
|
||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
@ -44,6 +45,7 @@ class EventServiceBase:
|
|||||||
payload=dict(
|
payload=dict(
|
||||||
queue_id=queue_id,
|
queue_id=queue_id,
|
||||||
queue_item_id=queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
|
queue_batch_id=queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node_id=node.get("id"),
|
node_id=node.get("id"),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
@ -58,6 +60,7 @@ class EventServiceBase:
|
|||||||
self,
|
self,
|
||||||
queue_id: str,
|
queue_id: str,
|
||||||
queue_item_id: int,
|
queue_item_id: int,
|
||||||
|
queue_batch_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
result: dict,
|
result: dict,
|
||||||
node: dict,
|
node: dict,
|
||||||
@ -69,6 +72,7 @@ class EventServiceBase:
|
|||||||
payload=dict(
|
payload=dict(
|
||||||
queue_id=queue_id,
|
queue_id=queue_id,
|
||||||
queue_item_id=queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
|
queue_batch_id=queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
node=node,
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
@ -80,6 +84,7 @@ class EventServiceBase:
|
|||||||
self,
|
self,
|
||||||
queue_id: str,
|
queue_id: str,
|
||||||
queue_item_id: int,
|
queue_item_id: int,
|
||||||
|
queue_batch_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
node: dict,
|
node: dict,
|
||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
@ -92,6 +97,7 @@ class EventServiceBase:
|
|||||||
payload=dict(
|
payload=dict(
|
||||||
queue_id=queue_id,
|
queue_id=queue_id,
|
||||||
queue_item_id=queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
|
queue_batch_id=queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
node=node,
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
@ -101,7 +107,13 @@ class EventServiceBase:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_started(
|
def emit_invocation_started(
|
||||||
self, queue_id: str, queue_item_id: int, graph_execution_state_id: str, node: dict, source_node_id: str
|
self,
|
||||||
|
queue_id: str,
|
||||||
|
queue_item_id: int,
|
||||||
|
queue_batch_id: str,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
node: dict,
|
||||||
|
source_node_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when an invocation has started"""
|
"""Emitted when an invocation has started"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
@ -109,19 +121,23 @@ class EventServiceBase:
|
|||||||
payload=dict(
|
payload=dict(
|
||||||
queue_id=queue_id,
|
queue_id=queue_id,
|
||||||
queue_item_id=queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
|
queue_batch_id=queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
node=node,
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_graph_execution_complete(self, queue_id: str, queue_item_id: int, graph_execution_state_id: str) -> None:
|
def emit_graph_execution_complete(
|
||||||
|
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
|
||||||
|
) -> None:
|
||||||
"""Emitted when a session has completed all invocations"""
|
"""Emitted when a session has completed all invocations"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
event_name="graph_execution_state_complete",
|
event_name="graph_execution_state_complete",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
queue_id=queue_id,
|
queue_id=queue_id,
|
||||||
queue_item_id=queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
|
queue_batch_id=queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -130,6 +146,7 @@ class EventServiceBase:
|
|||||||
self,
|
self,
|
||||||
queue_id: str,
|
queue_id: str,
|
||||||
queue_item_id: int,
|
queue_item_id: int,
|
||||||
|
queue_batch_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
@ -142,6 +159,7 @@ class EventServiceBase:
|
|||||||
payload=dict(
|
payload=dict(
|
||||||
queue_id=queue_id,
|
queue_id=queue_id,
|
||||||
queue_item_id=queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
|
queue_batch_id=queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
@ -154,6 +172,7 @@ class EventServiceBase:
|
|||||||
self,
|
self,
|
||||||
queue_id: str,
|
queue_id: str,
|
||||||
queue_item_id: int,
|
queue_item_id: int,
|
||||||
|
queue_batch_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
@ -167,6 +186,7 @@ class EventServiceBase:
|
|||||||
payload=dict(
|
payload=dict(
|
||||||
queue_id=queue_id,
|
queue_id=queue_id,
|
||||||
queue_item_id=queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
|
queue_batch_id=queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
@ -182,6 +202,7 @@ class EventServiceBase:
|
|||||||
self,
|
self,
|
||||||
queue_id: str,
|
queue_id: str,
|
||||||
queue_item_id: int,
|
queue_item_id: int,
|
||||||
|
queue_batch_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
error_type: str,
|
error_type: str,
|
||||||
error: str,
|
error: str,
|
||||||
@ -192,6 +213,7 @@ class EventServiceBase:
|
|||||||
payload=dict(
|
payload=dict(
|
||||||
queue_id=queue_id,
|
queue_id=queue_id,
|
||||||
queue_item_id=queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
|
queue_batch_id=queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
error_type=error_type,
|
error_type=error_type,
|
||||||
error=error,
|
error=error,
|
||||||
@ -202,6 +224,7 @@ class EventServiceBase:
|
|||||||
self,
|
self,
|
||||||
queue_id: str,
|
queue_id: str,
|
||||||
queue_item_id: int,
|
queue_item_id: int,
|
||||||
|
queue_batch_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
error_type: str,
|
error_type: str,
|
||||||
@ -213,6 +236,7 @@ class EventServiceBase:
|
|||||||
payload=dict(
|
payload=dict(
|
||||||
queue_id=queue_id,
|
queue_id=queue_id,
|
||||||
queue_item_id=queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
|
queue_batch_id=queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
error_type=error_type,
|
error_type=error_type,
|
||||||
@ -224,6 +248,7 @@ class EventServiceBase:
|
|||||||
self,
|
self,
|
||||||
queue_id: str,
|
queue_id: str,
|
||||||
queue_item_id: int,
|
queue_item_id: int,
|
||||||
|
queue_batch_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when a session is canceled"""
|
"""Emitted when a session is canceled"""
|
||||||
@ -232,6 +257,7 @@ class EventServiceBase:
|
|||||||
payload=dict(
|
payload=dict(
|
||||||
queue_id=queue_id,
|
queue_id=queue_id,
|
||||||
queue_item_id=queue_item_id,
|
queue_item_id=queue_item_id,
|
||||||
|
queue_batch_id=queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -15,6 +15,9 @@ class InvocationQueueItem(BaseModel):
|
|||||||
session_queue_item_id: int = Field(
|
session_queue_item_id: int = Field(
|
||||||
description="The ID of session queue item from which this invocation queue item came"
|
description="The ID of session queue item from which this invocation queue item came"
|
||||||
)
|
)
|
||||||
|
session_queue_batch_id: str = Field(
|
||||||
|
description="The ID of the session batch from which this invocation queue item came"
|
||||||
|
)
|
||||||
invoke_all: bool = Field(default=False)
|
invoke_all: bool = Field(default=False)
|
||||||
timestamp: float = Field(default_factory=time.time)
|
timestamp: float = Field(default_factory=time.time)
|
||||||
|
|
||||||
|
@ -18,7 +18,12 @@ class Invoker:
|
|||||||
self._start()
|
self._start()
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, queue_id: str, queue_item_id: int, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
self,
|
||||||
|
session_queue_id: str,
|
||||||
|
session_queue_item_id: int,
|
||||||
|
session_queue_batch_id: str,
|
||||||
|
graph_execution_state: GraphExecutionState,
|
||||||
|
invoke_all: bool = False,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||||
@ -34,8 +39,9 @@ class Invoker:
|
|||||||
# Queue the invocation
|
# Queue the invocation
|
||||||
self.services.queue.put(
|
self.services.queue.put(
|
||||||
InvocationQueueItem(
|
InvocationQueueItem(
|
||||||
session_queue_item_id=queue_item_id,
|
session_queue_id=session_queue_id,
|
||||||
session_queue_id=queue_id,
|
session_queue_item_id=session_queue_item_id,
|
||||||
|
session_queue_batch_id=session_queue_batch_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
invocation_id=invocation.id,
|
invocation_id=invocation.id,
|
||||||
invoke_all=invoke_all,
|
invoke_all=invoke_all,
|
||||||
|
@ -539,6 +539,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
context.services.events.emit_model_load_completed(
|
context.services.events.emit_model_load_completed(
|
||||||
queue_id=context.queue_id,
|
queue_id=context.queue_id,
|
||||||
queue_item_id=context.queue_item_id,
|
queue_item_id=context.queue_item_id,
|
||||||
|
queue_batch_id=context.queue_batch_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context.graph_execution_state_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
@ -550,6 +551,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
context.services.events.emit_model_load_started(
|
context.services.events.emit_model_load_started(
|
||||||
queue_id=context.queue_id,
|
queue_id=context.queue_id,
|
||||||
queue_item_id=context.queue_item_id,
|
queue_item_id=context.queue_item_id,
|
||||||
|
queue_batch_id=context.queue_batch_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context.graph_execution_state_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
|
@ -57,6 +57,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
||||||
self.__invoker.services.events.emit_session_retrieval_error(
|
self.__invoker.services.events.emit_session_retrieval_error(
|
||||||
|
queue_batch_id=queue_item.session_queue_batch_id,
|
||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||||
@ -70,6 +71,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
||||||
self.__invoker.services.events.emit_invocation_retrieval_error(
|
self.__invoker.services.events.emit_invocation_retrieval_error(
|
||||||
|
queue_batch_id=queue_item.session_queue_batch_id,
|
||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||||
@ -84,6 +86,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
|
|
||||||
# Send starting event
|
# Send starting event
|
||||||
self.__invoker.services.events.emit_invocation_started(
|
self.__invoker.services.events.emit_invocation_started(
|
||||||
|
queue_batch_id=queue_item.session_queue_batch_id,
|
||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
@ -106,6 +109,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
|
queue_batch_id=queue_item.session_queue_batch_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -121,6 +125,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
|
|
||||||
# Send complete event
|
# Send complete event
|
||||||
self.__invoker.services.events.emit_invocation_complete(
|
self.__invoker.services.events.emit_invocation_complete(
|
||||||
|
queue_batch_id=queue_item.session_queue_batch_id,
|
||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
@ -150,6 +155,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||||
# Send error event
|
# Send error event
|
||||||
self.__invoker.services.events.emit_invocation_error(
|
self.__invoker.services.events.emit_invocation_error(
|
||||||
|
queue_batch_id=queue_item.session_queue_batch_id,
|
||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
@ -170,14 +176,16 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
if queue_item.invoke_all and not is_complete:
|
if queue_item.invoke_all and not is_complete:
|
||||||
try:
|
try:
|
||||||
self.__invoker.invoke(
|
self.__invoker.invoke(
|
||||||
queue_item_id=queue_item.session_queue_item_id,
|
session_queue_batch_id=queue_item.session_queue_batch_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
session_queue_item_id=queue_item.session_queue_item_id,
|
||||||
|
session_queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state=graph_execution_state,
|
graph_execution_state=graph_execution_state,
|
||||||
invoke_all=True,
|
invoke_all=True,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||||
self.__invoker.services.events.emit_invocation_error(
|
self.__invoker.services.events.emit_invocation_error(
|
||||||
|
queue_batch_id=queue_item.session_queue_batch_id,
|
||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
@ -188,6 +196,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
)
|
)
|
||||||
elif is_complete:
|
elif is_complete:
|
||||||
self.__invoker.services.events.emit_graph_execution_complete(
|
self.__invoker.services.events.emit_graph_execution_complete(
|
||||||
|
queue_batch_id=queue_item.session_queue_batch_id,
|
||||||
queue_item_id=queue_item.session_queue_item_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
|
@ -102,8 +102,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self.__queue_item = queue_item
|
self.__queue_item = queue_item
|
||||||
self.__invoker.services.graph_execution_manager.set(queue_item.session)
|
self.__invoker.services.graph_execution_manager.set(queue_item.session)
|
||||||
self.__invoker.invoke(
|
self.__invoker.invoke(
|
||||||
queue_item_id=queue_item.item_id,
|
session_queue_batch_id=queue_item.batch_id,
|
||||||
queue_id=queue_item.queue_id,
|
session_queue_id=queue_item.queue_id,
|
||||||
|
session_queue_item_id=queue_item.item_id,
|
||||||
graph_execution_state=queue_item.session,
|
graph_execution_state=queue_item.session,
|
||||||
invoke_all=True,
|
invoke_all=True,
|
||||||
)
|
)
|
||||||
|
@ -562,6 +562,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__invoker.services.events.emit_session_canceled(
|
self.__invoker.services.events.emit_session_canceled(
|
||||||
queue_item_id=queue_item.item_id,
|
queue_item_id=queue_item.item_id,
|
||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
|
queue_batch_id=queue_item.batch_id,
|
||||||
graph_execution_state_id=queue_item.session_id,
|
graph_execution_state_id=queue_item.session_id,
|
||||||
)
|
)
|
||||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||||
@ -604,6 +605,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__invoker.services.events.emit_session_canceled(
|
self.__invoker.services.events.emit_session_canceled(
|
||||||
queue_item_id=current_queue_item.item_id,
|
queue_item_id=current_queue_item.item_id,
|
||||||
queue_id=current_queue_item.queue_id,
|
queue_id=current_queue_item.queue_id,
|
||||||
|
queue_batch_id=current_queue_item.batch_id,
|
||||||
graph_execution_state_id=current_queue_item.session_id,
|
graph_execution_state_id=current_queue_item.session_id,
|
||||||
)
|
)
|
||||||
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
|
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
|
||||||
@ -649,6 +651,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__invoker.services.events.emit_session_canceled(
|
self.__invoker.services.events.emit_session_canceled(
|
||||||
queue_item_id=current_queue_item.item_id,
|
queue_item_id=current_queue_item.item_id,
|
||||||
queue_id=current_queue_item.queue_id,
|
queue_id=current_queue_item.queue_id,
|
||||||
|
queue_batch_id=current_queue_item.batch_id,
|
||||||
graph_execution_state_id=current_queue_item.session_id,
|
graph_execution_state_id=current_queue_item.session_id,
|
||||||
)
|
)
|
||||||
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
|
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
|
||||||
|
@ -112,6 +112,7 @@ def stable_diffusion_step_callback(
|
|||||||
context.services.events.emit_generator_progress(
|
context.services.events.emit_generator_progress(
|
||||||
queue_id=context.queue_id,
|
queue_id=context.queue_id,
|
||||||
queue_item_id=context.queue_item_id,
|
queue_item_id=context.queue_item_id,
|
||||||
|
queue_batch_id=context.queue_batch_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context.graph_execution_state_id,
|
||||||
node=node,
|
node=node,
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { isAnyOf } from '@reduxjs/toolkit';
|
import { isAnyOf } from '@reduxjs/toolkit';
|
||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import {
|
import {
|
||||||
canvasBatchesAndSessionsReset,
|
canvasBatchIdsReset,
|
||||||
commitStagingAreaImage,
|
commitStagingAreaImage,
|
||||||
discardStagedImages,
|
discardStagedImages,
|
||||||
} from 'features/canvas/store/canvasSlice';
|
} from 'features/canvas/store/canvasSlice';
|
||||||
@ -38,7 +38,7 @@ export const addCommitStagingAreaImageListener = () => {
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
dispatch(canvasBatchesAndSessionsReset());
|
dispatch(canvasBatchIdsReset());
|
||||||
} catch {
|
} catch {
|
||||||
log.error('Failed to cancel canvas batches');
|
log.error('Failed to cancel canvas batches');
|
||||||
dispatch(
|
dispatch(
|
||||||
|
@ -30,7 +30,7 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
`Invocation complete (${action.payload.data.node.type})`
|
`Invocation complete (${action.payload.data.node.type})`
|
||||||
);
|
);
|
||||||
|
|
||||||
const { result, node, graph_execution_state_id } = data;
|
const { result, node, queue_batch_id } = data;
|
||||||
|
|
||||||
// This complete event has an associated image output
|
// This complete event has an associated image output
|
||||||
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
|
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
|
||||||
@ -43,7 +43,7 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
|
|
||||||
// Add canvas images to the staging area
|
// Add canvas images to the staging area
|
||||||
if (
|
if (
|
||||||
canvas.sessionIds.includes(graph_execution_state_id) &&
|
canvas.batchIds.includes(queue_batch_id) &&
|
||||||
[CANVAS_OUTPUT].includes(data.source_node_id)
|
[CANVAS_OUTPUT].includes(data.source_node_id)
|
||||||
) {
|
) {
|
||||||
dispatch(addImageToStagingArea(imageDTO));
|
dispatch(addImageToStagingArea(imageDTO));
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { canvasSessionIdAdded } from 'features/canvas/store/canvasSlice';
|
|
||||||
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
|
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
|
||||||
import {
|
import {
|
||||||
appSocketQueueItemStatusChanged,
|
appSocketQueueItemStatusChanged,
|
||||||
@ -10,12 +9,11 @@ import { startAppListening } from '../..';
|
|||||||
export const addSocketQueueItemStatusChangedEventListener = () => {
|
export const addSocketQueueItemStatusChangedEventListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketQueueItemStatusChanged,
|
actionCreator: socketQueueItemStatusChanged,
|
||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch }) => {
|
||||||
const log = logger('socketio');
|
const log = logger('socketio');
|
||||||
const {
|
const {
|
||||||
queue_item_id: item_id,
|
queue_item_id: item_id,
|
||||||
batch_id,
|
queue_batch_id,
|
||||||
graph_execution_state_id,
|
|
||||||
status,
|
status,
|
||||||
} = action.payload.data;
|
} = action.payload.data;
|
||||||
log.debug(
|
log.debug(
|
||||||
@ -36,11 +34,6 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
const state = getState();
|
|
||||||
if (state.canvas.batchIds.includes(batch_id)) {
|
|
||||||
dispatch(canvasSessionIdAdded(graph_execution_state_id));
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
queueApi.util.invalidateTags([
|
queueApi.util.invalidateTags([
|
||||||
'CurrentSessionQueueItem',
|
'CurrentSessionQueueItem',
|
||||||
@ -48,7 +41,7 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
|
|||||||
'SessionQueueStatus',
|
'SessionQueueStatus',
|
||||||
{ type: 'SessionQueueItem', id: item_id },
|
{ type: 'SessionQueueItem', id: item_id },
|
||||||
{ type: 'SessionQueueItemDTO', id: item_id },
|
{ type: 'SessionQueueItemDTO', id: item_id },
|
||||||
{ type: 'BatchStatus', id: batch_id },
|
{ type: 'BatchStatus', id: queue_batch_id },
|
||||||
])
|
])
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
|
@ -11,12 +11,12 @@ const selector = createSelector(
|
|||||||
({ system, canvas }) => {
|
({ system, canvas }) => {
|
||||||
const { denoiseProgress } = system;
|
const { denoiseProgress } = system;
|
||||||
const { boundingBox } = canvas.layerState.stagingArea;
|
const { boundingBox } = canvas.layerState.stagingArea;
|
||||||
const { sessionIds } = canvas;
|
const { batchIds } = canvas;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
boundingBox,
|
boundingBox,
|
||||||
progressImage:
|
progressImage:
|
||||||
denoiseProgress && sessionIds.includes(denoiseProgress.session_id)
|
denoiseProgress && batchIds.includes(denoiseProgress.batch_id)
|
||||||
? denoiseProgress.progress_image
|
? denoiseProgress.progress_image
|
||||||
: undefined,
|
: undefined,
|
||||||
};
|
};
|
||||||
|
@ -85,7 +85,6 @@ export const initialCanvasState: CanvasState = {
|
|||||||
stageDimensions: { width: 0, height: 0 },
|
stageDimensions: { width: 0, height: 0 },
|
||||||
stageScale: 1,
|
stageScale: 1,
|
||||||
tool: 'brush',
|
tool: 'brush',
|
||||||
sessionIds: [],
|
|
||||||
batchIds: [],
|
batchIds: [],
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -302,11 +301,7 @@ export const canvasSlice = createSlice({
|
|||||||
canvasBatchIdAdded: (state, action: PayloadAction<string>) => {
|
canvasBatchIdAdded: (state, action: PayloadAction<string>) => {
|
||||||
state.batchIds.push(action.payload);
|
state.batchIds.push(action.payload);
|
||||||
},
|
},
|
||||||
canvasSessionIdAdded: (state, action: PayloadAction<string>) => {
|
canvasBatchIdsReset: (state) => {
|
||||||
state.sessionIds.push(action.payload);
|
|
||||||
},
|
|
||||||
canvasBatchesAndSessionsReset: (state) => {
|
|
||||||
state.sessionIds = [];
|
|
||||||
state.batchIds = [];
|
state.batchIds = [];
|
||||||
},
|
},
|
||||||
stagingAreaInitialized: (
|
stagingAreaInitialized: (
|
||||||
@ -879,8 +874,7 @@ export const {
|
|||||||
setShouldAntialias,
|
setShouldAntialias,
|
||||||
canvasResized,
|
canvasResized,
|
||||||
canvasBatchIdAdded,
|
canvasBatchIdAdded,
|
||||||
canvasSessionIdAdded,
|
canvasBatchIdsReset,
|
||||||
canvasBatchesAndSessionsReset,
|
|
||||||
} = canvasSlice.actions;
|
} = canvasSlice.actions;
|
||||||
|
|
||||||
export default canvasSlice.reducer;
|
export default canvasSlice.reducer;
|
||||||
|
@ -166,7 +166,6 @@ export interface CanvasState {
|
|||||||
tool: CanvasTool;
|
tool: CanvasTool;
|
||||||
generationMode?: GenerationMode;
|
generationMode?: GenerationMode;
|
||||||
batchIds: string[];
|
batchIds: string[];
|
||||||
sessionIds: string[];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint';
|
export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint';
|
||||||
|
@ -113,6 +113,7 @@ export const systemSlice = createSlice({
|
|||||||
order,
|
order,
|
||||||
progress_image,
|
progress_image,
|
||||||
graph_execution_state_id: session_id,
|
graph_execution_state_id: session_id,
|
||||||
|
queue_batch_id: batch_id,
|
||||||
} = action.payload.data;
|
} = action.payload.data;
|
||||||
|
|
||||||
state.denoiseProgress = {
|
state.denoiseProgress = {
|
||||||
@ -122,6 +123,7 @@ export const systemSlice = createSlice({
|
|||||||
percentage: calculateStepPercentage(step, total_steps, order),
|
percentage: calculateStepPercentage(step, total_steps, order),
|
||||||
progress_image,
|
progress_image,
|
||||||
session_id,
|
session_id,
|
||||||
|
batch_id,
|
||||||
};
|
};
|
||||||
|
|
||||||
state.status = 'PROCESSING';
|
state.status = 'PROCESSING';
|
||||||
|
@ -12,6 +12,7 @@ export type SystemStatus =
|
|||||||
|
|
||||||
export type DenoiseProgress = {
|
export type DenoiseProgress = {
|
||||||
session_id: string;
|
session_id: string;
|
||||||
|
batch_id: string;
|
||||||
progress_image: ProgressImage | null | undefined;
|
progress_image: ProgressImage | null | undefined;
|
||||||
step: number;
|
step: number;
|
||||||
total_steps: number;
|
total_steps: number;
|
||||||
|
@ -34,7 +34,8 @@ export type BaseNode = {
|
|||||||
|
|
||||||
export type ModelLoadStartedEvent = {
|
export type ModelLoadStartedEvent = {
|
||||||
queue_id: string;
|
queue_id: string;
|
||||||
queue_item_id: string;
|
queue_item_id: number;
|
||||||
|
queue_batch_id: string;
|
||||||
graph_execution_state_id: string;
|
graph_execution_state_id: string;
|
||||||
model_name: string;
|
model_name: string;
|
||||||
base_model: BaseModelType;
|
base_model: BaseModelType;
|
||||||
@ -44,7 +45,8 @@ export type ModelLoadStartedEvent = {
|
|||||||
|
|
||||||
export type ModelLoadCompletedEvent = {
|
export type ModelLoadCompletedEvent = {
|
||||||
queue_id: string;
|
queue_id: string;
|
||||||
queue_item_id: string;
|
queue_item_id: number;
|
||||||
|
queue_batch_id: string;
|
||||||
graph_execution_state_id: string;
|
graph_execution_state_id: string;
|
||||||
model_name: string;
|
model_name: string;
|
||||||
base_model: BaseModelType;
|
base_model: BaseModelType;
|
||||||
@ -62,7 +64,8 @@ export type ModelLoadCompletedEvent = {
|
|||||||
*/
|
*/
|
||||||
export type GeneratorProgressEvent = {
|
export type GeneratorProgressEvent = {
|
||||||
queue_id: string;
|
queue_id: string;
|
||||||
queue_item_id: string;
|
queue_item_id: number;
|
||||||
|
queue_batch_id: string;
|
||||||
graph_execution_state_id: string;
|
graph_execution_state_id: string;
|
||||||
node_id: string;
|
node_id: string;
|
||||||
source_node_id: string;
|
source_node_id: string;
|
||||||
@ -81,7 +84,8 @@ export type GeneratorProgressEvent = {
|
|||||||
*/
|
*/
|
||||||
export type InvocationCompleteEvent = {
|
export type InvocationCompleteEvent = {
|
||||||
queue_id: string;
|
queue_id: string;
|
||||||
queue_item_id: string;
|
queue_item_id: number;
|
||||||
|
queue_batch_id: string;
|
||||||
graph_execution_state_id: string;
|
graph_execution_state_id: string;
|
||||||
node: BaseNode;
|
node: BaseNode;
|
||||||
source_node_id: string;
|
source_node_id: string;
|
||||||
@ -95,7 +99,8 @@ export type InvocationCompleteEvent = {
|
|||||||
*/
|
*/
|
||||||
export type InvocationErrorEvent = {
|
export type InvocationErrorEvent = {
|
||||||
queue_id: string;
|
queue_id: string;
|
||||||
queue_item_id: string;
|
queue_item_id: number;
|
||||||
|
queue_batch_id: string;
|
||||||
graph_execution_state_id: string;
|
graph_execution_state_id: string;
|
||||||
node: BaseNode;
|
node: BaseNode;
|
||||||
source_node_id: string;
|
source_node_id: string;
|
||||||
@ -110,7 +115,8 @@ export type InvocationErrorEvent = {
|
|||||||
*/
|
*/
|
||||||
export type InvocationStartedEvent = {
|
export type InvocationStartedEvent = {
|
||||||
queue_id: string;
|
queue_id: string;
|
||||||
queue_item_id: string;
|
queue_item_id: number;
|
||||||
|
queue_batch_id: string;
|
||||||
graph_execution_state_id: string;
|
graph_execution_state_id: string;
|
||||||
node: BaseNode;
|
node: BaseNode;
|
||||||
source_node_id: string;
|
source_node_id: string;
|
||||||
@ -123,7 +129,8 @@ export type InvocationStartedEvent = {
|
|||||||
*/
|
*/
|
||||||
export type GraphExecutionStateCompleteEvent = {
|
export type GraphExecutionStateCompleteEvent = {
|
||||||
queue_id: string;
|
queue_id: string;
|
||||||
queue_item_id: string;
|
queue_item_id: number;
|
||||||
|
queue_batch_id: string;
|
||||||
graph_execution_state_id: string;
|
graph_execution_state_id: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -134,7 +141,8 @@ export type GraphExecutionStateCompleteEvent = {
|
|||||||
*/
|
*/
|
||||||
export type SessionRetrievalErrorEvent = {
|
export type SessionRetrievalErrorEvent = {
|
||||||
queue_id: string;
|
queue_id: string;
|
||||||
queue_item_id: string;
|
queue_item_id: number;
|
||||||
|
queue_batch_id: string;
|
||||||
graph_execution_state_id: string;
|
graph_execution_state_id: string;
|
||||||
error_type: string;
|
error_type: string;
|
||||||
error: string;
|
error: string;
|
||||||
@ -147,7 +155,8 @@ export type SessionRetrievalErrorEvent = {
|
|||||||
*/
|
*/
|
||||||
export type InvocationRetrievalErrorEvent = {
|
export type InvocationRetrievalErrorEvent = {
|
||||||
queue_id: string;
|
queue_id: string;
|
||||||
queue_item_id: string;
|
queue_item_id: number;
|
||||||
|
queue_batch_id: string;
|
||||||
graph_execution_state_id: string;
|
graph_execution_state_id: string;
|
||||||
node_id: string;
|
node_id: string;
|
||||||
error_type: string;
|
error_type: string;
|
||||||
@ -161,8 +170,8 @@ export type InvocationRetrievalErrorEvent = {
|
|||||||
*/
|
*/
|
||||||
export type QueueItemStatusChangedEvent = {
|
export type QueueItemStatusChangedEvent = {
|
||||||
queue_id: string;
|
queue_id: string;
|
||||||
queue_item_id: string;
|
queue_item_id: number;
|
||||||
batch_id: string;
|
queue_batch_id: string;
|
||||||
session_id: string;
|
session_id: string;
|
||||||
graph_execution_state_id: string;
|
graph_execution_state_id: string;
|
||||||
status: components['schemas']['SessionQueueItemDTO']['status'];
|
status: components['schemas']['SessionQueueItemDTO']['status'];
|
||||||
|
@ -75,7 +75,13 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B
|
|||||||
|
|
||||||
print(f"invoking {n.id}: {type(n)}")
|
print(f"invoking {n.id}: {type(n)}")
|
||||||
o = n.invoke(
|
o = n.invoke(
|
||||||
InvocationContext(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, services=services, graph_execution_state_id="1")
|
InvocationContext(
|
||||||
|
queue_batch_id="1",
|
||||||
|
queue_item_id=1,
|
||||||
|
queue_id=DEFAULT_QUEUE_ID,
|
||||||
|
services=services,
|
||||||
|
graph_execution_state_id="1",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
g.complete(n.id, o)
|
g.complete(n.id, o)
|
||||||
|
|
||||||
|
@ -102,7 +102,12 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
|
|||||||
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||||
def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
||||||
g = mock_invoker.create_execution_state(graph=simple_graph)
|
g = mock_invoker.create_execution_state(graph=simple_graph)
|
||||||
invocation_id = mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g)
|
invocation_id = mock_invoker.invoke(
|
||||||
|
session_queue_batch_id="1",
|
||||||
|
session_queue_item_id=1,
|
||||||
|
session_queue_id=DEFAULT_QUEUE_ID,
|
||||||
|
graph_execution_state=g,
|
||||||
|
)
|
||||||
assert invocation_id is not None
|
assert invocation_id is not None
|
||||||
|
|
||||||
def has_executed_any(g: GraphExecutionState):
|
def has_executed_any(g: GraphExecutionState):
|
||||||
@ -120,7 +125,11 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
|||||||
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
||||||
g = mock_invoker.create_execution_state(graph=simple_graph)
|
g = mock_invoker.create_execution_state(graph=simple_graph)
|
||||||
invocation_id = mock_invoker.invoke(
|
invocation_id = mock_invoker.invoke(
|
||||||
queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True
|
session_queue_batch_id="1",
|
||||||
|
session_queue_item_id=1,
|
||||||
|
session_queue_id=DEFAULT_QUEUE_ID,
|
||||||
|
graph_execution_state=g,
|
||||||
|
invoke_all=True,
|
||||||
)
|
)
|
||||||
assert invocation_id is not None
|
assert invocation_id is not None
|
||||||
|
|
||||||
@ -140,7 +149,13 @@ def test_handles_errors(mock_invoker: Invoker):
|
|||||||
g = mock_invoker.create_execution_state()
|
g = mock_invoker.create_execution_state()
|
||||||
g.graph.add_node(ErrorInvocation(id="1"))
|
g.graph.add_node(ErrorInvocation(id="1"))
|
||||||
|
|
||||||
mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True)
|
mock_invoker.invoke(
|
||||||
|
session_queue_batch_id="1",
|
||||||
|
session_queue_item_id=1,
|
||||||
|
session_queue_id=DEFAULT_QUEUE_ID,
|
||||||
|
graph_execution_state=g,
|
||||||
|
invoke_all=True,
|
||||||
|
)
|
||||||
|
|
||||||
def has_executed_all(g: GraphExecutionState):
|
def has_executed_all(g: GraphExecutionState):
|
||||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||||
|
Reference in New Issue
Block a user