-
-
✖
-
-
-
- Postprocessing...1/3
-
-
-
-
-
-
-
-
diff --git a/invokeai/frontend/web/static/legacy_web/index.js b/invokeai/frontend/web/static/legacy_web/index.js
deleted file mode 100644
index a150f3f2e9..0000000000
--- a/invokeai/frontend/web/static/legacy_web/index.js
+++ /dev/null
@@ -1,234 +0,0 @@
-function toBase64(file) {
- return new Promise((resolve, reject) => {
- const r = new FileReader();
- r.readAsDataURL(file);
- r.onload = () => resolve(r.result);
- r.onerror = (error) => reject(error);
- });
-}
-
-function appendOutput(src, seed, config) {
- let outputNode = document.createElement('figure');
-
- let variations = config.with_variations;
- if (config.variation_amount > 0) {
- variations =
- (variations ? variations + ',' : '') +
- seed +
- ':' +
- config.variation_amount;
- }
- let baseseed =
- config.with_variations || config.variation_amount > 0 ? config.seed : seed;
- let altText =
- baseseed + ' | ' + (variations ? variations + ' | ' : '') + config.prompt;
-
- // img needs width and height for lazy loading to work
- const figureContents = `
-
-
-
-
${seed}
- `;
-
- outputNode.innerHTML = figureContents;
- let figcaption = outputNode.querySelector('figcaption');
-
- // Reload image config
- figcaption.addEventListener('click', () => {
- let form = document.querySelector('#generate-form');
- for (const [k, v] of new FormData(form)) {
- if (k == 'initimg') {
- continue;
- }
- form.querySelector(`*[name=${k}]`).value = config[k];
- }
-
- document.querySelector('#seed').value = baseseed;
- document.querySelector('#with_variations').value = variations || '';
- if (document.querySelector('#variation_amount').value <= 0) {
- document.querySelector('#variation_amount').value = 0.2;
- }
-
- saveFields(document.querySelector('#generate-form'));
- });
-
- document.querySelector('#results').prepend(outputNode);
-}
-
-function saveFields(form) {
- for (const [k, v] of new FormData(form)) {
- if (typeof v !== 'object') {
- // Don't save 'file' type
- localStorage.setItem(k, v);
- }
- }
-}
-
-function loadFields(form) {
- for (const [k, v] of new FormData(form)) {
- const item = localStorage.getItem(k);
- if (item != null) {
- form.querySelector(`*[name=${k}]`).value = item;
- }
- }
-}
-
-function clearFields(form) {
- localStorage.clear();
- let prompt = form.prompt.value;
- form.reset();
- form.prompt.value = prompt;
-}
-
-const BLANK_IMAGE_URL =
- 'data:image/svg+xml,
';
-async function generateSubmit(form) {
- const prompt = document.querySelector('#prompt').value;
-
- // Convert file data to base64
- let formData = Object.fromEntries(new FormData(form));
- formData.initimg_name = formData.initimg.name;
- formData.initimg =
- formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
-
- let strength = formData.strength;
- let totalSteps = formData.initimg
- ? Math.floor(strength * formData.steps)
- : formData.steps;
-
- let progressSectionEle = document.querySelector('#progress-section');
- progressSectionEle.style.display = 'initial';
- let progressEle = document.querySelector('#progress-bar');
- progressEle.setAttribute('max', totalSteps);
- let progressImageEle = document.querySelector('#progress-image');
- progressImageEle.src = BLANK_IMAGE_URL;
-
- progressImageEle.style.display = {}.hasOwnProperty.call(
- formData,
- 'progress_images'
- )
- ? 'initial'
- : 'none';
-
- // Post as JSON, using Fetch streaming to get results
- fetch(form.action, {
- method: form.method,
- body: JSON.stringify(formData),
- }).then(async (response) => {
- const reader = response.body.getReader();
-
- let noOutputs = true;
- while (true) {
- let { value, done } = await reader.read();
- value = new TextDecoder().decode(value);
- if (done) {
- progressSectionEle.style.display = 'none';
- break;
- }
-
- for (let event of value.split('\n').filter((e) => e !== '')) {
- const data = JSON.parse(event);
-
- if (data.event === 'result') {
- noOutputs = false;
- appendOutput(data.url, data.seed, data.config);
- progressEle.setAttribute('value', 0);
- progressEle.setAttribute('max', totalSteps);
- } else if (data.event === 'upscaling-started') {
- document.getElementById('processing_cnt').textContent =
- data.processed_file_cnt;
- document.getElementById('scaling-inprocess-message').style.display =
- 'block';
- } else if (data.event === 'upscaling-done') {
- document.getElementById('scaling-inprocess-message').style.display =
- 'none';
- } else if (data.event === 'step') {
- progressEle.setAttribute('value', data.step);
- if (data.url) {
- progressImageEle.src = data.url;
- }
- } else if (data.event === 'canceled') {
- // avoid alerting as if this were an error case
- noOutputs = false;
- }
- }
- }
-
- // Re-enable form, remove no-results-message
- form.querySelector('fieldset').removeAttribute('disabled');
- document.querySelector('#prompt').value = prompt;
- document.querySelector('progress').setAttribute('value', '0');
-
- if (noOutputs) {
- alert('Error occurred while generating.');
- }
- });
-
- // Disable form while generating
- form.querySelector('fieldset').setAttribute('disabled', '');
- document.querySelector('#prompt').value = `Generating: "${prompt}"`;
-}
-
-async function fetchRunLog() {
- try {
- let response = await fetch('/run_log.json');
- const data = await response.json();
- for (let item of data.run_log) {
- appendOutput(item.url, item.seed, item);
- }
- } catch (e) {
- console.error(e);
- }
-}
-
-window.onload = async () => {
- document.querySelector('#prompt').addEventListener('keydown', (e) => {
- if (e.key === 'Enter' && !e.shiftKey) {
- const form = e.target.form;
- generateSubmit(form);
- }
- });
- document.querySelector('#generate-form').addEventListener('submit', (e) => {
- e.preventDefault();
- const form = e.target;
-
- generateSubmit(form);
- });
- document.querySelector('#generate-form').addEventListener('change', (e) => {
- saveFields(e.target.form);
- });
- document.querySelector('#reset-seed').addEventListener('click', (e) => {
- document.querySelector('#seed').value = -1;
- saveFields(e.target.form);
- });
- document.querySelector('#reset-all').addEventListener('click', (e) => {
- clearFields(e.target.form);
- });
- document.querySelector('#remove-image').addEventListener('click', (e) => {
- initimg.value = null;
- });
- loadFields(document.querySelector('#generate-form'));
-
- document.querySelector('#cancel-button').addEventListener('click', () => {
- fetch('/cancel').catch((e) => {
- console.error(e);
- });
- });
- document.documentElement.addEventListener('keydown', (e) => {
- if (e.key === 'Escape')
- fetch('/cancel').catch((err) => {
- console.error(err);
- });
- });
-
- if (!config.gfpgan_model_exists) {
- document.querySelector('#gfpgan').style.display = 'none';
- }
- await fetchRunLog();
-};
diff --git a/mkdocs.yml b/mkdocs.yml
index f95d83ac8f..97b2a16f19 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -134,6 +134,7 @@ nav:
- List of Default Nodes: 'nodes/defaultNodes.md'
- Workflow Editor Usage: 'nodes/NODES.md'
- ComfyUI to InvokeAI: 'nodes/comfyToInvoke.md'
+ - Facetool Node: 'nodes/detailedNodes/faceTools.md'
- Contributing Nodes: 'nodes/contributingNodes.md'
- Features:
- Overview: 'features/index.md'
@@ -144,7 +145,7 @@ nav:
- Image-to-Image: 'features/IMG2IMG.md'
- Controlling Logging: 'features/LOGGING.md'
- Model Merging: 'features/MODEL_MERGING.md'
- - Using Nodes : './nodes/overview'
+ - Using Nodes : 'nodes/overview.md'
- NSFW Checker: 'features/WATERMARK+NSFW.md'
- Postprocessing: 'features/POSTPROCESS.md'
- Prompting Features: 'features/PROMPTS.md'
@@ -152,15 +153,18 @@ nav:
- Unified Canvas: 'features/UNIFIED_CANVAS.md'
- InvokeAI Web Server: 'features/WEB.md'
- WebUI Hotkeys: "features/WEBUIHOTKEYS.md"
+ - Maintenance Utilities: "features/UTILITIES.md"
- Other: 'features/OTHER.md'
- Contributing:
- How to Contribute: 'contributing/CONTRIBUTING.md'
+ - InvokeAI Code of Conduct: 'CODE_OF_CONDUCT.md'
- Development:
- Overview: 'contributing/contribution_guides/development.md'
- New Contributors: 'contributing/contribution_guides/newContributorChecklist.md'
- InvokeAI Architecture: 'contributing/ARCHITECTURE.md'
- Frontend Documentation: 'contributing/contribution_guides/contributingToFrontend.md'
- Local Development: 'contributing/LOCAL_DEVELOPMENT.md'
+ - Adding Tests: 'contributing/TESTS.md'
- Documentation: 'contributing/contribution_guides/documentation.md'
- Nodes: 'contributing/INVOCATIONS.md'
- Translation: 'contributing/contribution_guides/translation.md'
@@ -168,9 +172,12 @@ nav:
- Changelog: 'CHANGELOG.md'
- Deprecated:
- Command Line Interface: 'deprecated/CLI.md'
+ - Variations: 'deprecated/VARIATIONS.md'
+ - Translations: 'deprecated/TRANSLATION.md'
- Embiggen: 'deprecated/EMBIGGEN.md'
- Inpainting: 'deprecated/INPAINTING.md'
- Outpainting: 'deprecated/OUTPAINTING.md'
+ - Troubleshooting: 'help/deprecated/TROUBLESHOOT.md'
- Help:
- Getting Started: 'help/gettingStartedWithAI.md'
- Diffusion Overview: 'help/diffusion.md'
diff --git a/pyproject.toml b/pyproject.toml
index 67486e1120..89f4ea2b45 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -42,7 +42,7 @@ dependencies = [
"datasets",
# When bumping diffusers beyond 0.21, make sure to address this:
# https://github.com/invoke-ai/InvokeAI/blob/fc09ab7e13cb7ca5389100d149b6422ace7b8ed3/invokeai/app/invocations/latent.py#L513
- "diffusers[torch]~=0.21.0",
+ "diffusers[torch]~=0.22.0",
"dnspython~=2.4.0",
"dynamicprompts",
"easing-functions",
@@ -122,10 +122,9 @@ dependencies = [
"configure_invokeai.py" = "invokeai.frontend.install.invokeai_configure:invokeai_configure"
"textual_inversion.py" = "invokeai.frontend.training:invokeai_textual_inversion"
-# shortcut commands to start cli and web
+# shortcut commands to start web ui
# "invokeai --web" will launch the web interface
-# "invokeai" will launch the CLI
-"invokeai" = "invokeai.frontend.legacy_launch_invokeai:main"
+# "invokeai" = "invokeai.frontend.legacy_launch_invokeai:main"
# new shortcut to launch web interface
"invokeai-web" = "invokeai.app.api_app:invoke_api"
@@ -138,7 +137,6 @@ dependencies = [
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
"invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main"
-"invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli"
"invokeai-node-web" = "invokeai.app.api_app:invoke_api"
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
"invokeai-db-maintenance" = "invokeai.backend.util.db_maintenance:main"
@@ -168,11 +166,13 @@ version = { attr = "invokeai.version.__version__" }
]
[tool.setuptools.package-data]
+"invokeai.app.assets" = ["**/*.png"]
"invokeai.assets.fonts" = ["**/*.ttf"]
"invokeai.backend" = ["**.png"]
"invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"]
"invokeai.frontend.web.dist" = ["**"]
"invokeai.frontend.web.static" = ["**"]
+"invokeai.app.invocations" = ["**"]
#=== Begin: PyTest and Coverage
[tool.pytest.ini_options]
@@ -206,6 +206,7 @@ exclude = [
"build",
"dist",
"invokeai/frontend/web/node_modules/",
+ ".venv*",
]
[tool.black]
diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/backend/ip_adapter/test_ip_adapter.py
index 7f634ee1fe..6712196778 100644
--- a/tests/backend/ip_adapter/test_ip_adapter.py
+++ b/tests/backend/ip_adapter/test_ip_adapter.py
@@ -65,7 +65,10 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
ip_adapter.to(torch_device, dtype=torch.float32)
unet.to(torch_device, dtype=torch.float32)
- cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [torch.randn((1, 4, 768)).to(torch_device)]}
+ # ip_embeds shape: (batch_size, num_ip_images, seq_len, ip_image_embedding_len)
+ ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
+
+ cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
diff --git a/tests/backend/model_management/test_lora.py b/tests/backend/model_management/test_lora.py
new file mode 100644
index 0000000000..14bcc87c89
--- /dev/null
+++ b/tests/backend/model_management/test_lora.py
@@ -0,0 +1,102 @@
+# test that if the model's device changes while the lora is applied, the weights can still be restored
+
+# test that LoRA patching works on both CPU and CUDA
+
+import pytest
+import torch
+
+from invokeai.backend.model_management.lora import ModelPatcher
+from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw
+
+
+@pytest.mark.parametrize(
+ "device",
+ [
+ "cpu",
+ pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
+ ],
+)
+@torch.no_grad()
+def test_apply_lora(device):
+ """Test the basic behavior of ModelPatcher.apply_lora(...). Check that patching and unpatching produce the correct
+ result, and that model/LoRA tensors are moved between devices as expected.
+ """
+
+ linear_in_features = 4
+ linear_out_features = 8
+ lora_dim = 2
+ model = torch.nn.ModuleDict(
+ {"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device=device, dtype=torch.float16)}
+ )
+
+ lora_layers = {
+ "linear_layer_1": LoRALayer(
+ layer_key="linear_layer_1",
+ values={
+ "lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
+ "lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
+ },
+ )
+ }
+ lora = LoRAModelRaw("lora_name", lora_layers)
+
+ lora_weight = 0.5
+ orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
+ expected_patched_linear_weight = orig_linear_weight + (lora_dim * lora_weight)
+
+ with ModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""):
+ # After patching, all LoRA layer weights should have been moved back to the cpu.
+ assert lora_layers["linear_layer_1"].up.device.type == "cpu"
+ assert lora_layers["linear_layer_1"].down.device.type == "cpu"
+
+ # After patching, the patched model should still be on its original device.
+ assert model["linear_layer_1"].weight.data.device.type == device
+
+ torch.testing.assert_close(model["linear_layer_1"].weight.data, expected_patched_linear_weight)
+
+ # After unpatching, the original model weights should have been restored on the original device.
+ assert model["linear_layer_1"].weight.data.device.type == device
+ torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
+@torch.no_grad()
+def test_apply_lora_change_device():
+ """Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
+ still behaves correctly.
+ """
+ linear_in_features = 4
+ linear_out_features = 8
+ lora_dim = 2
+ # Initialize the model on the CPU.
+ model = torch.nn.ModuleDict(
+ {"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)}
+ )
+
+ lora_layers = {
+ "linear_layer_1": LoRALayer(
+ layer_key="linear_layer_1",
+ values={
+ "lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
+ "lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
+ },
+ )
+ }
+ lora = LoRAModelRaw("lora_name", lora_layers)
+
+ orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
+
+ with ModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""):
+ # After patching, all LoRA layer weights should have been moved back to the cpu.
+ assert lora_layers["linear_layer_1"].up.device.type == "cpu"
+ assert lora_layers["linear_layer_1"].down.device.type == "cpu"
+
+ # After patching, the patched model should still be on the CPU.
+ assert model["linear_layer_1"].weight.data.device.type == "cpu"
+
+ # Move the model to the GPU.
+ assert model.to("cuda")
+
+ # After unpatching, the original model weights should have been restored on the GPU.
+ assert model["linear_layer_1"].weight.data.device.type == "cuda"
+ torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight, check_device=False)
diff --git a/tests/backend/model_management/test_memory_snapshot.py b/tests/backend/model_management/test_memory_snapshot.py
index 80aed7b7ba..216cd62171 100644
--- a/tests/backend/model_management/test_memory_snapshot.py
+++ b/tests/backend/model_management/test_memory_snapshot.py
@@ -13,10 +13,11 @@ def test_memory_snapshot_capture():
snapshots = [
- MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=Struct_mallinfo2()),
- MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=None),
- MemorySnapshot(process_ram=1.0, vram=None, malloc_info=Struct_mallinfo2()),
- MemorySnapshot(process_ram=1.0, vram=None, malloc_info=None),
+ MemorySnapshot(process_ram=1, vram=2, malloc_info=Struct_mallinfo2()),
+ MemorySnapshot(process_ram=1, vram=2, malloc_info=None),
+ MemorySnapshot(process_ram=1, vram=None, malloc_info=Struct_mallinfo2()),
+ MemorySnapshot(process_ram=1, vram=None, malloc_info=None),
+ None,
]
@@ -26,10 +27,12 @@ def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
- expected_lines = 1
- if snapshot_1.vram is not None and snapshot_2.vram is not None:
+ expected_lines = 0
+ if snapshot_1 is not None and snapshot_2 is not None:
expected_lines += 1
- if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
- expected_lines += 5
+ if snapshot_1.vram is not None and snapshot_2.vram is not None:
+ expected_lines += 1
+ if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
+ expected_lines += 5
assert len(msg.splitlines()) == expected_lines
diff --git a/tests/backend/model_management/test_model_load_optimization.py b/tests/backend/model_management/test_model_load_optimization.py
index 43f007e972..a4fe1dd597 100644
--- a/tests/backend/model_management/test_model_load_optimization.py
+++ b/tests/backend/model_management/test_model_load_optimization.py
@@ -11,6 +11,7 @@ from invokeai.backend.model_management.model_load_optimizations import _no_op, s
(torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
(torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
(torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
+ (torch.nn.Embedding, {"num_embeddings": 10, "embedding_dim": 10}),
],
)
def test_skip_torch_weight_init_linear(torch_module, layer_args):
@@ -36,12 +37,14 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args):
# Check that reset_parameters is skipped while `skip_torch_weight_init()` is active.
assert reset_params_fn_during == _no_op
assert not torch.allclose(layer_before.weight, layer_during.weight)
- assert not torch.allclose(layer_before.bias, layer_during.bias)
+ if hasattr(layer_before, "bias"):
+ assert not torch.allclose(layer_before.bias, layer_during.bias)
# Check that the original behavior is restored after `skip_torch_weight_init()` ends.
assert reset_params_fn_before is reset_params_fn_after
assert torch.allclose(layer_before.weight, layer_after.weight)
- assert torch.allclose(layer_before.bias, layer_after.bias)
+ if hasattr(layer_before, "bias"):
+ assert torch.allclose(layer_before.bias, layer_after.bias)
def test_skip_torch_weight_init_restores_base_class_behavior():
diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py
index 27b8a58bea..171cdfdb6f 100644
--- a/tests/nodes/test_graph_execution_state.py
+++ b/tests/nodes/test_graph_execution_state.py
@@ -75,6 +75,8 @@ def mock_services() -> InvocationServices:
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
+ workflow_records=None, # type: ignore
+ workflow_image_records=None, # type: ignore
)
diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py
index 105f7417cd..25b02955b0 100644
--- a/tests/nodes/test_invoker.py
+++ b/tests/nodes/test_invoker.py
@@ -80,6 +80,8 @@ def mock_services() -> InvocationServices:
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
+ workflow_records=None, # type: ignore
+ workflow_image_records=None, # type: ignore
)
diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py
index 3c965895f9..e2a50e61e5 100644
--- a/tests/nodes/test_node_graph.py
+++ b/tests/nodes/test_node_graph.py
@@ -10,7 +10,12 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.image import ShowImageInvocation
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
-from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
+from invokeai.app.invocations.primitives import (
+ FloatCollectionInvocation,
+ FloatInvocation,
+ IntegerInvocation,
+ StringInvocation,
+)
from invokeai.app.invocations.upscale import ESRGANInvocation
from invokeai.app.services.shared.default_graphs import create_text_to_image
from invokeai.app.services.shared.graph import (
@@ -27,8 +32,11 @@ from invokeai.app.services.shared.graph import (
)
from .test_nodes import (
+ AnyTypeTestInvocation,
ImageToImageTestInvocation,
ListPassThroughInvocation,
+ PolymorphicStringTestInvocation,
+ PromptCollectionTestInvocation,
PromptTestInvocation,
TextToImageTestInvocation,
)
@@ -607,8 +615,8 @@ def test_graph_can_deserialize():
g.add_edge(e)
json = g.model_dump_json()
- adapter_graph = TypeAdapter(Graph)
- g2 = adapter_graph.validate_json(json)
+ GraphValidator = TypeAdapter(Graph)
+ g2 = GraphValidator.validate_json(json)
assert g2 is not None
assert g2.nodes["1"] is not None
@@ -692,6 +700,144 @@ def test_ints_do_not_accept_floats():
g.add_edge(e)
+def test_polymorphic_accepts_single():
+ g = Graph()
+ n1 = StringInvocation(id="1", value="banana")
+ n2 = PolymorphicStringTestInvocation(id="2")
+ g.add_node(n1)
+ g.add_node(n2)
+ e1 = create_edge(n1.id, "value", n2.id, "value")
+ # Not throwing on this line is sufficient
+ g.add_edge(e1)
+
+
+def test_polymorphic_accepts_collection_of_same_base_type():
+ g = Graph()
+ n1 = PromptCollectionTestInvocation(id="1", collection=["banana", "sundae"])
+ n2 = PolymorphicStringTestInvocation(id="2")
+ g.add_node(n1)
+ g.add_node(n2)
+ e1 = create_edge(n1.id, "collection", n2.id, "value")
+ # Not throwing on this line is sufficient
+ g.add_edge(e1)
+
+
+def test_polymorphic_does_not_accept_collection_of_different_base_type():
+ g = Graph()
+ n1 = FloatCollectionInvocation(id="1", collection=[1.0, 2.0, 3.0])
+ n2 = PolymorphicStringTestInvocation(id="2")
+ g.add_node(n1)
+ g.add_node(n2)
+ e1 = create_edge(n1.id, "collection", n2.id, "value")
+ with pytest.raises(InvalidEdgeError):
+ g.add_edge(e1)
+
+
+def test_polymorphic_does_not_accept_generic_collection():
+ g = Graph()
+ n1 = IntegerInvocation(id="1", value=1)
+ n2 = IntegerInvocation(id="2", value=2)
+ n3 = CollectInvocation(id="3")
+ n4 = PolymorphicStringTestInvocation(id="4")
+ g.add_node(n1)
+ g.add_node(n2)
+ g.add_node(n3)
+ g.add_node(n4)
+ e1 = create_edge(n1.id, "value", n3.id, "item")
+ e2 = create_edge(n2.id, "value", n3.id, "item")
+ e3 = create_edge(n3.id, "collection", n4.id, "value")
+ g.add_edge(e1)
+ g.add_edge(e2)
+ with pytest.raises(InvalidEdgeError):
+ g.add_edge(e3)
+
+
+def test_any_accepts_integer():
+ g = Graph()
+ n1 = IntegerInvocation(id="1", value=1)
+ n2 = AnyTypeTestInvocation(id="2")
+ g.add_node(n1)
+ g.add_node(n2)
+ e = create_edge(n1.id, "value", n2.id, "value")
+ # Not throwing on this line is sufficient
+ g.add_edge(e)
+
+
+def test_any_accepts_string():
+ g = Graph()
+ n1 = StringInvocation(id="1", value="banana sundae")
+ n2 = AnyTypeTestInvocation(id="2")
+ g.add_node(n1)
+ g.add_node(n2)
+ e = create_edge(n1.id, "value", n2.id, "value")
+ # Not throwing on this line is sufficient
+ g.add_edge(e)
+
+
+def test_any_accepts_generic_collection():
+ g = Graph()
+ n1 = IntegerInvocation(id="1", value=1)
+ n2 = IntegerInvocation(id="2", value=2)
+ n3 = CollectInvocation(id="3")
+ n4 = AnyTypeTestInvocation(id="4")
+ g.add_node(n1)
+ g.add_node(n2)
+ g.add_node(n3)
+ g.add_node(n4)
+ e1 = create_edge(n1.id, "value", n3.id, "item")
+ e2 = create_edge(n2.id, "value", n3.id, "item")
+ e3 = create_edge(n3.id, "collection", n4.id, "value")
+ g.add_edge(e1)
+ g.add_edge(e2)
+ # Not throwing on this line is sufficient
+ g.add_edge(e3)
+
+
+def test_any_accepts_prompt_collection():
+ g = Graph()
+ n1 = PromptCollectionTestInvocation(id="1", collection=["banana", "sundae"])
+ n2 = AnyTypeTestInvocation(id="2")
+ g.add_node(n1)
+ g.add_node(n2)
+ e = create_edge(n1.id, "collection", n2.id, "value")
+ # Not throwing on this line is sufficient
+ g.add_edge(e)
+
+
+def test_any_accepts_any():
+ g = Graph()
+ n1 = AnyTypeTestInvocation(id="1")
+ n2 = AnyTypeTestInvocation(id="2")
+ g.add_node(n1)
+ g.add_node(n2)
+ e = create_edge(n1.id, "value", n2.id, "value")
+ # Not throwing on this line is sufficient
+ g.add_edge(e)
+
+
+def test_iterate_accepts_collection():
+ """We need to update the validation for Collect -> Iterate to traverse to the Iterate
+ node's output and compare that against the item type of the Collect node's collection. Until
+ then, Collect nodes may not output into Iterate nodes."""
+ g = Graph()
+ n1 = IntegerInvocation(id="1", value=1)
+ n2 = IntegerInvocation(id="2", value=2)
+ n3 = CollectInvocation(id="3")
+ n4 = IterateInvocation(id="4")
+ g.add_node(n1)
+ g.add_node(n2)
+ g.add_node(n3)
+ g.add_node(n4)
+ e1 = create_edge(n1.id, "value", n3.id, "item")
+ e2 = create_edge(n2.id, "value", n3.id, "item")
+ e3 = create_edge(n3.id, "collection", n4.id, "collection")
+ g.add_edge(e1)
+ g.add_edge(e2)
+ # Once we fix the validation logic as described, this should should not raise an error
+ with pytest.raises(InvalidEdgeError, match="Cannot connect collector to iterator"):
+ g.add_edge(e3)
+
+
def test_graph_can_generate_schema():
# Not throwing on this line is sufficient
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py
index 471c72a005..51b33dd4c7 100644
--- a/tests/nodes/test_nodes.py
+++ b/tests/nodes/test_nodes.py
@@ -1,11 +1,11 @@
from typing import Any, Callable, Union
-from pydantic import Field
-
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
+ InputField,
InvocationContext,
+ OutputField,
invocation,
invocation_output,
)
@@ -15,12 +15,12 @@ from invokeai.app.invocations.image import ImageField
# Define test invocations before importing anything that uses invocations
@invocation_output("test_list_output")
class ListPassThroughInvocationOutput(BaseInvocationOutput):
- collection: list[ImageField] = Field(default_factory=list)
+ collection: list[ImageField] = OutputField(default_factory=list)
@invocation("test_list")
class ListPassThroughInvocation(BaseInvocation):
- collection: list[ImageField] = Field(default_factory=list)
+ collection: list[ImageField] = InputField(default_factory=list)
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
return ListPassThroughInvocationOutput(collection=self.collection)
@@ -28,12 +28,12 @@ class ListPassThroughInvocation(BaseInvocation):
@invocation_output("test_prompt_output")
class PromptTestInvocationOutput(BaseInvocationOutput):
- prompt: str = Field(default="")
+ prompt: str = OutputField(default="")
@invocation("test_prompt")
class PromptTestInvocation(BaseInvocation):
- prompt: str = Field(default="")
+ prompt: str = InputField(default="")
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
return PromptTestInvocationOutput(prompt=self.prompt)
@@ -47,13 +47,13 @@ class ErrorInvocation(BaseInvocation):
@invocation_output("test_image_output")
class ImageTestInvocationOutput(BaseInvocationOutput):
- image: ImageField = Field()
+ image: ImageField = OutputField()
@invocation("test_text_to_image")
class TextToImageTestInvocation(BaseInvocation):
- prompt: str = Field(default="")
- prompt2: str = Field(default="")
+ prompt: str = InputField(default="")
+ prompt2: str = InputField(default="")
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
@@ -61,8 +61,8 @@ class TextToImageTestInvocation(BaseInvocation):
@invocation("test_image_to_image")
class ImageToImageTestInvocation(BaseInvocation):
- prompt: str = Field(default="")
- image: Union[ImageField, None] = Field(default=None)
+ prompt: str = InputField(default="")
+ image: Union[ImageField, None] = InputField(default=None)
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
@@ -70,17 +70,40 @@ class ImageToImageTestInvocation(BaseInvocation):
@invocation_output("test_prompt_collection_output")
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
- collection: list[str] = Field(default_factory=list)
+ collection: list[str] = OutputField(default_factory=list)
@invocation("test_prompt_collection")
class PromptCollectionTestInvocation(BaseInvocation):
- collection: list[str] = Field()
+ collection: list[str] = InputField()
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
+@invocation_output("test_any_output")
+class AnyTypeTestInvocationOutput(BaseInvocationOutput):
+ value: Any = OutputField()
+
+
+@invocation("test_any")
+class AnyTypeTestInvocation(BaseInvocation):
+ value: Any = InputField(default=None)
+
+ def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
+ return AnyTypeTestInvocationOutput(value=self.value)
+
+
+@invocation("test_polymorphic")
+class PolymorphicStringTestInvocation(BaseInvocation):
+ value: Union[str, list[str]] = InputField(default="")
+
+ def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
+ if isinstance(self.value, str):
+ return PromptCollectionTestInvocationOutput(collection=[self.value])
+ return PromptCollectionTestInvocationOutput(collection=self.value)
+
+
# Importing these must happen after test invocations are defined or they won't register
from invokeai.app.services.events.events_base import EventServiceBase # noqa: E402
from invokeai.app.services.shared.graph import Edge, EdgeConnection # noqa: E402
diff --git a/tests/nodes/test_session_queue.py b/tests/nodes/test_session_queue.py
index 731316068c..cdab5729f8 100644
--- a/tests/nodes/test_session_queue.py
+++ b/tests/nodes/test_session_queue.py
@@ -150,9 +150,9 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
assert len(values) == 8
- session_adapter = TypeAdapter(GraphExecutionState)
+ GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
# graph should be serialized
- ges = session_adapter.validate_json(values[0].session)
+ ges = GraphExecutionStateValidator.validate_json(values[0].session)
# graph values should be populated
assert ges.graph.get_node("1").prompt == "Banana sushi"
@@ -161,16 +161,16 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
assert ges.graph.get_node("4").prompt == "Nissan"
# session ids should match deserialized graph
- assert [v.session_id for v in values] == [session_adapter.validate_json(v.session).id for v in values]
+ assert [v.session_id for v in values] == [GraphExecutionStateValidator.validate_json(v.session).id for v in values]
# should unique session ids
sids = [v.session_id for v in values]
assert len(sids) == len(set(sids))
- nfv_list_adapter = TypeAdapter(list[NodeFieldValue])
+ NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
# should have 3 node field values
assert type(values[0].field_values) is str
- assert len(nfv_list_adapter.validate_json(values[0].field_values)) == 3
+ assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
# should have batch id and priority
assert all(v.batch_id == b.batch_id for v in values)