mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Run black formatting
This commit is contained in:
parent
02aa93c67c
commit
a61685696f
@ -48,7 +48,7 @@ async def create_session(
|
||||
)
|
||||
async def create_batch(
|
||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||
batches: list[Batch] = Body(description="Batch config to apply to the given graph")
|
||||
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
|
||||
) -> BatchProcess:
|
||||
"""Creates and starts a new new batch process"""
|
||||
session = ApiDependencies.invoker.services.batch_manager.run_batch_process(batches, graph)
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
import networkx as nx
|
||||
import uuid
|
||||
import copy
|
||||
@ -21,6 +20,8 @@ from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||
|
||||
|
||||
class Batch(BaseModel):
|
||||
data: list[InvocationsUnion] = Field(description="Mapping of ")
|
||||
node_id: str = Field(description="ID of the node to batch")
|
||||
@ -28,52 +29,44 @@ class Batch(BaseModel):
|
||||
|
||||
class BatchProcess(BaseModel):
|
||||
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
|
||||
sessions: list[str] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
||||
sessions: list[str] = Field(
|
||||
description="Tracker for which batch is currently being processed", default_factory=list
|
||||
)
|
||||
batches: list[Batch] = Field(
|
||||
description="List of batch configs to apply to this session",
|
||||
default_factory=list,
|
||||
)
|
||||
batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
||||
batch_indices: list[int] = Field(
|
||||
description="Tracker for which batch is currently being processed", default_factory=list
|
||||
)
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
|
||||
|
||||
class BatchManagerBase(ABC):
|
||||
@abstractmethod
|
||||
def start(
|
||||
self,
|
||||
invoker: Invoker
|
||||
):
|
||||
def start(self, invoker: Invoker):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_batch_process(
|
||||
self,
|
||||
batches: list[Batch],
|
||||
graph: Graph
|
||||
) -> BatchProcess:
|
||||
def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_batch_process(
|
||||
self,
|
||||
batch_process_id: str
|
||||
):
|
||||
def cancel_batch_process(self, batch_process_id: str):
|
||||
pass
|
||||
|
||||
|
||||
class BatchManager(BatchManagerBase):
|
||||
"""Responsible for managing currently running and scheduled batch jobs"""
|
||||
|
||||
__invoker: Invoker
|
||||
__batches: list[BatchProcess]
|
||||
|
||||
|
||||
def start(self, invoker) -> None:
|
||||
# if we do want multithreading at some point, we could make this configurable
|
||||
self.__invoker = invoker
|
||||
self.__batches = list()
|
||||
local_handler.register(
|
||||
event_name=EventServiceBase.session_event, _func=self.on_event
|
||||
)
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
|
||||
|
||||
async def on_event(self, event: Event):
|
||||
event_name = event[1]["event"]
|
||||
@ -85,22 +78,22 @@ class BatchManager(BatchManagerBase):
|
||||
await self.process(event)
|
||||
|
||||
return event
|
||||
|
||||
|
||||
async def process(self, event: Event):
|
||||
data = event[1]["data"]
|
||||
batchTarget = None
|
||||
for batch in self.__batches:
|
||||
if data['graph_execution_state_id'] in batch.sessions:
|
||||
if data["graph_execution_state_id"] in batch.sessions:
|
||||
batchTarget = batch
|
||||
break
|
||||
|
||||
|
||||
if batchTarget == None:
|
||||
return
|
||||
|
||||
|
||||
if sum(batchTarget.batch_indices) == 0:
|
||||
self.__batches = [batch for batch in self.__batches if batch != batchTarget]
|
||||
return
|
||||
|
||||
|
||||
batchTarget.batch_indices = self._next_batch_index(batchTarget)
|
||||
ges = self._next_batch_session(batchTarget)
|
||||
batchTarget.sessions.append(ges.id)
|
||||
@ -114,15 +107,14 @@ class BatchManager(BatchManagerBase):
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
for npath in sorted_nodes:
|
||||
node = graph.get_node(npath)
|
||||
(index, batch) = next(((i,b) for i,b in enumerate(batches) if b.node_id in node.id), (None, None))
|
||||
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
|
||||
if batch:
|
||||
batch_index = batch_process.batch_indices[index]
|
||||
datum = batch.data[batch_index]
|
||||
datum.id = node.id
|
||||
graph.update_node(npath, datum)
|
||||
|
||||
return GraphExecutionState(graph=graph)
|
||||
|
||||
return GraphExecutionState(graph=graph)
|
||||
|
||||
def _next_batch_index(self, batch_process: BatchProcess):
|
||||
batch_indicies = batch_process.batch_indices.copy()
|
||||
@ -132,19 +124,14 @@ class BatchManager(BatchManagerBase):
|
||||
break
|
||||
return batch_indicies
|
||||
|
||||
|
||||
def run_batch_process(
|
||||
self,
|
||||
batches: list[Batch],
|
||||
graph: Graph
|
||||
) -> BatchProcess:
|
||||
def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess:
|
||||
batch_indices = list()
|
||||
for batch in batches:
|
||||
batch_indices.append(len(batch.data)-1)
|
||||
batch_indices.append(len(batch.data) - 1)
|
||||
batch_process = BatchProcess(
|
||||
batches = batches,
|
||||
batch_indices = batch_indices,
|
||||
graph = graph,
|
||||
batches=batches,
|
||||
batch_indices=batch_indices,
|
||||
graph=graph,
|
||||
)
|
||||
ges = self._next_batch_session(batch_process)
|
||||
batch_process.sessions.append(ges.id)
|
||||
@ -153,8 +140,5 @@ class BatchManager(BatchManagerBase):
|
||||
self.__invoker.invoke(ges, invoke_all=True)
|
||||
return batch_process
|
||||
|
||||
def cancel_batch_process(
|
||||
self,
|
||||
batch_process_id: str
|
||||
):
|
||||
def cancel_batch_process(self, batch_process_id: str):
|
||||
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
|
||||
|
Loading…
Reference in New Issue
Block a user