From c00aea7a6ce7e257c2c7324777dc0b15c01cedfd Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Thu, 29 Jun 2023 16:01:17 +1000
Subject: [PATCH] tests(nodes): fix nodes tests

---
 invokeai/app/services/invocation_services.py |  67 +++--
 tests/conftest.py                            |  30 --
 tests/nodes/test_graph_execution_state.py    | 130 +++++---
 tests/nodes/test_invoker.py                  |  88 ++++--
 tests/nodes/test_node_graph.py               |  75 +++--
 tests/nodes/test_nodes.py                    |  16 +-
 tests/test_textual_inversion.py              | 301 -------------------
 7 files changed, 242 insertions(+), 465 deletions(-)
 delete mode 100644 tests/conftest.py
 delete mode 100644 tests/test_textual_inversion.py

diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py
index 10d1d91920..4e1da3b040 100644
--- a/invokeai/app/services/invocation_services.py
+++ b/invokeai/app/services/invocation_services.py
@@ -7,7 +7,7 @@ if TYPE_CHECKING:
     from invokeai.app.services.board_images import BoardImagesServiceABC
     from invokeai.app.services.boards import BoardServiceABC
     from invokeai.app.services.images import ImageServiceABC
-    from invokeai.backend import ModelManager
+    from invokeai.app.services.model_manager_service import ModelManagerServiceBase
     from invokeai.app.services.events import EventServiceBase
     from invokeai.app.services.latent_storage import LatentsStorageBase
     from invokeai.app.services.restoration_services import RestorationServices
@@ -22,46 +22,47 @@ class InvocationServices:
     """Services that can be used by invocations"""
 
     # TODO: Just forward-declared everything due to circular dependencies. Fix structure.
-    events: "EventServiceBase"
-    latents: "LatentsStorageBase"
-    queue: "InvocationQueueABC"
-    model_manager: "ModelManager"
-    restoration: "RestorationServices"
-    configuration: "InvokeAISettings"
-    images: "ImageServiceABC"
-    boards: "BoardServiceABC"
     board_images: "BoardImagesServiceABC"
-    graph_library: "ItemStorageABC"["LibraryGraph"]
+    boards: "BoardServiceABC"
+    configuration: "InvokeAISettings"
+    events: "EventServiceBase"
     graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
+    graph_library: "ItemStorageABC"["LibraryGraph"]
+    images: "ImageServiceABC"
+    latents: "LatentsStorageBase"
+    logger: "Logger"
+    model_manager: "ModelManagerServiceBase"
     processor: "InvocationProcessorABC"
+    queue: "InvocationQueueABC"
+    restoration: "RestorationServices"
 
     def __init__(
         self,
-        model_manager: "ModelManager",
-        events: "EventServiceBase",
-        logger: "Logger",
-        latents: "LatentsStorageBase",
-        images: "ImageServiceABC",
-        boards: "BoardServiceABC",
         board_images: "BoardImagesServiceABC",
-        queue: "InvocationQueueABC",
-        graph_library: "ItemStorageABC"["LibraryGraph"],
-        graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
-        processor: "InvocationProcessorABC",
-        restoration: "RestorationServices",
+        boards: "BoardServiceABC",
         configuration: "InvokeAISettings",
+        events: "EventServiceBase",
+        graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
+        graph_library: "ItemStorageABC"["LibraryGraph"],
+        images: "ImageServiceABC",
+        latents: "LatentsStorageBase",
+        logger: "Logger",
+        model_manager: "ModelManagerServiceBase",
+        processor: "InvocationProcessorABC",
+        queue: "InvocationQueueABC",
+        restoration: "RestorationServices",
     ):
-        self.model_manager = model_manager
-        self.events = events
-        self.logger = logger
-        self.latents = latents
-        self.images = images
-        self.boards = boards
         self.board_images = board_images
-        self.queue = queue
-        self.graph_library = graph_library
-        self.graph_execution_manager = graph_execution_manager
-        self.processor = processor
-        self.restoration = restoration
-        self.configuration = configuration
         self.boards = boards
