# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) from typing import Literal from pydantic import Field from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) from .math import FloatOutput, IntOutput # Pass-through parameter nodes - used by subgraphs class ParamIntInvocation(BaseInvocation): """An integer parameter""" # fmt: off type: Literal["param_int"] = "param_int" a: int = Field(default=0, description="The integer value") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": {"tags": ["param", "integer"], "title": "Integer Parameter"}, } def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=self.a) class ParamFloatInvocation(BaseInvocation): """A float parameter""" # fmt: off type: Literal["param_float"] = "param_float" param: float = Field(default=0.0, description="The float value") # fmt: on class Config(InvocationConfig): schema_extra = { "ui": {"tags": ["param", "float"], "title": "Float Parameter"}, } def invoke(self, context: InvocationContext) -> FloatOutput: return FloatOutput(param=self.param) class StringOutput(BaseInvocationOutput): """A string output""" type: Literal["string_output"] = "string_output" text: str = Field(default=None, description="The output string") class ParamStringInvocation(BaseInvocation): """A string parameter""" type: Literal["param_string"] = "param_string" text: str = Field(default="", description="The string value") class Config(InvocationConfig): schema_extra = { "ui": {"tags": ["param", "string"], "title": "String Parameter"}, } def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(text=self.text) class PromptOutput(BaseInvocationOutput): """A string output""" type: Literal["prompt_output"] = "prompt_output" prompt: str = Field(default=None, description="The output prompt") class ParamPromptInvocation(BaseInvocation): """A prompt input parameter""" type: Literal["param_prompt"] = "param_prompt" prompt: str = Field(default="", description="The prompt value") class Config(InvocationConfig): schema_extra = { "ui": {"tags": ["param", "prompt"], "title": "Prompt"}, } def invoke(self, context: InvocationContext) -> PromptOutput: return PromptOutput(prompt=self.prompt)