diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 10d1d91920..4e1da3b040 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from invokeai.app.services.board_images import BoardImagesServiceABC from invokeai.app.services.boards import BoardServiceABC from invokeai.app.services.images import ImageServiceABC - from invokeai.backend import ModelManager + from invokeai.app.services.model_manager_service import ModelManagerServiceBase from invokeai.app.services.events import EventServiceBase from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.restoration_services import RestorationServices @@ -22,46 +22,47 @@ class InvocationServices: """Services that can be used by invocations""" # TODO: Just forward-declared everything due to circular dependencies. Fix structure. - events: "EventServiceBase" - latents: "LatentsStorageBase" - queue: "InvocationQueueABC" - model_manager: "ModelManager" - restoration: "RestorationServices" - configuration: "InvokeAISettings" - images: "ImageServiceABC" - boards: "BoardServiceABC" board_images: "BoardImagesServiceABC" - graph_library: "ItemStorageABC"["LibraryGraph"] + boards: "BoardServiceABC" + configuration: "InvokeAISettings" + events: "EventServiceBase" graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] + graph_library: "ItemStorageABC"["LibraryGraph"] + images: "ImageServiceABC" + latents: "LatentsStorageBase" + logger: "Logger" + model_manager: "ModelManagerServiceBase" processor: "InvocationProcessorABC" + queue: "InvocationQueueABC" + restoration: "RestorationServices" def __init__( self, - model_manager: "ModelManager", - events: "EventServiceBase", - logger: "Logger", - latents: "LatentsStorageBase", - images: "ImageServiceABC", - boards: "BoardServiceABC", board_images: "BoardImagesServiceABC", - queue: "InvocationQueueABC", - graph_library: "ItemStorageABC"["LibraryGraph"], - graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], - processor: "InvocationProcessorABC", - restoration: "RestorationServices", + boards: "BoardServiceABC", configuration: "InvokeAISettings", + events: "EventServiceBase", + graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], + graph_library: "ItemStorageABC"["LibraryGraph"], + images: "ImageServiceABC", + latents: "LatentsStorageBase", + logger: "Logger", + model_manager: "ModelManagerServiceBase", + processor: "InvocationProcessorABC", + queue: "InvocationQueueABC", + restoration: "RestorationServices", ): - self.model_manager = model_manager - self.events = events - self.logger = logger - self.latents = latents - self.images = images - self.boards = boards self.board_images = board_images - self.queue = queue - self.graph_library = graph_library - self.graph_execution_manager = graph_execution_manager - self.processor = processor - self.restoration = restoration - self.configuration = configuration self.boards = boards + self.boards = boards + self.configuration = configuration + self.events = events + self.graph_execution_manager = graph_execution_manager + self.graph_library = graph_library + self.images = images + self.latents = latents + self.logger = logger + self.model_manager = model_manager + self.processor = processor + self.queue = queue + self.restoration = restoration diff --git a/invokeai/backend/install/legacy_arg_parsing.py b/invokeai/backend/install/legacy_arg_parsing.py index 4a58ff8336..684c50c77d 100644 --- a/invokeai/backend/install/legacy_arg_parsing.py +++ b/invokeai/backend/install/legacy_arg_parsing.py @@ -4,6 +4,8 @@ import argparse import shlex from argparse import ArgumentParser +# note that this includes both old sampler names and new scheduler names +# in order to be able to parse both 2.0 and 3.0-pre-nodes versions of invokeai.init SAMPLER_CHOICES = [ "ddim", "ddpm", @@ -27,6 +29,15 @@ SAMPLER_CHOICES = [ "dpmpp_sde", "dpmpp_sde_k", "unipc", + "k_dpm_2_a", + "k_dpm_2", + "k_dpmpp_2_a", + "k_dpmpp_2", + "k_euler_a", + "k_euler", + "k_heun", + "k_lms", + "plms", ] PRECISION_CHOICES = [ diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 06502b6c41..0000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest -from invokeai.app.services.invocation_services import InvocationServices -from invokeai.app.services.invocation_queue import MemoryInvocationQueue -from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory -from invokeai.app.services.graph import LibraryGraph, GraphExecutionState -from invokeai.app.services.processor import DefaultInvocationProcessor - -# Ignore these files as they need to be rewritten following the model manager refactor -collect_ignore = ["nodes/test_graph_execution_state.py", "nodes/test_node_graph.py", "test_textual_inversion.py"] - -@pytest.fixture(scope="session", autouse=True) -def mock_services(): - # NOTE: none of these are actually called by the test invocations - return InvocationServices( - model_manager = None, # type: ignore - events = None, # type: ignore - logger = None, # type: ignore - images = None, # type: ignore - latents = None, # type: ignore - board_images=None, # type: ignore - boards=None, # type: ignore - queue = MemoryInvocationQueue(), - graph_library=SqliteItemStorage[LibraryGraph]( - filename=sqlite_memory, table_name="graphs" - ), - graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor(), - restoration = None, # type: ignore - configuration = None, # type: ignore - ) diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index df8964da18..f34b18310b 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -1,41 +1,81 @@ -import pytest - -from invokeai.app.invocations.baseinvocation import (BaseInvocation, - BaseInvocationOutput, - InvocationContext) +from .test_invoker import create_edge +from .test_nodes import ( + TestEventService, + TextToImageTestInvocation, + PromptTestInvocation, + PromptCollectionTestInvocation, +) +from invokeai.app.services.invocation_queue import MemoryInvocationQueue +from invokeai.app.services.processor import DefaultInvocationProcessor +from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InvocationContext, +) from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation -from invokeai.app.services.graph import (CollectInvocation, Graph, - GraphExecutionState, - IterateInvocation) from invokeai.app.services.invocation_services import InvocationServices - -from .test_invoker import create_edge -from .test_nodes import (ImageTestInvocation, PromptCollectionTestInvocation, - PromptTestInvocation) +from invokeai.app.services.graph import ( + Graph, + CollectInvocation, + IterateInvocation, + GraphExecutionState, + LibraryGraph, +) +import pytest @pytest.fixture def simple_graph(): g = Graph() - g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi")) - g.add_node(ImageTestInvocation(id = "2")) + g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) + g.add_node(TextToImageTestInvocation(id="2")) g.add_edge(create_edge("1", "prompt", "2", "prompt")) return g -def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: + +# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types +# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate +# the test invocations. +@pytest.fixture +def mock_services() -> InvocationServices: + # NOTE: none of these are actually called by the test invocations + return InvocationServices( + model_manager = None, # type: ignore + events = TestEventService(), + logger = None, # type: ignore + images = None, # type: ignore + latents = None, # type: ignore + boards = None, # type: ignore + board_images = None, # type: ignore + queue = MemoryInvocationQueue(), + graph_library=SqliteItemStorage[LibraryGraph]( + filename=sqlite_memory, table_name="graphs" + ), + graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), + processor = DefaultInvocationProcessor(), + restoration = None, # type: ignore + configuration = None, # type: ignore + ) + + +def invoke_next( + g: GraphExecutionState, services: InvocationServices +) -> tuple[BaseInvocation, BaseInvocationOutput]: n = g.next() if n is None: return (None, None) - print(f'invoking {n.id}: {type(n)}') + print(f"invoking {n.id}: {type(n)}") o = n.invoke(InvocationContext(services, "1")) g.complete(n.id, o) return (n, o) + def test_graph_state_executes_in_order(simple_graph, mock_services): - g = GraphExecutionState(graph = simple_graph) + g = GraphExecutionState(graph=simple_graph) n1 = invoke_next(g, mock_services) n2 = invoke_next(g, mock_services) @@ -47,38 +87,42 @@ def test_graph_state_executes_in_order(simple_graph, mock_services): assert g.results[n1[0].id].prompt == n1[0].prompt assert n2[0].prompt == n1[0].prompt + def test_graph_is_complete(simple_graph, mock_services): - g = GraphExecutionState(graph = simple_graph) + g = GraphExecutionState(graph=simple_graph) n1 = invoke_next(g, mock_services) n2 = invoke_next(g, mock_services) n3 = g.next() assert g.is_complete() + def test_graph_is_not_complete(simple_graph, mock_services): - g = GraphExecutionState(graph = simple_graph) + g = GraphExecutionState(graph=simple_graph) n1 = invoke_next(g, mock_services) n2 = g.next() assert not g.is_complete() + # TODO: test completion with iterators/subgraphs + def test_graph_state_expands_iterator(mock_services): graph = Graph() - graph.add_node(RangeInvocation(id = "0", start = 0, stop = 3, step = 1)) - graph.add_node(IterateInvocation(id = "1")) - graph.add_node(MultiplyInvocation(id = "2", b = 10)) - graph.add_node(AddInvocation(id = "3", b = 1)) + graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1)) + graph.add_node(IterateInvocation(id="1")) + graph.add_node(MultiplyInvocation(id="2", b=10)) + graph.add_node(AddInvocation(id="3", b=1)) graph.add_edge(create_edge("0", "collection", "1", "collection")) graph.add_edge(create_edge("1", "item", "2", "a")) graph.add_edge(create_edge("2", "a", "3", "a")) - g = GraphExecutionState(graph = graph) + g = GraphExecutionState(graph=graph) while not g.is_complete(): invoke_next(g, mock_services) - prepared_add_nodes = g.source_prepared_mapping['3'] + prepared_add_nodes = g.source_prepared_mapping["3"] results = set([g.results[n].a for n in prepared_add_nodes]) expected = set([1, 11, 21]) assert results == expected @@ -87,15 +131,17 @@ def test_graph_state_expands_iterator(mock_services): def test_graph_state_collects(mock_services): graph = Graph() test_prompts = ["Banana sushi", "Cat sushi"] - graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts))) - graph.add_node(IterateInvocation(id = "2")) - graph.add_node(PromptTestInvocation(id = "3")) - graph.add_node(CollectInvocation(id = "4")) + graph.add_node( + PromptCollectionTestInvocation(id="1", collection=list(test_prompts)) + ) + graph.add_node(IterateInvocation(id="2")) + graph.add_node(PromptTestInvocation(id="3")) + graph.add_node(CollectInvocation(id="4")) graph.add_edge(create_edge("1", "collection", "2", "collection")) graph.add_edge(create_edge("2", "item", "3", "prompt")) graph.add_edge(create_edge("3", "prompt", "4", "item")) - g = GraphExecutionState(graph = graph) + g = GraphExecutionState(graph=graph) n1 = invoke_next(g, mock_services) n2 = invoke_next(g, mock_services) n3 = invoke_next(g, mock_services) @@ -113,10 +159,16 @@ def test_graph_state_prepares_eagerly(mock_services): graph = Graph() test_prompts = ["Banana sushi", "Cat sushi"] - graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts))) + graph.add_node( + PromptCollectionTestInvocation( + id="prompt_collection", collection=list(test_prompts) + ) + ) graph.add_node(IterateInvocation(id="iterate")) graph.add_node(PromptTestInvocation(id="prompt_iterated")) - graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection")) + graph.add_edge( + create_edge("prompt_collection", "collection", "iterate", "collection") + ) graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt")) # separated, fully-preparable chain of nodes @@ -142,13 +194,21 @@ def test_graph_executes_depth_first(mock_services): graph = Graph() test_prompts = ["Banana sushi", "Cat sushi"] - graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts))) + graph.add_node( + PromptCollectionTestInvocation( + id="prompt_collection", collection=list(test_prompts) + ) + ) graph.add_node(IterateInvocation(id="iterate")) graph.add_node(PromptTestInvocation(id="prompt_iterated")) graph.add_node(PromptTestInvocation(id="prompt_successor")) - graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection")) + graph.add_edge( + create_edge("prompt_collection", "collection", "iterate", "collection") + ) graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt")) - graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt")) + graph.add_edge( + create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt") + ) g = GraphExecutionState(graph=graph) n1 = invoke_next(g, mock_services) diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 4331e62d21..19d7dd20b3 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -1,26 +1,62 @@ -import pytest - -from invokeai.app.services.graph import Graph, GraphExecutionState -from invokeai.app.services.invocation_services import InvocationServices +from .test_nodes import ( + TestEventService, + ErrorInvocation, + TextToImageTestInvocation, + PromptTestInvocation, + create_edge, + wait_until, +) +from invokeai.app.services.invocation_queue import MemoryInvocationQueue +from invokeai.app.services.processor import DefaultInvocationProcessor +from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.invoker import Invoker - -from .test_nodes import (ErrorInvocation, ImageTestInvocation, - PromptTestInvocation, create_edge, wait_until) +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.graph import ( + Graph, + GraphExecutionState, + LibraryGraph, +) +import pytest @pytest.fixture def simple_graph(): g = Graph() - g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi")) - g.add_node(ImageTestInvocation(id = "2")) + g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) + g.add_node(TextToImageTestInvocation(id="2")) g.add_edge(create_edge("1", "prompt", "2", "prompt")) return g + +# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types +# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate +# the test invocations. +@pytest.fixture +def mock_services() -> InvocationServices: + # NOTE: none of these are actually called by the test invocations + return InvocationServices( + model_manager = None, # type: ignore + events = TestEventService(), + logger = None, # type: ignore + images = None, # type: ignore + latents = None, # type: ignore + boards = None, # type: ignore + board_images = None, # type: ignore + queue = MemoryInvocationQueue(), + graph_library=SqliteItemStorage[LibraryGraph]( + filename=sqlite_memory, table_name="graphs" + ), + graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), + processor = DefaultInvocationProcessor(), + restoration = None, # type: ignore + configuration = None, # type: ignore + ) + + @pytest.fixture() def mock_invoker(mock_services: InvocationServices) -> Invoker: - return Invoker( - services = mock_services - ) + return Invoker(services=mock_services) + def test_can_create_graph_state(mock_invoker: Invoker): g = mock_invoker.create_execution_state() @@ -29,17 +65,19 @@ def test_can_create_graph_state(mock_invoker: Invoker): assert g is not None assert isinstance(g, GraphExecutionState) + def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph = simple_graph) + g = mock_invoker.create_execution_state(graph=simple_graph) mock_invoker.stop() assert g is not None assert isinstance(g, GraphExecutionState) assert g.graph == simple_graph -@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") + +# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") def test_can_invoke(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph = simple_graph) + g = mock_invoker.create_execution_state(graph=simple_graph) invocation_id = mock_invoker.invoke(g) assert invocation_id is not None @@ -47,32 +85,34 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph): g = mock_invoker.services.graph_execution_manager.get(g.id) return len(g.executed) > 0 - wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1) + wait_until(lambda: has_executed_any(g), timeout=5, interval=1) mock_invoker.stop() g = mock_invoker.services.graph_execution_manager.get(g.id) assert len(g.executed) > 0 -@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") + +# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") def test_can_invoke_all(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph = simple_graph) - invocation_id = mock_invoker.invoke(g, invoke_all = True) + g = mock_invoker.create_execution_state(graph=simple_graph) + invocation_id = mock_invoker.invoke(g, invoke_all=True) assert invocation_id is not None def has_executed_all(g: GraphExecutionState): g = mock_invoker.services.graph_execution_manager.get(g.id) return g.is_complete() - wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1) + wait_until(lambda: has_executed_all(g), timeout=5, interval=1) mock_invoker.stop() g = mock_invoker.services.graph_execution_manager.get(g.id) assert g.is_complete() -@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") + +# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") def test_handles_errors(mock_invoker: Invoker): g = mock_invoker.create_execution_state() - g.graph.add_node(ErrorInvocation(id = "1")) + g.graph.add_node(ErrorInvocation(id="1")) mock_invoker.invoke(g, invoke_all=True) @@ -80,11 +120,11 @@ def test_handles_errors(mock_invoker: Invoker): g = mock_invoker.services.graph_execution_manager.get(g.id) return g.is_complete() - wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1) + wait_until(lambda: has_executed_all(g), timeout=5, interval=1) mock_invoker.stop() g = mock_invoker.services.graph_execution_manager.get(g.id) assert g.has_error() assert g.is_complete() - assert all((i in g.errors for i in g.source_prepared_mapping['1'])) + assert all((i in g.errors for i in g.source_prepared_mapping["1"])) diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index 82818414b2..df7378150d 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -1,6 +1,5 @@ -from .test_nodes import ListPassThroughInvocation, PromptTestInvocation +from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation -from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation from invokeai.app.invocations.upscale import UpscaleInvocation from invokeai.app.invocations.image import * from invokeai.app.invocations.math import AddInvocation, SubtractInvocation @@ -18,7 +17,7 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg # Tests def test_connections_are_compatible(): - from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") + from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") from_field = "image" to_node = UpscaleInvocation(id = "2") to_field = "image" @@ -28,7 +27,7 @@ def test_connections_are_compatible(): assert result == True def test_connections_are_incompatible(): - from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") + from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") from_field = "image" to_node = UpscaleInvocation(id = "2") to_field = "strength" @@ -38,7 +37,7 @@ def test_connections_are_incompatible(): assert result == False def test_connections_incompatible_with_invalid_fields(): - from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") + from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") from_field = "invalid_field" to_node = UpscaleInvocation(id = "2") to_field = "image" @@ -56,28 +55,28 @@ def test_connections_incompatible_with_invalid_fields(): def test_graph_can_add_node(): g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.add_node(n) assert n.id in g.nodes def test_graph_fails_to_add_node_with_duplicate_id(): g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.add_node(n) - n2 = TextToImageInvocation(id = "1", prompt = "Banana sushi the second") + n2 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi the second") with pytest.raises(NodeAlreadyInGraphError): g.add_node(n2) def test_graph_updates_node(): g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.add_node(n) - n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second") + n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second") g.add_node(n2) - nu = TextToImageInvocation(id = "1", prompt = "Banana sushi updated") + nu = TextToImageTestInvocation(id = "1", prompt = "Banana sushi updated") g.update_node("1", nu) @@ -85,7 +84,7 @@ def test_graph_updates_node(): def test_graph_fails_to_update_node_if_type_changes(): g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.add_node(n) n2 = UpscaleInvocation(id = "2") g.add_node(n2) @@ -97,14 +96,14 @@ def test_graph_fails_to_update_node_if_type_changes(): def test_graph_allows_non_conflicting_id_change(): g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.add_node(n) n2 = UpscaleInvocation(id = "2") g.add_node(n2) e1 = create_edge(n.id,"image",n2.id,"image") g.add_edge(e1) - nu = TextToImageInvocation(id = "3", prompt = "Banana sushi") + nu = TextToImageTestInvocation(id = "3", prompt = "Banana sushi") g.update_node("1", nu) with pytest.raises(NodeNotFoundError): @@ -117,18 +116,18 @@ def test_graph_allows_non_conflicting_id_change(): def test_graph_fails_to_update_node_id_if_conflict(): g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.add_node(n) - n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second") + n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second") g.add_node(n2) - nu = TextToImageInvocation(id = "2", prompt = "Banana sushi") + nu = TextToImageTestInvocation(id = "2", prompt = "Banana sushi") with pytest.raises(NodeAlreadyInGraphError): g.update_node("1", nu) def test_graph_adds_edge(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.add_node(n1) g.add_node(n2) @@ -148,7 +147,7 @@ def test_graph_fails_to_add_edge_with_cycle(): def test_graph_fails_to_add_edge_with_long_cycle(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") n3 = UpscaleInvocation(id = "3") g.add_node(n1) @@ -164,7 +163,7 @@ def test_graph_fails_to_add_edge_with_long_cycle(): def test_graph_fails_to_add_edge_with_missing_node_id(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.add_node(n1) g.add_node(n2) @@ -177,7 +176,7 @@ def test_graph_fails_to_add_edge_with_missing_node_id(): def test_graph_fails_to_add_edge_when_destination_exists(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") n3 = UpscaleInvocation(id = "3") g.add_node(n1) @@ -194,7 +193,7 @@ def test_graph_fails_to_add_edge_when_destination_exists(): def test_graph_fails_to_add_edge_with_mismatched_types(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.add_node(n1) g.add_node(n2) @@ -204,8 +203,8 @@ def test_graph_fails_to_add_edge_with_mismatched_types(): def test_graph_connects_collector(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi 2") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi 2") n3 = CollectInvocation(id = "3") n4 = ListPassThroughInvocation(id = "4") g.add_node(n1) @@ -224,7 +223,7 @@ def test_graph_connects_collector(): def test_graph_collector_invalid_with_varying_input_types(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2") n3 = CollectInvocation(id = "3") g.add_node(n1) @@ -282,7 +281,7 @@ def test_graph_connects_iterator(): g = Graph() n1 = ListPassThroughInvocation(id = "1") n2 = IterateInvocation(id = "2") - n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi") + n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi") g.add_node(n1) g.add_node(n2) g.add_node(n3) @@ -298,7 +297,7 @@ def test_graph_iterator_invalid_if_multiple_inputs(): g = Graph() n1 = ListPassThroughInvocation(id = "1") n2 = IterateInvocation(id = "2") - n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi") + n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi") n4 = ListPassThroughInvocation(id = "4") g.add_node(n1) g.add_node(n2) @@ -316,7 +315,7 @@ def test_graph_iterator_invalid_if_multiple_inputs(): def test_graph_iterator_invalid_if_input_not_list(): g = Graph() - n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = IterateInvocation(id = "2") g.add_node(n1) g.add_node(n2) @@ -344,7 +343,7 @@ def test_graph_iterator_invalid_if_output_and_input_types_different(): def test_graph_validates(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.add_node(n1) g.add_node(n2) @@ -355,7 +354,7 @@ def test_graph_validates(): def test_graph_invalid_if_edges_reference_missing_nodes(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.nodes[n1.id] = n1 e1 = create_edge("1","image","2","image") g.edges.append(e1) @@ -367,7 +366,7 @@ def test_graph_invalid_if_subgraph_invalid(): n1 = GraphInvocation(id = "1") n1.graph = Graph() - n1_1 = TextToImageInvocation(id = "2", prompt = "Banana sushi") + n1_1 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi") n1.graph.nodes[n1_1.id] = n1_1 e1 = create_edge("1","image","2","image") n1.graph.edges.append(e1) @@ -391,7 +390,7 @@ def test_graph_invalid_if_has_cycle(): def test_graph_invalid_with_invalid_connection(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.nodes[n1.id] = n1 g.nodes[n2.id] = n2 @@ -408,7 +407,7 @@ def test_graph_gets_subgraph_node(): n1.graph = Graph() n1.graph.add_node - n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1.graph.add_node(n1_1) g.add_node(n1) @@ -485,7 +484,7 @@ def test_graph_fails_to_get_missing_subgraph_node(): n1.graph = Graph() n1.graph.add_node - n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1.graph.add_node(n1_1) g.add_node(n1) @@ -499,7 +498,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node(): n1.graph = Graph() n1.graph.add_node - n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1.graph.add_node(n1_1) g.add_node(n1) @@ -512,7 +511,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node(): def test_graph_gets_networkx_graph(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.add_node(n1) g.add_node(n2) @@ -529,7 +528,7 @@ def test_graph_gets_networkx_graph(): # TODO: Graph serializes and deserializes def test_graph_can_serialize(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.add_node(n1) g.add_node(n2) @@ -541,7 +540,7 @@ def test_graph_can_serialize(): def test_graph_can_deserialize(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.add_node(n1) g.add_node(n2) diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py index d16d67d815..af011954c5 100644 --- a/tests/nodes/test_nodes.py +++ b/tests/nodes/test_nodes.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, Union from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.image import ImageField from invokeai.app.services.invocation_services import InvocationServices @@ -43,14 +43,23 @@ class ImageTestInvocationOutput(BaseInvocationOutput): image: ImageField = Field() -class ImageTestInvocation(BaseInvocation): - type: Literal['test_image'] = 'test_image' +class TextToImageTestInvocation(BaseInvocation): + type: Literal['test_text_to_image'] = 'test_text_to_image' prompt: str = Field(default = "") def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) +class ImageToImageTestInvocation(BaseInvocation): + type: Literal['test_image_to_image'] = 'test_image_to_image' + + prompt: str = Field(default = "") + image: Union[ImageField, None] = Field(default=None) + + def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: + return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) + class PromptCollectionTestInvocationOutput(BaseInvocationOutput): type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output' collection: list[str] = Field(default_factory=list) @@ -62,7 +71,6 @@ class PromptCollectionTestInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) - from invokeai.app.services.events import EventServiceBase from invokeai.app.services.graph import Edge, EdgeConnection diff --git a/tests/test_textual_inversion.py b/tests/test_textual_inversion.py deleted file mode 100644 index 53d2f2bfe6..0000000000 --- a/tests/test_textual_inversion.py +++ /dev/null @@ -1,301 +0,0 @@ - -import unittest -from typing import Union - -import torch - -from invokeai.backend.stable_diffusion import TextualInversionManager - - -KNOWN_WORDS = ['a', 'b', 'c'] -KNOWN_WORDS_TOKEN_IDS = [0, 1, 2] -UNKNOWN_WORDS = ['d', 'e', 'f'] - -class DummyEmbeddingsList(list): - def __getattr__(self, name): - if name == 'num_embeddings': - return len(self) - elif name == 'weight': - return self - elif name == 'data': - return self - -def make_dummy_embedding(): - return torch.randn([768]) - -class DummyTransformer: - - - def __init__(self): - self.embeddings = DummyEmbeddingsList([make_dummy_embedding() for _ in range(len(KNOWN_WORDS))]) - - def resize_token_embeddings(self, new_size=None): - if new_size is None: - return self.embeddings - else: - while len(self.embeddings) > new_size: - self.embeddings.pop(-1) - while len(self.embeddings) < new_size: - self.embeddings.append(make_dummy_embedding()) - - def get_input_embeddings(self): - return self.embeddings - -class DummyTokenizer(): - def __init__(self): - self.tokens = KNOWN_WORDS.copy() - self.bos_token_id = 49406 # these are what the real CLIPTokenizer has - self.eos_token_id = 49407 - self.pad_token_id = 49407 - self.unk_token_id = 49407 - - def convert_tokens_to_ids(self, token_str): - try: - return self.tokens.index(token_str) - except ValueError: - return self.unk_token_id - - def add_tokens(self, token_str): - if token_str in self.tokens: - return 0 - self.tokens.append(token_str) - return 1 - - -class DummyClipEmbedder: - def __init__(self): - self.max_length = 77 - self.transformer = DummyTransformer() - self.tokenizer = DummyTokenizer() - self.position_embeddings_tensor = torch.randn([77,768], dtype=torch.float32) - - def position_embedding(self, indices: Union[list,torch.Tensor]): - if type(indices) is list: - indices = torch.tensor(indices, dtype=int) - return torch.index_select(self.position_embeddings_tensor, 0, indices) - - -def was_embedding_overwritten_correctly(tim: TextualInversionManager, overwritten_embedding: torch.Tensor, ti_indices: list, ti_embedding: torch.Tensor) -> bool: - return torch.allclose(overwritten_embedding[ti_indices], ti_embedding + tim.clip_embedder.position_embedding(ti_indices)) - - -def make_dummy_textual_inversion_manager(): - return TextualInversionManager( - tokenizer=DummyTokenizer(), - text_encoder=DummyTransformer() - ) - -class TextualInversionManagerTestCase(unittest.TestCase): - - - def test_construction(self): - tim = make_dummy_textual_inversion_manager() - - def test_add_embedding_for_known_token(self): - tim = make_dummy_textual_inversion_manager() - test_embedding = torch.randn([1, 768]) - test_embedding_name = KNOWN_WORDS[0] - self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name)) - - pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - - ti = tim._add_textual_inversion(test_embedding_name, test_embedding) - self.assertEqual(ti.trigger_token_id, 0) - - - # check adding 'test' did not create a new word - embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - self.assertEqual(pre_embeddings_count, embeddings_count) - - # check it was added - self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name)) - textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name) - self.assertIsNotNone(textual_inversion) - self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding)) - self.assertEqual(textual_inversion.trigger_string, test_embedding_name) - self.assertEqual(textual_inversion.trigger_token_id, ti.trigger_token_id) - - def test_add_embedding_for_unknown_token(self): - tim = make_dummy_textual_inversion_manager() - test_embedding_1 = torch.randn([1, 768]) - test_embedding_name_1 = UNKNOWN_WORDS[0] - - pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - - added_token_id_1 = tim._add_textual_inversion(test_embedding_name_1, test_embedding_1).trigger_token_id - # new token id should get added on the end - self.assertEqual(added_token_id_1, len(KNOWN_WORDS)) - - # check adding did create a new word - embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - self.assertEqual(pre_embeddings_count+1, embeddings_count) - - # check it was added - self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1)) - textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_1) - self.assertIsNotNone(textual_inversion) - self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1)) - self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1) - self.assertEqual(textual_inversion.trigger_token_id, added_token_id_1) - - # add another one - test_embedding_2 = torch.randn([1, 768]) - test_embedding_name_2 = UNKNOWN_WORDS[1] - - pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - - added_token_id_2 = tim._add_textual_inversion(test_embedding_name_2, test_embedding_2).trigger_token_id - self.assertEqual(added_token_id_2, len(KNOWN_WORDS)+1) - - # check adding did create a new word - embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - self.assertEqual(pre_embeddings_count+1, embeddings_count) - - # check it was added - self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_2)) - textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_2) - self.assertIsNotNone(textual_inversion) - self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_2)) - self.assertEqual(textual_inversion.trigger_string, test_embedding_name_2) - self.assertEqual(textual_inversion.trigger_token_id, added_token_id_2) - - # check the old one is still there - self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1)) - textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_1) - self.assertIsNotNone(textual_inversion) - self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1)) - self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1) - self.assertEqual(textual_inversion.trigger_token_id, added_token_id_1) - - - def test_pad_raises_on_eos_bos(self): - tim = make_dummy_textual_inversion_manager() - prompt_token_ids_with_eos_bos = [tim.tokenizer.bos_token_id] + \ - [KNOWN_WORDS_TOKEN_IDS] + \ - [tim.tokenizer.eos_token_id] - with self.assertRaises(ValueError): - tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_with_eos_bos) - - def test_pad_tokens_list_vector_length_1(self): - tim = make_dummy_textual_inversion_manager() - prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy() - - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids) - self.assertEqual(prompt_token_ids, expanded_prompt_token_ids) - - test_embedding_1v = torch.randn([1, 768]) - test_embedding_1v_token = "" - test_embedding_1v_token_id = tim._add_textual_inversion(test_embedding_1v_token, test_embedding_1v).trigger_token_id - self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS)) - - # at the end - prompt_token_ids_1v_append = prompt_token_ids + [test_embedding_1v_token_id] - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_append) - self.assertEqual(prompt_token_ids_1v_append, expanded_prompt_token_ids) - - # at the start - prompt_token_ids_1v_prepend = [test_embedding_1v_token_id] + prompt_token_ids - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_prepend) - self.assertEqual(prompt_token_ids_1v_prepend, expanded_prompt_token_ids) - - # in the middle - prompt_token_ids_1v_insert = prompt_token_ids[0:2] + [test_embedding_1v_token_id] + prompt_token_ids[2:3] - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_insert) - self.assertEqual(prompt_token_ids_1v_insert, expanded_prompt_token_ids) - - def test_pad_tokens_list_vector_length_2(self): - tim = make_dummy_textual_inversion_manager() - prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy() - - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids) - self.assertEqual(prompt_token_ids, expanded_prompt_token_ids) - - test_embedding_2v = torch.randn([2, 768]) - test_embedding_2v_token = "" - test_embedding_2v_token_id = tim._add_textual_inversion(test_embedding_2v_token, test_embedding_2v).trigger_token_id - test_embedding_2v_pad_token_ids = tim.get_textual_inversion_for_token_id(test_embedding_2v_token_id).pad_token_ids - self.assertEqual(test_embedding_2v_token_id, len(KNOWN_WORDS)) - - # at the end - prompt_token_ids_2v_append = prompt_token_ids + [test_embedding_2v_token_id] - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_append) - self.assertNotEqual(prompt_token_ids_2v_append, expanded_prompt_token_ids) - self.assertEqual(prompt_token_ids + [test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids, expanded_prompt_token_ids) - - # at the start - prompt_token_ids_2v_prepend = [test_embedding_2v_token_id] + prompt_token_ids - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_prepend) - self.assertNotEqual(prompt_token_ids_2v_prepend, expanded_prompt_token_ids) - self.assertEqual([test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids + prompt_token_ids, expanded_prompt_token_ids) - - # in the middle - prompt_token_ids_2v_insert = prompt_token_ids[0:2] + [test_embedding_2v_token_id] + prompt_token_ids[2:3] - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_insert) - self.assertNotEqual(prompt_token_ids_2v_insert, expanded_prompt_token_ids) - self.assertEqual(prompt_token_ids[0:2] + [test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids + prompt_token_ids[2:3], expanded_prompt_token_ids) - - def test_pad_tokens_list_vector_length_8(self): - tim = make_dummy_textual_inversion_manager() - prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy() - - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids) - self.assertEqual(prompt_token_ids, expanded_prompt_token_ids) - - test_embedding_8v = torch.randn([8, 768]) - test_embedding_8v_token = "" - test_embedding_8v_token_id = tim._add_textual_inversion(test_embedding_8v_token, test_embedding_8v).trigger_token_id - test_embedding_8v_pad_token_ids = tim.get_textual_inversion_for_token_id(test_embedding_8v_token_id).pad_token_ids - self.assertEqual(test_embedding_8v_token_id, len(KNOWN_WORDS)) - - # at the end - prompt_token_ids_8v_append = prompt_token_ids + [test_embedding_8v_token_id] - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_append) - self.assertNotEqual(prompt_token_ids_8v_append, expanded_prompt_token_ids) - self.assertEqual(prompt_token_ids + [test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids, expanded_prompt_token_ids) - - # at the start - prompt_token_ids_8v_prepend = [test_embedding_8v_token_id] + prompt_token_ids - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_prepend) - self.assertNotEqual(prompt_token_ids_8v_prepend, expanded_prompt_token_ids) - self.assertEqual([test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids + prompt_token_ids, expanded_prompt_token_ids) - - # in the middle - prompt_token_ids_8v_insert = prompt_token_ids[0:2] + [test_embedding_8v_token_id] + prompt_token_ids[2:3] - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_insert) - self.assertNotEqual(prompt_token_ids_8v_insert, expanded_prompt_token_ids) - self.assertEqual(prompt_token_ids[0:2] + [test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids + prompt_token_ids[2:3], expanded_prompt_token_ids) - - - def test_deferred_loading(self): - tim = make_dummy_textual_inversion_manager() - test_embedding = torch.randn([1, 768]) - test_embedding_name = UNKNOWN_WORDS[0] - self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name)) - - pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - - ti = tim._add_textual_inversion(test_embedding_name, test_embedding, defer_injecting_tokens=True) - self.assertIsNone(ti.trigger_token_id) - - # check that a new word is not yet created - embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - self.assertEqual(pre_embeddings_count, embeddings_count) - - # check it was added - self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name)) - textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name) - self.assertIsNotNone(textual_inversion) - self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding)) - self.assertEqual(textual_inversion.trigger_string, test_embedding_name) - self.assertIsNone(textual_inversion.trigger_token_id, ti.trigger_token_id) - - # check it lazy-loads - prompt = " ".join([KNOWN_WORDS[0], UNKNOWN_WORDS[0], KNOWN_WORDS[1]]) - tim.create_deferred_token_ids_for_any_trigger_terms(prompt) - - embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - self.assertEqual(pre_embeddings_count+1, embeddings_count) - - textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name) - self.assertEqual(textual_inversion.trigger_string, test_embedding_name) - self.assertEqual(textual_inversion.trigger_token_id, len(KNOWN_WORDS))