+        self.boards = boards
+        self.configuration = configuration
+        self.events = events
+        self.graph_execution_manager = graph_execution_manager
+        self.graph_library = graph_library
+        self.images = images
+        self.latents = latents
+        self.logger = logger
+        self.model_manager = model_manager
+        self.processor = processor
+        self.queue = queue
+        self.restoration = restoration
diff --git a/tests/conftest.py b/tests/conftest.py
deleted file mode 100644
index 06502b6c41..0000000000
--- a/tests/conftest.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import pytest
-from invokeai.app.services.invocation_services import InvocationServices
-from invokeai.app.services.invocation_queue import MemoryInvocationQueue
-from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
-from invokeai.app.services.graph import LibraryGraph, GraphExecutionState
-from invokeai.app.services.processor import DefaultInvocationProcessor
-
-# Ignore these files as they need to be rewritten following the model manager refactor
-collect_ignore = ["nodes/test_graph_execution_state.py", "nodes/test_node_graph.py", "test_textual_inversion.py"]
-
-@pytest.fixture(scope="session", autouse=True)
-def mock_services():
-    # NOTE: none of these are actually called by the test invocations
-    return InvocationServices(
-        model_manager = None, # type: ignore
-        events = None, # type: ignore
-        logger = None, # type: ignore
-        images = None, # type: ignore
-        latents = None, # type: ignore
-        board_images=None, # type: ignore
-        boards=None, # type: ignore
-        queue = MemoryInvocationQueue(),
-        graph_library=SqliteItemStorage[LibraryGraph](
-            filename=sqlite_memory, table_name="graphs"
-        ),
-        graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
-        processor = DefaultInvocationProcessor(),
-        restoration = None, # type: ignore
-        configuration = None, # type: ignore
-    )
diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py
index df8964da18..f34b18310b 100644
--- a/tests/nodes/test_graph_execution_state.py
+++ b/tests/nodes/test_graph_execution_state.py
@@ -1,41 +1,81 @@
-import pytest
-
-from invokeai.app.invocations.baseinvocation import (BaseInvocation,
-                                                     BaseInvocationOutput,
-                                                     InvocationContext)
+from .test_invoker import create_edge
+from .test_nodes import (
+    TestEventService,
+    TextToImageTestInvocation,
+    PromptTestInvocation,
+    PromptCollectionTestInvocation,
+)
+from invokeai.app.services.invocation_queue import MemoryInvocationQueue
+from invokeai.app.services.processor import DefaultInvocationProcessor
+from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
+from invokeai.app.invocations.baseinvocation import (
+    BaseInvocation,
+    BaseInvocationOutput,
+    InvocationContext,
+)
 from invokeai.app.invocations.collections import RangeInvocation
 from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
-from invokeai.app.services.graph import (CollectInvocation, Graph,
-                                         GraphExecutionState,
-                                         IterateInvocation)
 from invokeai.app.services.invocation_services import InvocationServices
-
-from .test_invoker import create_edge
-from .test_nodes import (ImageTestInvocation, PromptCollectionTestInvocation,
-                         PromptTestInvocation)
+from invokeai.app.services.graph import (
+    Graph,
+    CollectInvocation,
+    IterateInvocation,
+    GraphExecutionState,
+    LibraryGraph,
+)
+import pytest
 
 
 @pytest.fixture
 def simple_graph():
     g = Graph()
-    g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
-    g.add_node(ImageTestInvocation(id = "2"))
+    g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
+    g.add_node(TextToImageTestInvocation(id="2"))
     g.add_edge(create_edge("1", "prompt", "2", "prompt"))
     return g
 
-def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
+
+# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
+# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
+# the test invocations.
+@pytest.fixture
+def mock_services() -> InvocationServices:
+    # NOTE: none of these are actually called by the test invocations
+    return InvocationServices(
+        model_manager = None, # type: ignore
+        events = TestEventService(),
+        logger = None, # type: ignore
+        images = None, # type: ignore
+        latents = None, # type: ignore
+        boards = None, # type: ignore
+        board_images = None, # type: ignore
+        queue = MemoryInvocationQueue(),
+        graph_library=SqliteItemStorage[LibraryGraph](
+            filename=sqlite_memory, table_name="graphs"
+        ),
+        graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
+        processor = DefaultInvocationProcessor(),
+        restoration = None, # type: ignore
+        configuration = None, # type: ignore
+    )
+
+
+def invoke_next(
+    g: GraphExecutionState, services: InvocationServices
+) -> tuple[BaseInvocation, BaseInvocationOutput]:
     n = g.next()
     if n is None:
         return (None, None)
 
-    print(f'invoking {n.id}: {type(n)}')
+    print(f"invoking {n.id}: {type(n)}")
     o = n.invoke(InvocationContext(services, "1"))
     g.complete(n.id, o)
 
     return (n, o)
 
+
 def test_graph_state_executes_in_order(simple_graph, mock_services):
-    g = GraphExecutionState(graph = simple_graph)
+    g = GraphExecutionState(graph=simple_graph)
 
     n1 = invoke_next(g, mock_services)
     n2 = invoke_next(g, mock_services)
@@ -47,38 +87,42 @@ def test_graph_state_executes_in_order(simple_graph, mock_services):
     assert g.results[n1[0].id].prompt == n1[0].prompt
     assert n2[0].prompt == n1[0].prompt
 
+
 def test_graph_is_complete(simple_graph, mock_services):
-    g = GraphExecutionState(graph = simple_graph)
+    g = GraphExecutionState(graph=simple_graph)
     n1 = invoke_next(g, mock_services)
     n2 = invoke_next(g, mock_services)
     n3 = g.next()
 
     assert g.is_complete()
 
+
 def test_graph_is_not_complete(simple_graph, mock_services):
-    g = GraphExecutionState(graph = simple_graph)
+    g = GraphExecutionState(graph=simple_graph)
     n1 = invoke_next(g, mock_services)
     n2 = g.next()
 
     assert not g.is_complete()
 
+
 # TODO: test completion with iterators/subgraphs
 
+
 def test_graph_state_expands_iterator(mock_services):
     graph = Graph()
-    graph.add_node(RangeInvocation(id = "0", start = 0, stop = 3, step = 1))
-    graph.add_node(IterateInvocation(id = "1"))
-    graph.add_node(MultiplyInvocation(id = "2", b = 10))
-    graph.add_node(AddInvocation(id = "3", b = 1))
+    graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1))
+    graph.add_node(IterateInvocation(id="1"))
+    graph.add_node(MultiplyInvocation(id="2", b=10))
+    graph.add_node(AddInvocation(id="3", b=1))
     graph.add_edge(create_edge("0", "collection", "1", "collection"))
     graph.add_edge(create_edge("1", "item", "2", "a"))
     graph.add_edge(create_edge("2", "a", "3", "a"))
 
