InvokeAI/invokeai/app/services/invocation_queue.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

79 lines
2.6 KiB
Python
Raw Normal View History

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
2023-04-13 04:23:15 +00:00
import time
from abc import ABC, abstractmethod
from queue import Queue
2023-08-18 14:57:18 +00:00
from typing import Optional
from pydantic import BaseModel, Field
class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
invocation_id: str = Field(description="The ID of the node being invoked")
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
session_queue_item_id: int = Field(
description="The ID of session queue item from which this invocation queue item came"
)
session_queue_batch_id: str = Field(
description="The ID of the session batch from which this invocation queue item came"
)
invoke_all: bool = Field(default=False)
timestamp: float = Field(default_factory=time.time)
class InvocationQueueABC(ABC):
"""Abstract base class for all invocation queues"""
@abstractmethod
def get(self) -> InvocationQueueItem:
pass
@abstractmethod
def put(self, item: Optional[InvocationQueueItem]) -> None:
pass
@abstractmethod
def cancel(self, graph_execution_state_id: str) -> None:
pass
@abstractmethod
def is_canceled(self, graph_execution_state_id: str) -> bool:
pass
2023-03-17 03:05:36 +00:00
class MemoryInvocationQueue(InvocationQueueABC):
__queue: Queue
2023-03-17 03:05:36 +00:00
__cancellations: dict[str, float]
def __init__(self):
self.__queue = Queue()
2023-03-17 03:05:36 +00:00
self.__cancellations = dict()
2023-03-03 06:02:00 +00:00
def get(self) -> InvocationQueueItem:
2023-03-17 03:05:36 +00:00
item = self.__queue.get()
while (
isinstance(item, InvocationQueueItem)
and item.graph_execution_state_id in self.__cancellations
and self.__cancellations[item.graph_execution_state_id] > item.timestamp
):
item = self.__queue.get()
# Clear old items
for graph_execution_state_id in list(self.__cancellations.keys()):
if self.__cancellations[graph_execution_state_id] < item.timestamp:
del self.__cancellations[graph_execution_state_id]
return item
2023-03-03 06:02:00 +00:00
2023-07-03 16:17:45 +00:00
def put(self, item: Optional[InvocationQueueItem]) -> None:
self.__queue.put(item)
2023-03-17 03:05:36 +00:00
def cancel(self, graph_execution_state_id: str) -> None:
if graph_execution_state_id not in self.__cancellations:
self.__cancellations[graph_execution_state_id] = time.time()
def is_canceled(self, graph_execution_state_id: str) -> bool:
return graph_execution_state_id in self.__cancellations