mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Run black formatting
This commit is contained in:
parent
02aa93c67c
commit
a61685696f
@ -48,7 +48,7 @@ async def create_session(
|
|||||||
)
|
)
|
||||||
async def create_batch(
|
async def create_batch(
|
||||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||||
batches: list[Batch] = Body(description="Batch config to apply to the given graph")
|
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
|
||||||
) -> BatchProcess:
|
) -> BatchProcess:
|
||||||
"""Creates and starts a new new batch process"""
|
"""Creates and starts a new new batch process"""
|
||||||
session = ApiDependencies.invoker.services.batch_manager.run_batch_process(batches, graph)
|
session = ApiDependencies.invoker.services.batch_manager.run_batch_process(batches, graph)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import uuid
|
import uuid
|
||||||
import copy
|
import copy
|
||||||
@ -21,6 +20,8 @@ from invokeai.app.services.invoker import Invoker
|
|||||||
|
|
||||||
|
|
||||||
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class Batch(BaseModel):
|
class Batch(BaseModel):
|
||||||
data: list[InvocationsUnion] = Field(description="Mapping of ")
|
data: list[InvocationsUnion] = Field(description="Mapping of ")
|
||||||
node_id: str = Field(description="ID of the node to batch")
|
node_id: str = Field(description="ID of the node to batch")
|
||||||
@ -28,52 +29,44 @@ class Batch(BaseModel):
|
|||||||
|
|
||||||
class BatchProcess(BaseModel):
|
class BatchProcess(BaseModel):
|
||||||
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
|
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
|
||||||
sessions: list[str] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
sessions: list[str] = Field(
|
||||||
|
description="Tracker for which batch is currently being processed", default_factory=list
|
||||||
|
)
|
||||||
batches: list[Batch] = Field(
|
batches: list[Batch] = Field(
|
||||||
description="List of batch configs to apply to this session",
|
description="List of batch configs to apply to this session",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
batch_indices: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
|
batch_indices: list[int] = Field(
|
||||||
|
description="Tracker for which batch is currently being processed", default_factory=list
|
||||||
|
)
|
||||||
graph: Graph = Field(description="The graph being executed")
|
graph: Graph = Field(description="The graph being executed")
|
||||||
|
|
||||||
|
|
||||||
class BatchManagerBase(ABC):
|
class BatchManagerBase(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def start(
|
def start(self, invoker: Invoker):
|
||||||
self,
|
|
||||||
invoker: Invoker
|
|
||||||
):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run_batch_process(
|
def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess:
|
||||||
self,
|
|
||||||
batches: list[Batch],
|
|
||||||
graph: Graph
|
|
||||||
) -> BatchProcess:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cancel_batch_process(
|
def cancel_batch_process(self, batch_process_id: str):
|
||||||
self,
|
|
||||||
batch_process_id: str
|
|
||||||
):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BatchManager(BatchManagerBase):
|
class BatchManager(BatchManagerBase):
|
||||||
"""Responsible for managing currently running and scheduled batch jobs"""
|
"""Responsible for managing currently running and scheduled batch jobs"""
|
||||||
|
|
||||||
__invoker: Invoker
|
__invoker: Invoker
|
||||||
__batches: list[BatchProcess]
|
__batches: list[BatchProcess]
|
||||||
|
|
||||||
|
|
||||||
def start(self, invoker) -> None:
|
def start(self, invoker) -> None:
|
||||||
# if we do want multithreading at some point, we could make this configurable
|
# if we do want multithreading at some point, we could make this configurable
|
||||||
self.__invoker = invoker
|
self.__invoker = invoker
|
||||||
self.__batches = list()
|
self.__batches = list()
|
||||||
local_handler.register(
|
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
|
||||||
event_name=EventServiceBase.session_event, _func=self.on_event
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_event(self, event: Event):
|
async def on_event(self, event: Event):
|
||||||
event_name = event[1]["event"]
|
event_name = event[1]["event"]
|
||||||
@ -90,7 +83,7 @@ class BatchManager(BatchManagerBase):
|
|||||||
data = event[1]["data"]
|
data = event[1]["data"]
|
||||||
batchTarget = None
|
batchTarget = None
|
||||||
for batch in self.__batches:
|
for batch in self.__batches:
|
||||||
if data['graph_execution_state_id'] in batch.sessions:
|
if data["graph_execution_state_id"] in batch.sessions:
|
||||||
batchTarget = batch
|
batchTarget = batch
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -114,7 +107,7 @@ class BatchManager(BatchManagerBase):
|
|||||||
sorted_nodes = nx.topological_sort(g)
|
sorted_nodes = nx.topological_sort(g)
|
||||||
for npath in sorted_nodes:
|
for npath in sorted_nodes:
|
||||||
node = graph.get_node(npath)
|
node = graph.get_node(npath)
|
||||||
(index, batch) = next(((i,b) for i,b in enumerate(batches) if b.node_id in node.id), (None, None))
|
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
|
||||||
if batch:
|
if batch:
|
||||||
batch_index = batch_process.batch_indices[index]
|
batch_index = batch_process.batch_indices[index]
|
||||||
datum = batch.data[batch_index]
|
datum = batch.data[batch_index]
|
||||||
@ -123,7 +116,6 @@ class BatchManager(BatchManagerBase):
|
|||||||
|
|
||||||
return GraphExecutionState(graph=graph)
|
return GraphExecutionState(graph=graph)
|
||||||
|
|
||||||
|
|
||||||
def _next_batch_index(self, batch_process: BatchProcess):
|
def _next_batch_index(self, batch_process: BatchProcess):
|
||||||
batch_indicies = batch_process.batch_indices.copy()
|
batch_indicies = batch_process.batch_indices.copy()
|
||||||
for index in range(len(batch_indicies)):
|
for index in range(len(batch_indicies)):
|
||||||
@ -132,19 +124,14 @@ class BatchManager(BatchManagerBase):
|
|||||||
break
|
break
|
||||||
return batch_indicies
|
return batch_indicies
|
||||||
|
|
||||||
|
def run_batch_process(self, batches: list[Batch], graph: Graph) -> BatchProcess:
|
||||||
def run_batch_process(
|
|
||||||
self,
|
|
||||||
batches: list[Batch],
|
|
||||||
graph: Graph
|
|
||||||
) -> BatchProcess:
|
|
||||||
batch_indices = list()
|
batch_indices = list()
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
batch_indices.append(len(batch.data)-1)
|
batch_indices.append(len(batch.data) - 1)
|
||||||
batch_process = BatchProcess(
|
batch_process = BatchProcess(
|
||||||
batches = batches,
|
batches=batches,
|
||||||
batch_indices = batch_indices,
|
batch_indices=batch_indices,
|
||||||
graph = graph,
|
graph=graph,
|
||||||
)
|
)
|
||||||
ges = self._next_batch_session(batch_process)
|
ges = self._next_batch_session(batch_process)
|
||||||
batch_process.sessions.append(ges.id)
|
batch_process.sessions.append(ges.id)
|
||||||
@ -153,8 +140,5 @@ class BatchManager(BatchManagerBase):
|
|||||||
self.__invoker.invoke(ges, invoke_all=True)
|
self.__invoker.invoke(ges, invoke_all=True)
|
||||||
return batch_process
|
return batch_process
|
||||||
|
|
||||||
def cancel_batch_process(
|
def cancel_batch_process(self, batch_process_id: str):
|
||||||
self,
|
|
||||||
batch_process_id: str
|
|
||||||
):
|
|
||||||
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
|
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
|
||||||
|
Loading…
Reference in New Issue
Block a user