-    g = GraphExecutionState(graph = graph)
+    g = GraphExecutionState(graph=graph)
     while not g.is_complete():
         invoke_next(g, mock_services)
 
-    prepared_add_nodes = g.source_prepared_mapping['3']
+    prepared_add_nodes = g.source_prepared_mapping["3"]
     results = set([g.results[n].a for n in prepared_add_nodes])
     expected = set([1, 11, 21])
     assert results == expected
@@ -87,15 +131,17 @@ def test_graph_state_expands_iterator(mock_services):
 def test_graph_state_collects(mock_services):
     graph = Graph()
     test_prompts = ["Banana sushi", "Cat sushi"]
-    graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts)))
-    graph.add_node(IterateInvocation(id = "2"))
-    graph.add_node(PromptTestInvocation(id = "3"))
-    graph.add_node(CollectInvocation(id = "4"))
+    graph.add_node(
+        PromptCollectionTestInvocation(id="1", collection=list(test_prompts))
+    )
+    graph.add_node(IterateInvocation(id="2"))
+    graph.add_node(PromptTestInvocation(id="3"))
+    graph.add_node(CollectInvocation(id="4"))
     graph.add_edge(create_edge("1", "collection", "2", "collection"))
     graph.add_edge(create_edge("2", "item", "3", "prompt"))
     graph.add_edge(create_edge("3", "prompt", "4", "item"))
 
-    g = GraphExecutionState(graph = graph)
+    g = GraphExecutionState(graph=graph)
     n1 = invoke_next(g, mock_services)
     n2 = invoke_next(g, mock_services)
     n3 = invoke_next(g, mock_services)
@@ -113,10 +159,16 @@ def test_graph_state_prepares_eagerly(mock_services):
     graph = Graph()
 
     test_prompts = ["Banana sushi", "Cat sushi"]
-    graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
+    graph.add_node(
+        PromptCollectionTestInvocation(
+            id="prompt_collection", collection=list(test_prompts)
+        )
+    )
     graph.add_node(IterateInvocation(id="iterate"))
     graph.add_node(PromptTestInvocation(id="prompt_iterated"))
-    graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
+    graph.add_edge(
+        create_edge("prompt_collection", "collection", "iterate", "collection")
+    )
     graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
 
     # separated, fully-preparable chain of nodes
@@ -142,13 +194,21 @@ def test_graph_executes_depth_first(mock_services):
     graph = Graph()
 
     test_prompts = ["Banana sushi", "Cat sushi"]
-    graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
+    graph.add_node(
+        PromptCollectionTestInvocation(
+            id="prompt_collection", collection=list(test_prompts)
+        )
+    )
     graph.add_node(IterateInvocation(id="iterate"))
     graph.add_node(PromptTestInvocation(id="prompt_iterated"))
     graph.add_node(PromptTestInvocation(id="prompt_successor"))
-    graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
+    graph.add_edge(
+        create_edge("prompt_collection", "collection", "iterate", "collection")
+    )
     graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
-    graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
+    graph.add_edge(
+        create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt")
+    )
 
     g = GraphExecutionState(graph=graph)
     n1 = invoke_next(g, mock_services)
diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py
index 4331e62d21..19d7dd20b3 100644
--- a/tests/nodes/test_invoker.py
+++ b/tests/nodes/test_invoker.py
@@ -1,26 +1,62 @@
-import pytest
-
-from invokeai.app.services.graph import Graph, GraphExecutionState
-from invokeai.app.services.invocation_services import InvocationServices
+from .test_nodes import (
+    TestEventService,
+    ErrorInvocation,
+    TextToImageTestInvocation,
+    PromptTestInvocation,
+    create_edge,
+    wait_until,
+)
+from invokeai.app.services.invocation_queue import MemoryInvocationQueue
+from invokeai.app.services.processor import DefaultInvocationProcessor
+from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
 from invokeai.app.services.invoker import Invoker
-
-from .test_nodes import (ErrorInvocation, ImageTestInvocation,
-                         PromptTestInvocation, create_edge, wait_until)
+from invokeai.app.services.invocation_services import InvocationServices
+from invokeai.app.services.graph import (
+    Graph,
+    GraphExecutionState,
+    LibraryGraph,
+)
+import pytest
 
 
 @pytest.fixture
 def simple_graph():
     g = Graph()
-    g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi"))
-    g.add_node(ImageTestInvocation(id = "2"))
+    g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
+    g.add_node(TextToImageTestInvocation(id="2"))
     g.add_edge(create_edge("1", "prompt", "2", "prompt"))
     return g
 
+
+# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
+# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
+# the test invocations.
+@pytest.fixture
+def mock_services() -> InvocationServices:
+    # NOTE: none of these are actually called by the test invocations
+    return InvocationServices(
+        model_manager = None, # type: ignore
+        events = TestEventService(),
+        logger = None, # type: ignore
+        images = None, # type: ignore
+        latents = None, # type: ignore
+        boards = None, # type: ignore
+        board_images = None, # type: ignore
+        queue = MemoryInvocationQueue(),
+        graph_library=SqliteItemStorage[LibraryGraph](
+            filename=sqlite_memory, table_name="graphs"
+        ),
+        graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
+        processor = DefaultInvocationProcessor(),
+        restoration = None, # type: ignore
+        configuration = None, # type: ignore
+    )
+
+
 @pytest.fixture()
 def mock_invoker(mock_services: InvocationServices) -> Invoker:
