mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): move all invocation metadata (type, title, tags, category) to decorator
All invocation metadata (type, title, tags and category) are now defined in decorators. The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation. Category is a new invocation metadata, but it is not used by the frontend just yet. - `@invocation()` decorator for invocations ```py @invocation( "sdxl_compel_prompt", title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): ... ``` - `@invocation_output()` decorator for invocation outputs ```py @invocation_output("clip_skip_output") class ClipSkipInvocationOutput(BaseInvocationOutput): ... ``` - update invocation docs - add category to decorator - regen frontend types
This commit is contained in:
@ -1,65 +1,63 @@
|
||||
from typing import Any, Callable, Literal, Union
|
||||
from typing import Any, Callable, Union
|
||||
from pydantic import Field
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.image import ImageField
|
||||
|
||||
|
||||
# Define test invocations before importing anything that uses invocations
|
||||
@invocation_output("test_list_output")
|
||||
class ListPassThroughInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal["test_list_output"] = "test_list_output"
|
||||
|
||||
collection: list[ImageField] = Field(default_factory=list)
|
||||
|
||||
|
||||
@invocation("test_list")
|
||||
class ListPassThroughInvocation(BaseInvocation):
|
||||
type: Literal["test_list"] = "test_list"
|
||||
|
||||
collection: list[ImageField] = Field(default_factory=list)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
|
||||
return ListPassThroughInvocationOutput(collection=self.collection)
|
||||
|
||||
|
||||
@invocation_output("test_prompt_output")
|
||||
class PromptTestInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal["test_prompt_output"] = "test_prompt_output"
|
||||
|
||||
prompt: str = Field(default="")
|
||||
|
||||
|
||||
@invocation("test_prompt")
|
||||
class PromptTestInvocation(BaseInvocation):
|
||||
type: Literal["test_prompt"] = "test_prompt"
|
||||
|
||||
prompt: str = Field(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
||||
return PromptTestInvocationOutput(prompt=self.prompt)
|
||||
|
||||
|
||||
@invocation("test_error")
|
||||
class ErrorInvocation(BaseInvocation):
|
||||
type: Literal["test_error"] = "test_error"
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
||||
raise Exception("This invocation is supposed to fail")
|
||||
|
||||
|
||||
@invocation_output("test_image_output")
|
||||
class ImageTestInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal["test_image_output"] = "test_image_output"
|
||||
|
||||
image: ImageField = Field()
|
||||
|
||||
|
||||
@invocation("test_text_to_image")
|
||||
class TextToImageTestInvocation(BaseInvocation):
|
||||
type: Literal["test_text_to_image"] = "test_text_to_image"
|
||||
|
||||
prompt: str = Field(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||
|
||||
|
||||
@invocation("test_image_to_image")
|
||||
class ImageToImageTestInvocation(BaseInvocation):
|
||||
type: Literal["test_image_to_image"] = "test_image_to_image"
|
||||
|
||||
prompt: str = Field(default="")
|
||||
image: Union[ImageField, None] = Field(default=None)
|
||||
|
||||
@ -67,13 +65,13 @@ class ImageToImageTestInvocation(BaseInvocation):
|
||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||
|
||||
|
||||
@invocation_output("test_prompt_collection_output")
|
||||
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal["test_prompt_collection_output"] = "test_prompt_collection_output"
|
||||
collection: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@invocation("test_prompt_collection")
|
||||
class PromptCollectionTestInvocation(BaseInvocation):
|
||||
type: Literal["test_prompt_collection"] = "test_prompt_collection"
|
||||
collection: list[str] = Field()
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||
|
Reference in New Issue
Block a user