-
-
✖
-
-

-
- Postprocessing...1/3
-
-
-
-
-
-
-
-
diff --git a/invokeai/frontend/web/static/legacy_web/index.js b/invokeai/frontend/web/static/legacy_web/index.js
deleted file mode 100644
index a150f3f2e9..0000000000
--- a/invokeai/frontend/web/static/legacy_web/index.js
+++ /dev/null
@@ -1,234 +0,0 @@
-function toBase64(file) {
- return new Promise((resolve, reject) => {
- const r = new FileReader();
- r.readAsDataURL(file);
- r.onload = () => resolve(r.result);
- r.onerror = (error) => reject(error);
- });
-}
-
-function appendOutput(src, seed, config) {
- let outputNode = document.createElement('figure');
-
- let variations = config.with_variations;
- if (config.variation_amount > 0) {
- variations =
- (variations ? variations + ',' : '') +
- seed +
- ':' +
- config.variation_amount;
- }
- let baseseed =
- config.with_variations || config.variation_amount > 0 ? config.seed : seed;
- let altText =
- baseseed + ' | ' + (variations ? variations + ' | ' : '') + config.prompt;
-
- // img needs width and height for lazy loading to work
- const figureContents = `
-
-
-
-
${seed}
- `;
-
- outputNode.innerHTML = figureContents;
- let figcaption = outputNode.querySelector('figcaption');
-
- // Reload image config
- figcaption.addEventListener('click', () => {
- let form = document.querySelector('#generate-form');
- for (const [k, v] of new FormData(form)) {
- if (k == 'initimg') {
- continue;
- }
- form.querySelector(`*[name=${k}]`).value = config[k];
- }
-
- document.querySelector('#seed').value = baseseed;
- document.querySelector('#with_variations').value = variations || '';
- if (document.querySelector('#variation_amount').value <= 0) {
- document.querySelector('#variation_amount').value = 0.2;
- }
-
- saveFields(document.querySelector('#generate-form'));
- });
-
- document.querySelector('#results').prepend(outputNode);
-}
-
-function saveFields(form) {
- for (const [k, v] of new FormData(form)) {
- if (typeof v !== 'object') {
- // Don't save 'file' type
- localStorage.setItem(k, v);
- }
- }
-}
-
-function loadFields(form) {
- for (const [k, v] of new FormData(form)) {
- const item = localStorage.getItem(k);
- if (item != null) {
- form.querySelector(`*[name=${k}]`).value = item;
- }
- }
-}
-
-function clearFields(form) {
- localStorage.clear();
- let prompt = form.prompt.value;
- form.reset();
- form.prompt.value = prompt;
-}
-
-const BLANK_IMAGE_URL =
- 'data:image/svg+xml,
';
-async function generateSubmit(form) {
- const prompt = document.querySelector('#prompt').value;
-
- // Convert file data to base64
- let formData = Object.fromEntries(new FormData(form));
- formData.initimg_name = formData.initimg.name;
- formData.initimg =
- formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
-
- let strength = formData.strength;
- let totalSteps = formData.initimg
- ? Math.floor(strength * formData.steps)
- : formData.steps;
-
- let progressSectionEle = document.querySelector('#progress-section');
- progressSectionEle.style.display = 'initial';
- let progressEle = document.querySelector('#progress-bar');
- progressEle.setAttribute('max', totalSteps);
- let progressImageEle = document.querySelector('#progress-image');
- progressImageEle.src = BLANK_IMAGE_URL;
-
- progressImageEle.style.display = {}.hasOwnProperty.call(
- formData,
- 'progress_images'
- )
- ? 'initial'
- : 'none';
-
- // Post as JSON, using Fetch streaming to get results
- fetch(form.action, {
- method: form.method,
- body: JSON.stringify(formData),
- }).then(async (response) => {
- const reader = response.body.getReader();
-
- let noOutputs = true;
- while (true) {
- let { value, done } = await reader.read();
- value = new TextDecoder().decode(value);
- if (done) {
- progressSectionEle.style.display = 'none';
- break;
- }
-
- for (let event of value.split('\n').filter((e) => e !== '')) {
- const data = JSON.parse(event);
-
- if (data.event === 'result') {
- noOutputs = false;
- appendOutput(data.url, data.seed, data.config);
- progressEle.setAttribute('value', 0);
- progressEle.setAttribute('max', totalSteps);
- } else if (data.event === 'upscaling-started') {
- document.getElementById('processing_cnt').textContent =
- data.processed_file_cnt;
- document.getElementById('scaling-inprocess-message').style.display =
- 'block';
- } else if (data.event === 'upscaling-done') {
- document.getElementById('scaling-inprocess-message').style.display =
- 'none';
- } else if (data.event === 'step') {
- progressEle.setAttribute('value', data.step);
- if (data.url) {
- progressImageEle.src = data.url;
- }
- } else if (data.event === 'canceled') {
- // avoid alerting as if this were an error case
- noOutputs = false;
- }
- }
- }
-
- // Re-enable form, remove no-results-message
- form.querySelector('fieldset').removeAttribute('disabled');
- document.querySelector('#prompt').value = prompt;
- document.querySelector('progress').setAttribute('value', '0');
-
- if (noOutputs) {
- alert('Error occurred while generating.');
- }
- });
-
- // Disable form while generating
- form.querySelector('fieldset').setAttribute('disabled', '');
- document.querySelector('#prompt').value = `Generating: "${prompt}"`;
-}
-
-async function fetchRunLog() {
- try {
- let response = await fetch('/run_log.json');
- const data = await response.json();
- for (let item of data.run_log) {
- appendOutput(item.url, item.seed, item);
- }
- } catch (e) {
- console.error(e);
- }
-}
-
-window.onload = async () => {
- document.querySelector('#prompt').addEventListener('keydown', (e) => {
- if (e.key === 'Enter' && !e.shiftKey) {
- const form = e.target.form;
- generateSubmit(form);
- }
- });
- document.querySelector('#generate-form').addEventListener('submit', (e) => {
- e.preventDefault();
- const form = e.target;
-
- generateSubmit(form);
- });
- document.querySelector('#generate-form').addEventListener('change', (e) => {
- saveFields(e.target.form);
- });
- document.querySelector('#reset-seed').addEventListener('click', (e) => {
- document.querySelector('#seed').value = -1;
- saveFields(e.target.form);
- });
- document.querySelector('#reset-all').addEventListener('click', (e) => {
- clearFields(e.target.form);
- });
- document.querySelector('#remove-image').addEventListener('click', (e) => {
- initimg.value = null;
- });
- loadFields(document.querySelector('#generate-form'));
-
- document.querySelector('#cancel-button').addEventListener('click', () => {
- fetch('/cancel').catch((e) => {
- console.error(e);
- });
- });
- document.documentElement.addEventListener('keydown', (e) => {
- if (e.key === 'Escape')
- fetch('/cancel').catch((err) => {
- console.error(err);
- });
- });
-
- if (!config.gfpgan_model_exists) {
- document.querySelector('#gfpgan').style.display = 'none';
- }
- await fetchRunLog();
-};
diff --git a/mkdocs.yml b/mkdocs.yml
index f95d83ac8f..97b2a16f19 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -134,6 +134,7 @@ nav:
- List of Default Nodes: 'nodes/defaultNodes.md'
- Workflow Editor Usage: 'nodes/NODES.md'
- ComfyUI to InvokeAI: 'nodes/comfyToInvoke.md'
+ - Facetool Node: 'nodes/detailedNodes/faceTools.md'
- Contributing Nodes: 'nodes/contributingNodes.md'
- Features:
- Overview: 'features/index.md'
@@ -144,7 +145,7 @@ nav:
- Image-to-Image: 'features/IMG2IMG.md'
- Controlling Logging: 'features/LOGGING.md'
- Model Merging: 'features/MODEL_MERGING.md'
- - Using Nodes : './nodes/overview'
+ - Using Nodes : 'nodes/overview.md'
- NSFW Checker: 'features/WATERMARK+NSFW.md'
- Postprocessing: 'features/POSTPROCESS.md'
- Prompting Features: 'features/PROMPTS.md'
@@ -152,15 +153,18 @@ nav:
- Unified Canvas: 'features/UNIFIED_CANVAS.md'
- InvokeAI Web Server: 'features/WEB.md'
- WebUI Hotkeys: "features/WEBUIHOTKEYS.md"
+ - Maintenance Utilities: "features/UTILITIES.md"
- Other: 'features/OTHER.md'
- Contributing:
- How to Contribute: 'contributing/CONTRIBUTING.md'
+ - InvokeAI Code of Conduct: 'CODE_OF_CONDUCT.md'
- Development:
- Overview: 'contributing/contribution_guides/development.md'
- New Contributors: 'contributing/contribution_guides/newContributorChecklist.md'
- InvokeAI Architecture: 'contributing/ARCHITECTURE.md'
- Frontend Documentation: 'contributing/contribution_guides/contributingToFrontend.md'
- Local Development: 'contributing/LOCAL_DEVELOPMENT.md'
+ - Adding Tests: 'contributing/TESTS.md'
- Documentation: 'contributing/contribution_guides/documentation.md'
- Nodes: 'contributing/INVOCATIONS.md'
- Translation: 'contributing/contribution_guides/translation.md'
@@ -168,9 +172,12 @@ nav:
- Changelog: 'CHANGELOG.md'
- Deprecated:
- Command Line Interface: 'deprecated/CLI.md'
+ - Variations: 'deprecated/VARIATIONS.md'
+ - Translations: 'deprecated/TRANSLATION.md'
- Embiggen: 'deprecated/EMBIGGEN.md'
- Inpainting: 'deprecated/INPAINTING.md'
- Outpainting: 'deprecated/OUTPAINTING.md'
+ - Troubleshooting: 'help/deprecated/TROUBLESHOOT.md'
- Help:
- Getting Started: 'help/gettingStartedWithAI.md'
- Diffusion Overview: 'help/diffusion.md'
diff --git a/pyproject.toml b/pyproject.toml
index 67486e1120..d67b096ddc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -122,10 +122,9 @@ dependencies = [
"configure_invokeai.py" = "invokeai.frontend.install.invokeai_configure:invokeai_configure"
"textual_inversion.py" = "invokeai.frontend.training:invokeai_textual_inversion"
-# shortcut commands to start cli and web
+# shortcut commands to start web ui
# "invokeai --web" will launch the web interface
-# "invokeai" will launch the CLI
-"invokeai" = "invokeai.frontend.legacy_launch_invokeai:main"
+# "invokeai" = "invokeai.frontend.legacy_launch_invokeai:main"
# new shortcut to launch web interface
"invokeai-web" = "invokeai.app.api_app:invoke_api"
@@ -138,7 +137,6 @@ dependencies = [
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
"invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main"
-"invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli"
"invokeai-node-web" = "invokeai.app.api_app:invoke_api"
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
"invokeai-db-maintenance" = "invokeai.backend.util.db_maintenance:main"
@@ -168,11 +166,13 @@ version = { attr = "invokeai.version.__version__" }
]
[tool.setuptools.package-data]
+"invokeai.app.assets" = ["**/*.png"]
"invokeai.assets.fonts" = ["**/*.ttf"]
"invokeai.backend" = ["**.png"]
"invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"]
"invokeai.frontend.web.dist" = ["**"]
"invokeai.frontend.web.static" = ["**"]
+"invokeai.app.invocations" = ["**"]
#=== Begin: PyTest and Coverage
[tool.pytest.ini_options]
diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/backend/ip_adapter/test_ip_adapter.py
index 7f634ee1fe..6712196778 100644
--- a/tests/backend/ip_adapter/test_ip_adapter.py
+++ b/tests/backend/ip_adapter/test_ip_adapter.py
@@ -65,7 +65,10 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
ip_adapter.to(torch_device, dtype=torch.float32)
unet.to(torch_device, dtype=torch.float32)
- cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [torch.randn((1, 4, 768)).to(torch_device)]}
+ # ip_embeds shape: (batch_size, num_ip_images, seq_len, ip_image_embedding_len)
+ ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
+
+ cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py
index 27b8a58bea..171cdfdb6f 100644
--- a/tests/nodes/test_graph_execution_state.py
+++ b/tests/nodes/test_graph_execution_state.py
@@ -75,6 +75,8 @@ def mock_services() -> InvocationServices:
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
+ workflow_records=None, # type: ignore
+ workflow_image_records=None, # type: ignore
)
diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py
index 105f7417cd..25b02955b0 100644
--- a/tests/nodes/test_invoker.py
+++ b/tests/nodes/test_invoker.py
@@ -80,6 +80,8 @@ def mock_services() -> InvocationServices:
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
+ workflow_records=None, # type: ignore
+ workflow_image_records=None, # type: ignore
)
diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py
index 3c965895f9..e2a50e61e5 100644
--- a/tests/nodes/test_node_graph.py
+++ b/tests/nodes/test_node_graph.py
@@ -10,7 +10,12 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.image import ShowImageInvocation
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
-from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
+from invokeai.app.invocations.primitives import (
+ FloatCollectionInvocation,
+ FloatInvocation,
+ IntegerInvocation,
+ StringInvocation,
+)
from invokeai.app.invocations.upscale import ESRGANInvocation
from invokeai.app.services.shared.default_graphs import create_text_to_image
from invokeai.app.services.shared.graph import (
@@ -27,8 +32,11 @@ from invokeai.app.services.shared.graph import (
)
from .test_nodes import (
+ AnyTypeTestInvocation,
ImageToImageTestInvocation,
ListPassThroughInvocation,
+ PolymorphicStringTestInvocation,
+ PromptCollectionTestInvocation,
PromptTestInvocation,
TextToImageTestInvocation,
)
@@ -607,8 +615,8 @@ def test_graph_can_deserialize():
g.add_edge(e)
json = g.model_dump_json()
- adapter_graph = TypeAdapter(Graph)
- g2 = adapter_graph.validate_json(json)
+ GraphValidator = TypeAdapter(Graph)
+ g2 = GraphValidator.validate_json(json)
assert g2 is not None
assert g2.nodes["1"] is not None
@@ -692,6 +700,144 @@ def test_ints_do_not_accept_floats():
g.add_edge(e)
+def test_polymorphic_accepts_single():
+ g = Graph()
+ n1 = StringInvocation(id="1", value="banana")
+ n2 = PolymorphicStringTestInvocation(id="2")
+ g.add_node(n1)
+ g.add_node(n2)
+ e1 = create_edge(n1.id, "value", n2.id, "value")
+ # Not throwing on this line is sufficient
+ g.add_edge(e1)
+
+
+def test_polymorphic_accepts_collection_of_same_base_type():
+ g = Graph()
+ n1 = PromptCollectionTestInvocation(id="1", collection=["banana", "sundae"])
+ n2 = PolymorphicStringTestInvocation(id="2")
+ g.add_node(n1)
+ g.add_node(n2)
+ e1 = create_edge(n1.id, "collection", n2.id, "value")
+ # Not throwing on this line is sufficient
+ g.add_edge(e1)
+
+
+def test_polymorphic_does_not_accept_collection_of_different_base_type():
+ g = Graph()
+ n1 = FloatCollectionInvocation(id="1", collection=[1.0, 2.0, 3.0])
+ n2 = PolymorphicStringTestInvocation(id="2")
+ g.add_node(n1)
+ g.add_node(n2)
+ e1 = create_edge(n1.id, "collection", n2.id, "value")
+ with pytest.raises(InvalidEdgeError):
+ g.add_edge(e1)
+
+
+def test_polymorphic_does_not_accept_generic_collection():
+ g = Graph()
+ n1 = IntegerInvocation(id="1", value=1)
+ n2 = IntegerInvocation(id="2", value=2)
+ n3 = CollectInvocation(id="3")
+ n4 = PolymorphicStringTestInvocation(id="4")
+ g.add_node(n1)
+ g.add_node(n2)
+ g.add_node(n3)
+ g.add_node(n4)
+ e1 = create_edge(n1.id, "value", n3.id, "item")
+ e2 = create_edge(n2.id, "value", n3.id, "item")
+ e3 = create_edge(n3.id, "collection", n4.id, "value")
+ g.add_edge(e1)
+ g.add_edge(e2)
+ with pytest.raises(InvalidEdgeError):
+ g.add_edge(e3)
+
+
+def test_any_accepts_integer():
+ g = Graph()
+ n1 = IntegerInvocation(id="1", value=1)
+ n2 = AnyTypeTestInvocation(id="2")
+ g.add_node(n1)
+ g.add_node(n2)
+ e = create_edge(n1.id, "value", n2.id, "value")
+ # Not throwing on this line is sufficient
+ g.add_edge(e)
+
+
+def test_any_accepts_string():
+ g = Graph()
+ n1 = StringInvocation(id="1", value="banana sundae")
+ n2 = AnyTypeTestInvocation(id="2")
+ g.add_node(n1)
+ g.add_node(n2)
+ e = create_edge(n1.id, "value", n2.id, "value")
+ # Not throwing on this line is sufficient
+ g.add_edge(e)
+
+
+def test_any_accepts_generic_collection():
+ g = Graph()
+ n1 = IntegerInvocation(id="1", value=1)
+ n2 = IntegerInvocation(id="2", value=2)
+ n3 = CollectInvocation(id="3")
+ n4 = AnyTypeTestInvocation(id="4")
+ g.add_node(n1)
+ g.add_node(n2)
+ g.add_node(n3)
+ g.add_node(n4)
+ e1 = create_edge(n1.id, "value", n3.id, "item")
+ e2 = create_edge(n2.id, "value", n3.id, "item")
+ e3 = create_edge(n3.id, "collection", n4.id, "value")
+ g.add_edge(e1)
+ g.add_edge(e2)
+ # Not throwing on this line is sufficient
+ g.add_edge(e3)
+
+
+def test_any_accepts_prompt_collection():
+ g = Graph()
+ n1 = PromptCollectionTestInvocation(id="1", collection=["banana", "sundae"])
+ n2 = AnyTypeTestInvocation(id="2")
+ g.add_node(n1)
+ g.add_node(n2)
+ e = create_edge(n1.id, "collection", n2.id, "value")
+ # Not throwing on this line is sufficient
+ g.add_edge(e)
+
+
+def test_any_accepts_any():
+ g = Graph()
+ n1 = AnyTypeTestInvocation(id="1")
+ n2 = AnyTypeTestInvocation(id="2")
+ g.add_node(n1)
+ g.add_node(n2)
+ e = create_edge(n1.id, "value", n2.id, "value")
+ # Not throwing on this line is sufficient
+ g.add_edge(e)
+
+
+def test_iterate_accepts_collection():
+ """We need to update the validation for Collect -> Iterate to traverse to the Iterate
+ node's output and compare that against the item type of the Collect node's collection. Until
+ then, Collect nodes may not output into Iterate nodes."""
+ g = Graph()
+ n1 = IntegerInvocation(id="1", value=1)
+ n2 = IntegerInvocation(id="2", value=2)
+ n3 = CollectInvocation(id="3")
+ n4 = IterateInvocation(id="4")
+ g.add_node(n1)
+ g.add_node(n2)
+ g.add_node(n3)
+ g.add_node(n4)
+ e1 = create_edge(n1.id, "value", n3.id, "item")
+ e2 = create_edge(n2.id, "value", n3.id, "item")
+ e3 = create_edge(n3.id, "collection", n4.id, "collection")
+ g.add_edge(e1)
+ g.add_edge(e2)
+ # Once we fix the validation logic as described, this should should not raise an error
+ with pytest.raises(InvalidEdgeError, match="Cannot connect collector to iterator"):
+ g.add_edge(e3)
+
+
def test_graph_can_generate_schema():
# Not throwing on this line is sufficient
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py
index 471c72a005..51b33dd4c7 100644
--- a/tests/nodes/test_nodes.py
+++ b/tests/nodes/test_nodes.py
@@ -1,11 +1,11 @@
from typing import Any, Callable, Union
-from pydantic import Field
-
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
+ InputField,
InvocationContext,
+ OutputField,
invocation,
invocation_output,
)
@@ -15,12 +15,12 @@ from invokeai.app.invocations.image import ImageField
# Define test invocations before importing anything that uses invocations
@invocation_output("test_list_output")
class ListPassThroughInvocationOutput(BaseInvocationOutput):
- collection: list[ImageField] = Field(default_factory=list)
+ collection: list[ImageField] = OutputField(default_factory=list)
@invocation("test_list")
class ListPassThroughInvocation(BaseInvocation):
- collection: list[ImageField] = Field(default_factory=list)
+ collection: list[ImageField] = InputField(default_factory=list)
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
return ListPassThroughInvocationOutput(collection=self.collection)
@@ -28,12 +28,12 @@ class ListPassThroughInvocation(BaseInvocation):
@invocation_output("test_prompt_output")
class PromptTestInvocationOutput(BaseInvocationOutput):
- prompt: str = Field(default="")
+ prompt: str = OutputField(default="")
@invocation("test_prompt")
class PromptTestInvocation(BaseInvocation):
- prompt: str = Field(default="")
+ prompt: str = InputField(default="")
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
return PromptTestInvocationOutput(prompt=self.prompt)
@@ -47,13 +47,13 @@ class ErrorInvocation(BaseInvocation):
@invocation_output("test_image_output")
class ImageTestInvocationOutput(BaseInvocationOutput):
- image: ImageField = Field()
+ image: ImageField = OutputField()
@invocation("test_text_to_image")
class TextToImageTestInvocation(BaseInvocation):
- prompt: str = Field(default="")
- prompt2: str = Field(default="")
+ prompt: str = InputField(default="")
+ prompt2: str = InputField(default="")
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
@@ -61,8 +61,8 @@ class TextToImageTestInvocation(BaseInvocation):
@invocation("test_image_to_image")
class ImageToImageTestInvocation(BaseInvocation):
- prompt: str = Field(default="")
- image: Union[ImageField, None] = Field(default=None)
+ prompt: str = InputField(default="")
+ image: Union[ImageField, None] = InputField(default=None)
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
@@ -70,17 +70,40 @@ class ImageToImageTestInvocation(BaseInvocation):
@invocation_output("test_prompt_collection_output")
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
- collection: list[str] = Field(default_factory=list)
+ collection: list[str] = OutputField(default_factory=list)
@invocation("test_prompt_collection")
class PromptCollectionTestInvocation(BaseInvocation):
- collection: list[str] = Field()
+ collection: list[str] = InputField()
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
+@invocation_output("test_any_output")
+class AnyTypeTestInvocationOutput(BaseInvocationOutput):
+ value: Any = OutputField()
+
+
+@invocation("test_any")
+class AnyTypeTestInvocation(BaseInvocation):
+ value: Any = InputField(default=None)
+
+ def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
+ return AnyTypeTestInvocationOutput(value=self.value)
+
+
+@invocation("test_polymorphic")
+class PolymorphicStringTestInvocation(BaseInvocation):
+ value: Union[str, list[str]] = InputField(default="")
+
+ def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
+ if isinstance(self.value, str):
+ return PromptCollectionTestInvocationOutput(collection=[self.value])
+ return PromptCollectionTestInvocationOutput(collection=self.value)
+
+
# Importing these must happen after test invocations are defined or they won't register
from invokeai.app.services.events.events_base import EventServiceBase # noqa: E402
from invokeai.app.services.shared.graph import Edge, EdgeConnection # noqa: E402
diff --git a/tests/nodes/test_session_queue.py b/tests/nodes/test_session_queue.py
index 731316068c..cdab5729f8 100644
--- a/tests/nodes/test_session_queue.py
+++ b/tests/nodes/test_session_queue.py
@@ -150,9 +150,9 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
assert len(values) == 8
- session_adapter = TypeAdapter(GraphExecutionState)
+ GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
# graph should be serialized
- ges = session_adapter.validate_json(values[0].session)
+ ges = GraphExecutionStateValidator.validate_json(values[0].session)
# graph values should be populated
assert ges.graph.get_node("1").prompt == "Banana sushi"
@@ -161,16 +161,16 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
assert ges.graph.get_node("4").prompt == "Nissan"
# session ids should match deserialized graph
- assert [v.session_id for v in values] == [session_adapter.validate_json(v.session).id for v in values]
+ assert [v.session_id for v in values] == [GraphExecutionStateValidator.validate_json(v.session).id for v in values]
# should unique session ids
sids = [v.session_id for v in values]
assert len(sids) == len(set(sids))
- nfv_list_adapter = TypeAdapter(list[NodeFieldValue])
+ NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
# should have 3 node field values
assert type(values[0].field_values) is str
- assert len(nfv_list_adapter.validate_json(values[0].field_values)) == 3
+ assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
# should have batch id and priority
assert all(v.batch_id == b.batch_id for v in values)