-    return Invoker(
-        services = mock_services
-    )
+    return Invoker(services=mock_services)
+
 
 def test_can_create_graph_state(mock_invoker: Invoker):
     g = mock_invoker.create_execution_state()
@@ -29,17 +65,19 @@ def test_can_create_graph_state(mock_invoker: Invoker):
     assert g is not None
     assert isinstance(g, GraphExecutionState)
 
+
 def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
-    g = mock_invoker.create_execution_state(graph = simple_graph)
+    g = mock_invoker.create_execution_state(graph=simple_graph)
     mock_invoker.stop()
 
     assert g is not None
     assert isinstance(g, GraphExecutionState)
     assert g.graph == simple_graph
 
-@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
+
+# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
 def test_can_invoke(mock_invoker: Invoker, simple_graph):
-    g = mock_invoker.create_execution_state(graph = simple_graph)
+    g = mock_invoker.create_execution_state(graph=simple_graph)
     invocation_id = mock_invoker.invoke(g)
     assert invocation_id is not None
 
@@ -47,32 +85,34 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
         g = mock_invoker.services.graph_execution_manager.get(g.id)
         return len(g.executed) > 0
 
-    wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1)
+    wait_until(lambda: has_executed_any(g), timeout=5, interval=1)
     mock_invoker.stop()
 
     g = mock_invoker.services.graph_execution_manager.get(g.id)
     assert len(g.executed) > 0
 
-@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
+
+# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
 def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
-    g = mock_invoker.create_execution_state(graph = simple_graph)
-    invocation_id = mock_invoker.invoke(g, invoke_all = True)
+    g = mock_invoker.create_execution_state(graph=simple_graph)
+    invocation_id = mock_invoker.invoke(g, invoke_all=True)
     assert invocation_id is not None
 
     def has_executed_all(g: GraphExecutionState):
         g = mock_invoker.services.graph_execution_manager.get(g.id)
         return g.is_complete()
 
-    wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
+    wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
     mock_invoker.stop()
 
     g = mock_invoker.services.graph_execution_manager.get(g.id)
     assert g.is_complete()
 
-@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
+
+# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
 def test_handles_errors(mock_invoker: Invoker):
     g = mock_invoker.create_execution_state()
-    g.graph.add_node(ErrorInvocation(id = "1"))
+    g.graph.add_node(ErrorInvocation(id="1"))
 
     mock_invoker.invoke(g, invoke_all=True)
 
@@ -80,11 +120,11 @@ def test_handles_errors(mock_invoker: Invoker):
         g = mock_invoker.services.graph_execution_manager.get(g.id)
         return g.is_complete()
 
-    wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1)
+    wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
     mock_invoker.stop()
 
     g = mock_invoker.services.graph_execution_manager.get(g.id)
     assert g.has_error()
     assert g.is_complete()
 
-    assert all((i in g.errors for i in g.source_prepared_mapping['1']))
+    assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py
index 82818414b2..df7378150d 100644
--- a/tests/nodes/test_node_graph.py
+++ b/tests/nodes/test_node_graph.py
@@ -1,6 +1,5 @@
-from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
+from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation
 from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
-from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
 from invokeai.app.invocations.upscale import UpscaleInvocation
 from invokeai.app.invocations.image import *
 from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
@@ -18,7 +17,7 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
 
 # Tests
 def test_connections_are_compatible():
-    from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     from_field = "image"
     to_node = UpscaleInvocation(id = "2")
     to_field = "image"
@@ -28,7 +27,7 @@ def test_connections_are_compatible():
     assert result == True
 
 def test_connections_are_incompatible():
-    from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     from_field = "image"
     to_node = UpscaleInvocation(id = "2")
     to_field = "strength"
@@ -38,7 +37,7 @@ def test_connections_are_incompatible():
     assert result == False
 
 def test_connections_incompatible_with_invalid_fields():
-    from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     from_field = "invalid_field"
     to_node = UpscaleInvocation(id = "2")
     to_field = "image"
@@ -56,28 +55,28 @@ def test_connections_incompatible_with_invalid_fields():
 
 def test_graph_can_add_node():
     g = Graph()
-    n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     g.add_node(n)
 
     assert n.id in g.nodes
 
 def test_graph_fails_to_add_node_with_duplicate_id():
     g = Graph()
-    n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     g.add_node(n)
-    n2 = TextToImageInvocation(id = "1", prompt = "Banana sushi the second")
+    n2 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi the second")
 
     with pytest.raises(NodeAlreadyInGraphError):
         g.add_node(n2)
 
 def test_graph_updates_node():
     g = Graph()
-    n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     g.add_node(n)
-    n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second")
+    n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second")
     g.add_node(n2)
 
-    nu = TextToImageInvocation(id = "1", prompt = "Banana sushi updated")
+    nu = TextToImageTestInvocation(id = "1", prompt = "Banana sushi updated")
 
     g.update_node("1", nu)
 
