mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Testing out generating a new session for each batch_index
This commit is contained in:
parent
d2f968b902
commit
f080c56771
@ -35,7 +35,12 @@ async def create_session(
|
|||||||
)
|
)
|
||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
|
||||||
|
batch_indices = list()
|
||||||
|
if graph.batches:
|
||||||
|
for batch in graph.batches:
|
||||||
|
batch_indices.append(len(batch.data)-1)
|
||||||
|
session = ApiDependencies.invoker.create_execution_state(graph, batch_indices)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
@ -376,8 +376,9 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
result_latents = result_latents.to("cpu")
|
result_latents = result_latents.to("cpu")
|
||||||
import uuid
|
torch.cuda.empty_cache()
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}_{uuid.uuid4()}'
|
|
||||||
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.save(name, result_latents)
|
context.services.latents.save(name, result_latents)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
|
|
||||||
|
@ -818,6 +818,8 @@ class GraphExecutionState(BaseModel):
|
|||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
||||||
|
|
||||||
# The results of executed nodes
|
# The results of executed nodes
|
||||||
results: dict[
|
results: dict[
|
||||||
str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]
|
str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]
|
||||||
@ -855,14 +857,14 @@ class GraphExecutionState(BaseModel):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
def next(self, batch_indices: list = list()) -> Optional[BaseInvocation]:
|
def next(self) -> Optional[BaseInvocation]:
|
||||||
"""Gets the next node ready to execute."""
|
"""Gets the next node ready to execute."""
|
||||||
|
|
||||||
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
|
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
|
||||||
# possibly with a timeout?
|
# possibly with a timeout?
|
||||||
|
|
||||||
# If there are no prepared nodes, prepare some nodes
|
|
||||||
self._apply_batch_config()
|
self._apply_batch_config()
|
||||||
|
# If there are no prepared nodes, prepare some nodes
|
||||||
next_node = self._get_next_node()
|
next_node = self._get_next_node()
|
||||||
if next_node is None:
|
if next_node is None:
|
||||||
prepared_id = self._prepare()
|
prepared_id = self._prepare()
|
||||||
@ -871,15 +873,10 @@ class GraphExecutionState(BaseModel):
|
|||||||
while prepared_id is not None:
|
while prepared_id is not None:
|
||||||
prepared_id = self._prepare()
|
prepared_id = self._prepare()
|
||||||
next_node = self._get_next_node()
|
next_node = self._get_next_node()
|
||||||
|
|
||||||
# Get values from edges
|
# Get values from edges
|
||||||
if next_node is not None:
|
if next_node is not None:
|
||||||
self._prepare_inputs(next_node)
|
self._prepare_inputs(next_node)
|
||||||
if next_node is None and sum(self.batch_indices) != 0:
|
|
||||||
for index in range(len(self.batch_indices)):
|
|
||||||
if self.batch_indices[index] > 0:
|
|
||||||
self.batch_indices[index] -= 1
|
|
||||||
self.executed.clear()
|
|
||||||
return self.next()
|
|
||||||
|
|
||||||
# If next is still none, there's no next node, return None
|
# If next is still none, there's no next node, return None
|
||||||
return next_node
|
return next_node
|
||||||
@ -909,7 +906,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
def is_complete(self) -> bool:
|
def is_complete(self) -> bool:
|
||||||
"""Returns true if the graph is complete"""
|
"""Returns true if the graph is complete"""
|
||||||
node_ids = set(self.graph.nx_graph_flat().nodes)
|
node_ids = set(self.graph.nx_graph_flat().nodes)
|
||||||
return sum(self.batch_indices) == 0 and (self.has_error() or all((k in self.executed for k in node_ids)))
|
return self.has_error() or all((k in self.executed for k in node_ids))
|
||||||
|
|
||||||
def has_error(self) -> bool:
|
def has_error(self) -> bool:
|
||||||
"""Returns true if the graph has any errors"""
|
"""Returns true if the graph has any errors"""
|
||||||
|
@ -41,14 +41,9 @@ class Invoker:
|
|||||||
|
|
||||||
return invocation.id
|
return invocation.id
|
||||||
|
|
||||||
def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
|
def create_execution_state(self, graph: Optional[Graph] = None, batch_indices: list[int] = list()) -> GraphExecutionState:
|
||||||
"""Creates a new execution state for the given graph"""
|
"""Creates a new execution state for the given graph"""
|
||||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
new_state = GraphExecutionState(graph=Graph() if graph is None else graph, batch_indices=batch_indices)
|
||||||
if graph.batches:
|
|
||||||
batch_indices = list()
|
|
||||||
for batch in graph.batches:
|
|
||||||
batch_indices.append(len(batch.data)-1)
|
|
||||||
new_state.batch_indices = batch_indices
|
|
||||||
self.services.graph_execution_manager.set(new_state)
|
self.services.graph_execution_manager.set(new_state)
|
||||||
return new_state
|
return new_state
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ from ..invocations.baseinvocation import InvocationContext
|
|||||||
from .invocation_queue import InvocationQueueItem
|
from .invocation_queue import InvocationQueueItem
|
||||||
from .invoker import InvocationProcessorABC, Invoker
|
from .invoker import InvocationProcessorABC, Invoker
|
||||||
from ..models.exceptions import CanceledException
|
from ..models.exceptions import CanceledException
|
||||||
|
from .graph import GraphExecutionState
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||||
@ -73,7 +74,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=traceback.format_exc(),
|
error=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||||
|
|
||||||
@ -165,6 +166,15 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=traceback.format_exc()
|
error=traceback.format_exc()
|
||||||
)
|
)
|
||||||
|
elif queue_item.invoke_all and sum(graph_execution_state.batch_indices) > 0:
|
||||||
|
batch_indicies = graph_execution_state.batch_indices.copy()
|
||||||
|
for index in range(len(batch_indicies)):
|
||||||
|
if batch_indicies[index] > 0:
|
||||||
|
batch_indicies[index] -= 1
|
||||||
|
break
|
||||||
|
new_ges = GraphExecutionState(graph=graph_execution_state.graph, batch_indices=batch_indicies)
|
||||||
|
self.__invoker.services.graph_execution_manager.set(new_ges)
|
||||||
|
self.__invoker.invoke(new_ges, invoke_all=True)
|
||||||
elif is_complete:
|
elif is_complete:
|
||||||
self.__invoker.services.events.emit_graph_execution_complete(
|
self.__invoker.services.events.emit_graph_execution_complete(
|
||||||
graph_execution_state.id
|
graph_execution_state.id
|
||||||
|
Loading…
x
Reference in New Issue
Block a user