nodes: ensure Graph and GraphExecutionState ids are cast to str on instantiation

This commit is contained in:
Eugene 2023-04-14 11:17:40 -04:00 committed by Eugene Brodsky
parent cbd1a7263a
commit 570c3fe690
2 changed files with 3 additions and 7 deletions

View File

@ -25,7 +25,6 @@ from ..invocations.baseinvocation import (
BaseInvocationOutput, BaseInvocationOutput,
InvocationContext, InvocationContext,
) )
from .invocation_services import InvocationServices
class EdgeConnection(BaseModel): class EdgeConnection(BaseModel):
@ -214,7 +213,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
class Graph(BaseModel): class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4) id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field( nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
description="The nodes in this graph", default_factory=dict description="The nodes in this graph", default_factory=dict
@ -749,9 +748,7 @@ class Graph(BaseModel):
class GraphExecutionState(BaseModel): class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution""" """Tracks the state of a graph execution"""
id: str = Field( id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__()
)
# TODO: Store a reference to the graph instead of the actual graph? # TODO: Store a reference to the graph instead of the actual graph?
graph: Graph = Field(description="The graph being executed") graph: Graph = Field(description="The graph being executed")

View File

@ -3,13 +3,12 @@
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from queue import Queue from queue import Queue
from uuid import UUID
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class InvocationQueueItem(BaseModel): class InvocationQueueItem(BaseModel):
graph_execution_state_id: UUID graph_execution_state_id: str
invocation_id: str invocation_id: str
invoke_all: bool invoke_all: bool
timestamp: float = Field(default_factory=time.time) timestamp: float = Field(default_factory=time.time)