@@ -85,7 +84,7 @@ def test_graph_updates_node():
 
 def test_graph_fails_to_update_node_if_type_changes():
     g = Graph()
-    n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     g.add_node(n)
     n2 = UpscaleInvocation(id = "2")
     g.add_node(n2)
@@ -97,14 +96,14 @@ def test_graph_fails_to_update_node_if_type_changes():
 
 def test_graph_allows_non_conflicting_id_change():
     g = Graph()
-    n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     g.add_node(n)
     n2 = UpscaleInvocation(id = "2")
     g.add_node(n2)
     e1 = create_edge(n.id,"image",n2.id,"image")
     g.add_edge(e1)
     
-    nu = TextToImageInvocation(id = "3", prompt = "Banana sushi")
+    nu = TextToImageTestInvocation(id = "3", prompt = "Banana sushi")
     g.update_node("1", nu)
 
     with pytest.raises(NodeNotFoundError):
@@ -117,18 +116,18 @@ def test_graph_allows_non_conflicting_id_change():
 
 def test_graph_fails_to_update_node_id_if_conflict():
     g = Graph()
-    n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     g.add_node(n)
-    n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second")
+    n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second")
     g.add_node(n2)
     
-    nu = TextToImageInvocation(id = "2", prompt = "Banana sushi")
+    nu = TextToImageTestInvocation(id = "2", prompt = "Banana sushi")
     with pytest.raises(NodeAlreadyInGraphError):
         g.update_node("1", nu)
 
 def test_graph_adds_edge():
     g = Graph()
-    n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n2 = UpscaleInvocation(id = "2")
     g.add_node(n1)
     g.add_node(n2)
@@ -148,7 +147,7 @@ def test_graph_fails_to_add_edge_with_cycle():
 
 def test_graph_fails_to_add_edge_with_long_cycle():
     g = Graph()
-    n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n2 = UpscaleInvocation(id = "2")
     n3 = UpscaleInvocation(id = "3")
     g.add_node(n1)
@@ -164,7 +163,7 @@ def test_graph_fails_to_add_edge_with_long_cycle():
 
 def test_graph_fails_to_add_edge_with_missing_node_id():
     g = Graph()
-    n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n2 = UpscaleInvocation(id = "2")
     g.add_node(n1)
     g.add_node(n2)
@@ -177,7 +176,7 @@ def test_graph_fails_to_add_edge_with_missing_node_id():
 
 def test_graph_fails_to_add_edge_when_destination_exists():
     g = Graph()
-    n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n2 = UpscaleInvocation(id = "2")
     n3 = UpscaleInvocation(id = "3")
     g.add_node(n1)
@@ -194,7 +193,7 @@ def test_graph_fails_to_add_edge_when_destination_exists():
 
 def test_graph_fails_to_add_edge_with_mismatched_types():
     g = Graph()
-    n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n2 = UpscaleInvocation(id = "2")
     g.add_node(n1)
     g.add_node(n2)
@@ -204,8 +203,8 @@ def test_graph_fails_to_add_edge_with_mismatched_types():
 
 def test_graph_connects_collector():
     g = Graph()
-    n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
-    n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi 2")
+    n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
+    n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi 2")
     n3 = CollectInvocation(id = "3")
     n4 = ListPassThroughInvocation(id = "4")
     g.add_node(n1)
@@ -224,7 +223,7 @@ def test_graph_connects_collector():
 
 def test_graph_collector_invalid_with_varying_input_types():
     g = Graph()
-    n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2")
     n3 = CollectInvocation(id = "3")
     g.add_node(n1)
@@ -282,7 +281,7 @@ def test_graph_connects_iterator():
     g = Graph()
     n1 = ListPassThroughInvocation(id = "1")
     n2 = IterateInvocation(id = "2")
-    n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi")
+    n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi")
     g.add_node(n1)
     g.add_node(n2)
     g.add_node(n3)
@@ -298,7 +297,7 @@ def test_graph_iterator_invalid_if_multiple_inputs():
     g = Graph()
     n1 = ListPassThroughInvocation(id = "1")
     n2 = IterateInvocation(id = "2")
-    n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi")
+    n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi")
     n4 = ListPassThroughInvocation(id = "4")
     g.add_node(n1)
     g.add_node(n2)
@@ -316,7 +315,7 @@ def test_graph_iterator_invalid_if_multiple_inputs():
 
 def test_graph_iterator_invalid_if_input_not_list():
     g = Graph()
-    n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi")
+    n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n2 = IterateInvocation(id = "2")
     g.add_node(n1)
     g.add_node(n2)
@@ -344,7 +343,7 @@ def test_graph_iterator_invalid_if_output_and_input_types_different():
 
 def test_graph_validates():
     g = Graph()
-    n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n2 = UpscaleInvocation(id = "2")
     g.add_node(n1)
     g.add_node(n2)
@@ -355,7 +354,7 @@ def test_graph_validates():
 
 def test_graph_invalid_if_edges_reference_missing_nodes():
     g = Graph()
-    n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     g.nodes[n1.id] = n1
     e1 = create_edge("1","image","2","image")
     g.edges.append(e1)
@@ -367,7 +366,7 @@ def test_graph_invalid_if_subgraph_invalid():
     n1 = GraphInvocation(id = "1")
     n1.graph = Graph()
 
