Run black formatting

This commit is contained in:
Brandon Rising 2023-08-01 16:41:40 -04:00
parent 02aa93c67c
commit a61685696f
2 changed files with 27 additions and 43 deletions

View File

@ -48,7 +48,7 @@ async def create_session(
) )
async def create_batch( async def create_batch(
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"), graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
batches: list[Batch] = Body(description="Batch config to apply to the given graph") batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
) -> BatchProcess: ) -> BatchProcess:
"""Creates and starts a new new batch process""" """Creates and starts a new new batch process"""
session = ApiDependencies.invoker.services.batch_manager.run_batch_process(batches, graph) session = ApiDependencies.invoker.services.batch_manager.run_batch_process(batches, graph)

View File

@ -1,4 +1,3 @@
import networkx as nx import networkx as nx
import uuid import uuid
import copy import copy
@ -21,6 +20,8 @@ from invokeai.app.services.invoker import Invoker
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
class Batch(BaseModel): class Batch(BaseModel):
data: list[InvocationsUnion] = Field(description="Mapping of ") data: list[InvocationsUnion] = Field(description="Mapping of ")
node_id: str = Field(description="ID of the node to batch") node_id: str = Field(description="ID of the node to batch")
@ -28,52 +29,44 @@ class Batch(BaseModel):
class BatchProcess(BaseModel): class BatchProcess(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch") batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
sessions: list[str] = Field(description="Tracker for which batch is currently being processed", default_factory=list) sessions: list[str] = Field(
description="Tracker for which batch is currently being processed", default_factory=list
)
batches: list[Batch] = Field( batches: list[Batch] = Field(
description="List of batch configs to apply to this session", description="List of batch configs to apply to this session",
default_factory=list, default_factory=list,
) )
batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list) batch_indices: list[int] = Field(
description="Tracker for which batch is currently being processed", default_factory=list
)
graph: Graph = Field(description="The graph being executed") graph: Graph = Field(description="The graph being executed")
class BatchManagerBase(ABC): class BatchManagerBase(ABC):
@abstractmethod @abstractmethod
def start( def start(self, invoker: Invoker):
self,
invoker: Invoker
):
pass pass
@abstractmethod @abstractmethod
def run_batch_process( def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess:
self,
batches: list[Batch],
graph: Graph
) -> BatchProcess:
pass pass
@abstractmethod @abstractmethod
def cancel_batch_process( def cancel_batch_process(self, batch_process_id: str):
self,
batch_process_id: str
):
pass pass
class BatchManager(BatchManagerBase): class BatchManager(BatchManagerBase):
"""Responsible for managing currently running and scheduled batch jobs""" """Responsible for managing currently running and scheduled batch jobs"""
__invoker: Invoker __invoker: Invoker
__batches: list[BatchProcess] __batches: list[BatchProcess]
def start(self, invoker) -> None: def start(self, invoker) -> None:
# if we do want multithreading at some point, we could make this configurable # if we do want multithreading at some point, we could make this configurable
self.__invoker = invoker self.__invoker = invoker
self.__batches = list() self.__batches = list()
local_handler.register( local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
event_name=EventServiceBase.session_event, _func=self.on_event
)
async def on_event(self, event: Event): async def on_event(self, event: Event):
event_name = event[1]["event"] event_name = event[1]["event"]
@ -85,22 +78,22 @@ class BatchManager(BatchManagerBase):
await self.process(event) await self.process(event)
return event return event
async def process(self, event: Event): async def process(self, event: Event):
data = event[1]["data"] data = event[1]["data"]
batchTarget = None batchTarget = None
for batch in self.__batches: for batch in self.__batches:
if data['graph_execution_state_id'] in batch.sessions: if data["graph_execution_state_id"] in batch.sessions:
batchTarget = batch batchTarget = batch
break break
if batchTarget == None: if batchTarget == None:
return return
if sum(batchTarget.batch_indices) == 0: if sum(batchTarget.batch_indices) == 0:
self.__batches = [batch for batch in self.__batches if batch != batchTarget] self.__batches = [batch for batch in self.__batches if batch != batchTarget]
return return
batchTarget.batch_indices = self._next_batch_index(batchTarget) batchTarget.batch_indices = self._next_batch_index(batchTarget)
ges = self._next_batch_session(batchTarget) ges = self._next_batch_session(batchTarget)
batchTarget.sessions.append(ges.id) batchTarget.sessions.append(ges.id)
@ -114,15 +107,14 @@ class BatchManager(BatchManagerBase):
sorted_nodes = nx.topological_sort(g) sorted_nodes = nx.topological_sort(g)
for npath in sorted_nodes: for npath in sorted_nodes:
node = graph.get_node(npath) node = graph.get_node(npath)
(index, batch) = next(((i,b) for i,b in enumerate(batches) if b.node_id in node.id), (None, None)) (index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
if batch: if batch:
batch_index = batch_process.batch_indices[index] batch_index = batch_process.batch_indices[index]
datum = batch.data[batch_index] datum = batch.data[batch_index]
datum.id = node.id datum.id = node.id
graph.update_node(npath, datum) graph.update_node(npath, datum)
return GraphExecutionState(graph=graph)
return GraphExecutionState(graph=graph)
def _next_batch_index(self, batch_process: BatchProcess): def _next_batch_index(self, batch_process: BatchProcess):
batch_indicies = batch_process.batch_indices.copy() batch_indicies = batch_process.batch_indices.copy()
@ -132,19 +124,14 @@ class BatchManager(BatchManagerBase):
break break
return batch_indicies return batch_indicies
def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess:
def run_batch_process(
self,
batches: list[Batch],
graph: Graph
) -> BatchProcess:
batch_indices = list() batch_indices = list()
for batch in batches: for batch in batches:
batch_indices.append(len(batch.data)-1) batch_indices.append(len(batch.data) - 1)
batch_process = BatchProcess( batch_process = BatchProcess(
batches = batches, batches=batches,
batch_indices = batch_indices, batch_indices=batch_indices,
graph = graph, graph=graph,
) )
ges = self._next_batch_session(batch_process) ges = self._next_batch_session(batch_process)
batch_process.sessions.append(ges.id) batch_process.sessions.append(ges.id)
@ -153,8 +140,5 @@ class BatchManager(BatchManagerBase):
self.__invoker.invoke(ges, invoke_all=True) self.__invoker.invoke(ges, invoke_all=True)
return batch_process return batch_process
def cancel_batch_process( def cancel_batch_process(self, batch_process_id: str):
self,
batch_process_id: str
):
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id] self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]