# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) import time from abc import ABC, abstractmethod from queue import Queue from typing import Optional from pydantic import BaseModel, Field class InvocationQueueItem(BaseModel): graph_execution_state_id: str = Field(description="The ID of the graph execution state") invocation_id: str = Field(description="The ID of the node being invoked") session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came") session_queue_item_id: int = Field( description="The ID of session queue item from which this invocation queue item came" ) invoke_all: bool = Field(default=False) timestamp: float = Field(default_factory=time.time) class InvocationQueueABC(ABC): """Abstract base class for all invocation queues""" @abstractmethod def get(self) -> InvocationQueueItem: pass @abstractmethod def put(self, item: Optional[InvocationQueueItem]) -> None: pass @abstractmethod def cancel(self, graph_execution_state_id: str) -> None: pass @abstractmethod def is_canceled(self, graph_execution_state_id: str) -> bool: pass class MemoryInvocationQueue(InvocationQueueABC): __queue: Queue __cancellations: dict[str, float] def __init__(self): self.__queue = Queue() self.__cancellations = dict() def get(self) -> InvocationQueueItem: item = self.__queue.get() while ( isinstance(item, InvocationQueueItem) and item.graph_execution_state_id in self.__cancellations and self.__cancellations[item.graph_execution_state_id] > item.timestamp ): item = self.__queue.get() # Clear old items for graph_execution_state_id in list(self.__cancellations.keys()): if self.__cancellations[graph_execution_state_id] < item.timestamp: del self.__cancellations[graph_execution_state_id] return item def put(self, item: Optional[InvocationQueueItem]) -> None: self.__queue.put(item) def cancel(self, graph_execution_state_id: str) -> None: if graph_execution_state_id not in self.__cancellations: self.__cancellations[graph_execution_state_id] = time.time() def is_canceled(self, graph_execution_state_id: str) -> bool: return graph_execution_state_id in self.__cancellations