-    n1_1 = TextToImageInvocation(id = "2", prompt = "Banana sushi")
+    n1_1 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi")
     n1.graph.nodes[n1_1.id] = n1_1
     e1 = create_edge("1","image","2","image")
     n1.graph.edges.append(e1)
@@ -391,7 +390,7 @@ def test_graph_invalid_if_has_cycle():
 
 def test_graph_invalid_with_invalid_connection():
     g = Graph()
-    n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n2 = UpscaleInvocation(id = "2")
     g.nodes[n1.id] = n1
     g.nodes[n2.id] = n2
@@ -408,7 +407,7 @@ def test_graph_gets_subgraph_node():
     n1.graph = Graph()
     n1.graph.add_node
 
-    n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n1.graph.add_node(n1_1)
 
     g.add_node(n1)
@@ -485,7 +484,7 @@ def test_graph_fails_to_get_missing_subgraph_node():
     n1.graph = Graph()
     n1.graph.add_node
 
-    n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n1.graph.add_node(n1_1)
 
     g.add_node(n1)
@@ -499,7 +498,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
     n1.graph = Graph()
     n1.graph.add_node
 
-    n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n1.graph.add_node(n1_1)
 
     g.add_node(n1)
@@ -512,7 +511,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
 
 def test_graph_gets_networkx_graph():
     g = Graph()
-    n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n2 = UpscaleInvocation(id = "2")
     g.add_node(n1)
     g.add_node(n2)
@@ -529,7 +528,7 @@ def test_graph_gets_networkx_graph():
 # TODO: Graph serializes and deserializes
 def test_graph_can_serialize():
     g = Graph()
-    n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n2 = UpscaleInvocation(id = "2")
     g.add_node(n1)
     g.add_node(n2)
@@ -541,7 +540,7 @@ def test_graph_can_serialize():
 
 def test_graph_can_deserialize():
     g = Graph()
-    n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
+    n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
     n2 = UpscaleInvocation(id = "2")
     g.add_node(n1)
     g.add_node(n2)
diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py
index d16d67d815..af011954c5 100644
--- a/tests/nodes/test_nodes.py
+++ b/tests/nodes/test_nodes.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Literal
+from typing import Any, Callable, Literal, Union
 from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
 from invokeai.app.invocations.image import ImageField
 from invokeai.app.services.invocation_services import InvocationServices
@@ -43,14 +43,23 @@ class ImageTestInvocationOutput(BaseInvocationOutput):
 
     image: ImageField = Field()
 
-class ImageTestInvocation(BaseInvocation):
-    type: Literal['test_image'] = 'test_image'
+class TextToImageTestInvocation(BaseInvocation):
+    type: Literal['test_text_to_image'] = 'test_text_to_image'
 
     prompt: str = Field(default = "")
 
     def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
         return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
 
+class ImageToImageTestInvocation(BaseInvocation):
+    type: Literal['test_image_to_image'] = 'test_image_to_image'
+
+    prompt: str = Field(default = "")
+    image: Union[ImageField, None] = Field(default=None)
+
+    def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
+        return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
+
 class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
     type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output'
     collection: list[str] = Field(default_factory=list)
@@ -62,7 +71,6 @@ class PromptCollectionTestInvocation(BaseInvocation):
     def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
         return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
 
-
 from invokeai.app.services.events import EventServiceBase
 from invokeai.app.services.graph import Edge, EdgeConnection
 
diff --git a/tests/test_textual_inversion.py b/tests/test_textual_inversion.py
deleted file mode 100644
index 53d2f2bfe6..0000000000
--- a/tests/test_textual_inversion.py
+++ /dev/null
@@ -1,301 +0,0 @@
-
-import unittest
-from typing import Union
-
-import torch
-
-from invokeai.backend.stable_diffusion import TextualInversionManager
-
-
-KNOWN_WORDS = ['a', 'b', 'c']
-KNOWN_WORDS_TOKEN_IDS = [0, 1, 2]
-UNKNOWN_WORDS = ['d', 'e', 'f']
-
-class DummyEmbeddingsList(list):
-    def __getattr__(self, name):
-        if name == 'num_embeddings':
-            return len(self)
-        elif name == 'weight':
-            return self
-        elif name == 'data':
-            return self
-
-def make_dummy_embedding():
-    return torch.randn([768])
-
-class DummyTransformer:
-
-
-    def __init__(self):
-        self.embeddings = DummyEmbeddingsList([make_dummy_embedding() for _ in range(len(KNOWN_WORDS))])
-
-    def resize_token_embeddings(self, new_size=None):
-        if new_size is None:
-            return self.embeddings
-        else:
-            while len(self.embeddings) > new_size:
-                self.embeddings.pop(-1)
-            while len(self.embeddings) < new_size:
-                self.embeddings.append(make_dummy_embedding())
-
-    def get_input_embeddings(self):
-        return self.embeddings
-
-class DummyTokenizer():
-    def __init__(self):
-        self.tokens = KNOWN_WORDS.copy()
-        self.bos_token_id = 49406 # these are what the real CLIPTokenizer has
-        self.eos_token_id = 49407
-        self.pad_token_id = 49407
-        self.unk_token_id = 49407
-
-    def convert_tokens_to_ids(self, token_str):
-        try:
-            return self.tokens.index(token_str)
-        except ValueError:
-            return self.unk_token_id
-
-    def add_tokens(self, token_str):
-        if token_str in self.tokens:
-            return 0
-        self.tokens.append(token_str)
-        return 1
-
-
-class DummyClipEmbedder:
-    def __init__(self):
-        self.max_length = 77
-        self.transformer = DummyTransformer()
-        self.tokenizer = DummyTokenizer()
-        self.position_embeddings_tensor = torch.randn([77,768], dtype=torch.float32)
-
-    def position_embedding(self, indices: Union[list,torch.Tensor]):
-        if type(indices) is list:
-            indices = torch.tensor(indices, dtype=int)
-        return torch.index_select(self.position_embeddings_tensor, 0, indices)
-
-
-def was_embedding_overwritten_correctly(tim: TextualInversionManager, overwritten_embedding: torch.Tensor, ti_indices: list, ti_embedding: torch.Tensor) -> bool:
-    return torch.allclose(overwritten_embedding[ti_indices], ti_embedding + tim.clip_embedder.position_embedding(ti_indices))
-
-
-def make_dummy_textual_inversion_manager():
-    return TextualInversionManager(
-        tokenizer=DummyTokenizer(),
-        text_encoder=DummyTransformer()
-    )
-
-class TextualInversionManagerTestCase(unittest.TestCase):
-
-
-    def test_construction(self):
-        tim = make_dummy_textual_inversion_manager()
-
-    def test_add_embedding_for_known_token(self):
-        tim = make_dummy_textual_inversion_manager()
-        test_embedding = torch.randn([1, 768])
-        test_embedding_name = KNOWN_WORDS[0]
-        self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
-
-        pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
-
-        ti = tim._add_textual_inversion(test_embedding_name, test_embedding)
-        self.assertEqual(ti.trigger_token_id, 0)
-
-
-        # check adding 'test' did not create a new word
-        embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
-        self.assertEqual(pre_embeddings_count, embeddings_count)
-
-        # check it was added
-        self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
-        textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
-        self.assertIsNotNone(textual_inversion)
-        self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding))
-        self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
-        self.assertEqual(textual_inversion.trigger_token_id, ti.trigger_token_id)
-
-    def test_add_embedding_for_unknown_token(self):
-        tim = make_dummy_textual_inversion_manager()
-        test_embedding_1 = torch.randn([1, 768])
-        test_embedding_name_1 = UNKNOWN_WORDS[0]
-
-        pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
-
-        added_token_id_1 = tim._add_textual_inversion(test_embedding_name_1, test_embedding_1).trigger_token_id
-        # new token id should get added on the end
-        self.assertEqual(added_token_id_1, len(KNOWN_WORDS))
-
-        # check adding did create a new word
-        embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
-        self.assertEqual(pre_embeddings_count+1, embeddings_count)
-
-        # check it was added
-        self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1))
-        textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_1)
-        self.assertIsNotNone(textual_inversion)
-        self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1))
-        self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1)
-        self.assertEqual(textual_inversion.trigger_token_id, added_token_id_1)
-
-        # add another one
-        test_embedding_2 = torch.randn([1, 768])
-        test_embedding_name_2 = UNKNOWN_WORDS[1]
-
-        pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
-
-        added_token_id_2 = tim._add_textual_inversion(test_embedding_name_2, test_embedding_2).trigger_token_id
-        self.assertEqual(added_token_id_2, len(KNOWN_WORDS)+1)
-
-        # check adding did create a new word
-        embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
-        self.assertEqual(pre_embeddings_count+1, embeddings_count)
-
-        # check it was added
-        self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_2))
-        textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_2)
-        self.assertIsNotNone(textual_inversion)
-        self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_2))
-        self.assertEqual(textual_inversion.trigger_string, test_embedding_name_2)
-        self.assertEqual(textual_inversion.trigger_token_id, added_token_id_2)
-
-        # check the old one is still there
-        self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1))
-        textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_1)
-        self.assertIsNotNone(textual_inversion)
-        self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1))
-        self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1)
-        self.assertEqual(textual_inversion.trigger_token_id, added_token_id_1)
-
-
-    def test_pad_raises_on_eos_bos(self):
-        tim = make_dummy_textual_inversion_manager()
-        prompt_token_ids_with_eos_bos = [tim.tokenizer.bos_token_id] + \
-                                         [KNOWN_WORDS_TOKEN_IDS] + \
-                                         [tim.tokenizer.eos_token_id]
-        with self.assertRaises(ValueError):
-            tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_with_eos_bos)
-
-    def test_pad_tokens_list_vector_length_1(self):
-        tim = make_dummy_textual_inversion_manager()
-        prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
-
-        expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
-        self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
-
-        test_embedding_1v = torch.randn([1, 768])
-        test_embedding_1v_token = "<inversion-trigger-vector-length-1>"
-        test_embedding_1v_token_id = tim._add_textual_inversion(test_embedding_1v_token, test_embedding_1v).trigger_token_id
-        self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS))
-
-        # at the end
-        prompt_token_ids_1v_append = prompt_token_ids + [test_embedding_1v_token_id]
-        expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_append)
-        self.assertEqual(prompt_token_ids_1v_append, expanded_prompt_token_ids)
-
-        # at the start
-        prompt_token_ids_1v_prepend = [test_embedding_1v_token_id] + prompt_token_ids
-        expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_prepend)
-        self.assertEqual(prompt_token_ids_1v_prepend, expanded_prompt_token_ids)
-
-        # in the middle
-        prompt_token_ids_1v_insert = prompt_token_ids[0:2] + [test_embedding_1v_token_id] + prompt_token_ids[2:3]
-        expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_insert)
-        self.assertEqual(prompt_token_ids_1v_insert, expanded_prompt_token_ids)
-
-    def test_pad_tokens_list_vector_length_2(self):
-        tim = make_dummy_textual_inversion_manager()
-        prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
-
-        expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
-        self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
-
-        test_embedding_2v = torch.randn([2, 768])
-        test_embedding_2v_token = "<inversion-trigger-vector-length-2>"
-        test_embedding_2v_token_id = tim._add_textual_inversion(test_embedding_2v_token, test_embedding_2v).trigger_token_id
-        test_embedding_2v_pad_token_ids = tim.get_textual_inversion_for_token_id(test_embedding_2v_token_id).pad_token_ids
-        self.assertEqual(test_embedding_2v_token_id, len(KNOWN_WORDS))
-
-        # at the end
-        prompt_token_ids_2v_append = prompt_token_ids + [test_embedding_2v_token_id]
-        expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_append)
-        self.assertNotEqual(prompt_token_ids_2v_append, expanded_prompt_token_ids)
-        self.assertEqual(prompt_token_ids + [test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids, expanded_prompt_token_ids)
-
-        # at the start
-        prompt_token_ids_2v_prepend = [test_embedding_2v_token_id] + prompt_token_ids
-        expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_prepend)
-        self.assertNotEqual(prompt_token_ids_2v_prepend, expanded_prompt_token_ids)
-        self.assertEqual([test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids + prompt_token_ids, expanded_prompt_token_ids)
-
-        # in the middle
-        prompt_token_ids_2v_insert = prompt_token_ids[0:2] + [test_embedding_2v_token_id] + prompt_token_ids[2:3]
-        expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_insert)
-        self.assertNotEqual(prompt_token_ids_2v_insert, expanded_prompt_token_ids)
-        self.assertEqual(prompt_token_ids[0:2] + [test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids + prompt_token_ids[2:3], expanded_prompt_token_ids)
-
-    def test_pad_tokens_list_vector_length_8(self):
-        tim = make_dummy_textual_inversion_manager()
-        prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
-
-        expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
-        self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
-
-        test_embedding_8v = torch.randn([8, 768])
-        test_embedding_8v_token = "<inversion-trigger-vector-length-8>"
-        test_embedding_8v_token_id = tim._add_textual_inversion(test_embedding_8v_token, test_embedding_8v).trigger_token_id
-        test_embedding_8v_pad_token_ids = tim.get_textual_inversion_for_token_id(test_embedding_8v_token_id).pad_token_ids
-        self.assertEqual(test_embedding_8v_token_id, len(KNOWN_WORDS))
-
-        # at the end
-        prompt_token_ids_8v_append = prompt_token_ids + [test_embedding_8v_token_id]
-        expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_append)
-        self.assertNotEqual(prompt_token_ids_8v_append, expanded_prompt_token_ids)
-        self.assertEqual(prompt_token_ids + [test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids, expanded_prompt_token_ids)
-
-        # at the start
-        prompt_token_ids_8v_prepend = [test_embedding_8v_token_id] + prompt_token_ids
-        expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_prepend)
-        self.assertNotEqual(prompt_token_ids_8v_prepend, expanded_prompt_token_ids)
-        self.assertEqual([test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids + prompt_token_ids, expanded_prompt_token_ids)
-
-        # in the middle
-        prompt_token_ids_8v_insert = prompt_token_ids[0:2] + [test_embedding_8v_token_id] + prompt_token_ids[2:3]
-        expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_insert)
-        self.assertNotEqual(prompt_token_ids_8v_insert, expanded_prompt_token_ids)
-        self.assertEqual(prompt_token_ids[0:2] + [test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids + prompt_token_ids[2:3], expanded_prompt_token_ids)
-
-
-    def test_deferred_loading(self):
-        tim = make_dummy_textual_inversion_manager()
-        test_embedding = torch.randn([1, 768])
-        test_embedding_name = UNKNOWN_WORDS[0]
-        self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
-
-        pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
-
-        ti = tim._add_textual_inversion(test_embedding_name, test_embedding, defer_injecting_tokens=True)
-        self.assertIsNone(ti.trigger_token_id)
-
-        # check that a new word is not yet created
-        embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
-        self.assertEqual(pre_embeddings_count, embeddings_count)
-
-        # check it was added
-        self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
-        textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
-        self.assertIsNotNone(textual_inversion)
-        self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding))
-        self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
-        self.assertIsNone(textual_inversion.trigger_token_id, ti.trigger_token_id)
-
-        # check it lazy-loads
-        prompt = " ".join([KNOWN_WORDS[0], UNKNOWN_WORDS[0], KNOWN_WORDS[1]])
-        tim.create_deferred_token_ids_for_any_trigger_terms(prompt)
-
-        embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
-        self.assertEqual(pre_embeddings_count+1, embeddings_count)
-
-        textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
-        self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
-        self.assertEqual(textual_inversion.trigger_token_id, len(KNOWN_WORDS))