Merge branch 'main' into lstein/fix-vae-conversion-crash

This commit is contained in:
Lincoln Stein 2023-07-03 14:03:13 -04:00 committed by GitHub
commit cfd09214d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 253 additions and 465 deletions

View File

@ -7,7 +7,7 @@ if TYPE_CHECKING:
from invokeai.app.services.board_images import BoardImagesServiceABC from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC from invokeai.app.services.images import ImageServiceABC
from invokeai.backend import ModelManager from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.restoration_services import RestorationServices from invokeai.app.services.restoration_services import RestorationServices
@ -22,46 +22,47 @@ class InvocationServices:
"""Services that can be used by invocations""" """Services that can be used by invocations"""
# TODO: Just forward-declared everything due to circular dependencies. Fix structure. # TODO: Just forward-declared everything due to circular dependencies. Fix structure.
events: "EventServiceBase"
latents: "LatentsStorageBase"
queue: "InvocationQueueABC"
model_manager: "ModelManager"
restoration: "RestorationServices"
configuration: "InvokeAISettings"
images: "ImageServiceABC"
boards: "BoardServiceABC"
board_images: "BoardImagesServiceABC" board_images: "BoardImagesServiceABC"
graph_library: "ItemStorageABC"["LibraryGraph"] boards: "BoardServiceABC"
configuration: "InvokeAISettings"
events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
graph_library: "ItemStorageABC"["LibraryGraph"]
images: "ImageServiceABC"
latents: "LatentsStorageBase"
logger: "Logger"
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
queue: "InvocationQueueABC"
restoration: "RestorationServices"
def __init__( def __init__(
self, self,
model_manager: "ModelManager",
events: "EventServiceBase",
logger: "Logger",
latents: "LatentsStorageBase",
images: "ImageServiceABC",
boards: "BoardServiceABC",
board_images: "BoardImagesServiceABC", board_images: "BoardImagesServiceABC",
queue: "InvocationQueueABC", boards: "BoardServiceABC",
graph_library: "ItemStorageABC"["LibraryGraph"],
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: "RestorationServices",
configuration: "InvokeAISettings", configuration: "InvokeAISettings",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
graph_library: "ItemStorageABC"["LibraryGraph"],
images: "ImageServiceABC",
latents: "LatentsStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC",
queue: "InvocationQueueABC",
restoration: "RestorationServices",
): ):
self.model_manager = model_manager
self.events = events
self.logger = logger
self.latents = latents
self.images = images
self.boards = boards
self.board_images = board_images self.board_images = board_images
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration
self.configuration = configuration
self.boards = boards self.boards = boards
self.boards = boards
self.configuration = configuration
self.events = events
self.graph_execution_manager = graph_execution_manager
self.graph_library = graph_library
self.images = images
self.latents = latents
self.logger = logger
self.model_manager = model_manager
self.processor = processor
self.queue = queue
self.restoration = restoration

View File

@ -4,6 +4,8 @@ import argparse
import shlex import shlex
from argparse import ArgumentParser from argparse import ArgumentParser
# note that this includes both old sampler names and new scheduler names
# in order to be able to parse both 2.0 and 3.0-pre-nodes versions of invokeai.init
SAMPLER_CHOICES = [ SAMPLER_CHOICES = [
"ddim", "ddim",
"ddpm", "ddpm",
@ -27,6 +29,15 @@ SAMPLER_CHOICES = [
"dpmpp_sde", "dpmpp_sde",
"dpmpp_sde_k", "dpmpp_sde_k",
"unipc", "unipc",
"k_dpm_2_a",
"k_dpm_2",
"k_dpmpp_2_a",
"k_dpmpp_2",
"k_euler_a",
"k_euler",
"k_heun",
"k_lms",
"plms",
] ]
PRECISION_CHOICES = [ PRECISION_CHOICES = [

View File

@ -1,30 +0,0 @@
import pytest
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.graph import LibraryGraph, GraphExecutionState
from invokeai.app.services.processor import DefaultInvocationProcessor
# Ignore these files as they need to be rewritten following the model manager refactor
collect_ignore = ["nodes/test_graph_execution_state.py", "nodes/test_node_graph.py", "test_textual_inversion.py"]
@pytest.fixture(scope="session", autouse=True)
def mock_services():
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
model_manager = None, # type: ignore
events = None, # type: ignore
logger = None, # type: ignore
images = None, # type: ignore
latents = None, # type: ignore
board_images=None, # type: ignore
boards=None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)

View File

@ -1,39 +1,79 @@
import pytest from .test_invoker import create_edge
from .test_nodes import (
from invokeai.app.invocations.baseinvocation import (BaseInvocation, TestEventService,
TextToImageTestInvocation,
PromptTestInvocation,
PromptCollectionTestInvocation,
)
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
InvocationContext) InvocationContext,
)
from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.graph import (CollectInvocation, Graph,
GraphExecutionState,
IterateInvocation)
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.graph import (
from .test_invoker import create_edge Graph,
from .test_nodes import (ImageTestInvocation, PromptCollectionTestInvocation, CollectInvocation,
PromptTestInvocation) IterateInvocation,
GraphExecutionState,
LibraryGraph,
)
import pytest
@pytest.fixture @pytest.fixture
def simple_graph(): def simple_graph():
g = Graph() g = Graph()
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
g.add_node(ImageTestInvocation(id = "2")) g.add_node(TextToImageTestInvocation(id="2"))
g.add_edge(create_edge("1", "prompt", "2", "prompt")) g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g return g
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
model_manager = None, # type: ignore
events = TestEventService(),
logger = None, # type: ignore
images = None, # type: ignore
latents = None, # type: ignore
boards = None, # type: ignore
board_images = None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)
def invoke_next(
g: GraphExecutionState, services: InvocationServices
) -> tuple[BaseInvocation, BaseInvocationOutput]:
n = g.next() n = g.next()
if n is None: if n is None:
return (None, None) return (None, None)
print(f'invoking {n.id}: {type(n)}') print(f"invoking {n.id}: {type(n)}")
o = n.invoke(InvocationContext(services, "1")) o = n.invoke(InvocationContext(services, "1"))
g.complete(n.id, o) g.complete(n.id, o)
return (n, o) return (n, o)
def test_graph_state_executes_in_order(simple_graph, mock_services): def test_graph_state_executes_in_order(simple_graph, mock_services):
g = GraphExecutionState(graph=simple_graph) g = GraphExecutionState(graph=simple_graph)
@ -47,6 +87,7 @@ def test_graph_state_executes_in_order(simple_graph, mock_services):
assert g.results[n1[0].id].prompt == n1[0].prompt assert g.results[n1[0].id].prompt == n1[0].prompt
assert n2[0].prompt == n1[0].prompt assert n2[0].prompt == n1[0].prompt
def test_graph_is_complete(simple_graph, mock_services): def test_graph_is_complete(simple_graph, mock_services):
g = GraphExecutionState(graph=simple_graph) g = GraphExecutionState(graph=simple_graph)
n1 = invoke_next(g, mock_services) n1 = invoke_next(g, mock_services)
@ -55,6 +96,7 @@ def test_graph_is_complete(simple_graph, mock_services):
assert g.is_complete() assert g.is_complete()
def test_graph_is_not_complete(simple_graph, mock_services): def test_graph_is_not_complete(simple_graph, mock_services):
g = GraphExecutionState(graph=simple_graph) g = GraphExecutionState(graph=simple_graph)
n1 = invoke_next(g, mock_services) n1 = invoke_next(g, mock_services)
@ -62,8 +104,10 @@ def test_graph_is_not_complete(simple_graph, mock_services):
assert not g.is_complete() assert not g.is_complete()
# TODO: test completion with iterators/subgraphs # TODO: test completion with iterators/subgraphs
def test_graph_state_expands_iterator(mock_services): def test_graph_state_expands_iterator(mock_services):
graph = Graph() graph = Graph()
graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1)) graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1))
@ -78,7 +122,7 @@ def test_graph_state_expands_iterator(mock_services):
while not g.is_complete(): while not g.is_complete():
invoke_next(g, mock_services) invoke_next(g, mock_services)
prepared_add_nodes = g.source_prepared_mapping['3'] prepared_add_nodes = g.source_prepared_mapping["3"]
results = set([g.results[n].a for n in prepared_add_nodes]) results = set([g.results[n].a for n in prepared_add_nodes])
expected = set([1, 11, 21]) expected = set([1, 11, 21])
assert results == expected assert results == expected
@ -87,7 +131,9 @@ def test_graph_state_expands_iterator(mock_services):
def test_graph_state_collects(mock_services): def test_graph_state_collects(mock_services):
graph = Graph() graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"] test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts))) graph.add_node(
PromptCollectionTestInvocation(id="1", collection=list(test_prompts))
)
graph.add_node(IterateInvocation(id="2")) graph.add_node(IterateInvocation(id="2"))
graph.add_node(PromptTestInvocation(id="3")) graph.add_node(PromptTestInvocation(id="3"))
graph.add_node(CollectInvocation(id="4")) graph.add_node(CollectInvocation(id="4"))
@ -113,10 +159,16 @@ def test_graph_state_prepares_eagerly(mock_services):
graph = Graph() graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"] test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts))) graph.add_node(
PromptCollectionTestInvocation(
id="prompt_collection", collection=list(test_prompts)
)
)
graph.add_node(IterateInvocation(id="iterate")) graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated")) graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection")) graph.add_edge(
create_edge("prompt_collection", "collection", "iterate", "collection")
)
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt")) graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
# separated, fully-preparable chain of nodes # separated, fully-preparable chain of nodes
@ -142,13 +194,21 @@ def test_graph_executes_depth_first(mock_services):
graph = Graph() graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"] test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts))) graph.add_node(
PromptCollectionTestInvocation(
id="prompt_collection", collection=list(test_prompts)
)
)
graph.add_node(IterateInvocation(id="iterate")) graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated")) graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_node(PromptTestInvocation(id="prompt_successor")) graph.add_node(PromptTestInvocation(id="prompt_successor"))
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection")) graph.add_edge(
create_edge("prompt_collection", "collection", "iterate", "collection")
)
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt")) graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt")) graph.add_edge(
create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt")
)
g = GraphExecutionState(graph=graph) g = GraphExecutionState(graph=graph)
n1 = invoke_next(g, mock_services) n1 = invoke_next(g, mock_services)

View File

@ -1,26 +1,62 @@
import pytest from .test_nodes import (
TestEventService,
from invokeai.app.services.graph import Graph, GraphExecutionState ErrorInvocation,
from invokeai.app.services.invocation_services import InvocationServices TextToImageTestInvocation,
PromptTestInvocation,
create_edge,
wait_until,
)
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.invocation_services import InvocationServices
from .test_nodes import (ErrorInvocation, ImageTestInvocation, from invokeai.app.services.graph import (
PromptTestInvocation, create_edge, wait_until) Graph,
GraphExecutionState,
LibraryGraph,
)
import pytest
@pytest.fixture @pytest.fixture
def simple_graph(): def simple_graph():
g = Graph() g = Graph()
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
g.add_node(ImageTestInvocation(id = "2")) g.add_node(TextToImageTestInvocation(id="2"))
g.add_edge(create_edge("1", "prompt", "2", "prompt")) g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g return g
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
model_manager = None, # type: ignore
events = TestEventService(),
logger = None, # type: ignore
images = None, # type: ignore
latents = None, # type: ignore
boards = None, # type: ignore
board_images = None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)
@pytest.fixture() @pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker: def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker( return Invoker(services=mock_services)
services = mock_services
)
def test_can_create_graph_state(mock_invoker: Invoker): def test_can_create_graph_state(mock_invoker: Invoker):
g = mock_invoker.create_execution_state() g = mock_invoker.create_execution_state()
@ -29,6 +65,7 @@ def test_can_create_graph_state(mock_invoker: Invoker):
assert g is not None assert g is not None
assert isinstance(g, GraphExecutionState) assert isinstance(g, GraphExecutionState)
def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph): def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph) g = mock_invoker.create_execution_state(graph=simple_graph)
mock_invoker.stop() mock_invoker.stop()
@ -37,7 +74,8 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
assert isinstance(g, GraphExecutionState) assert isinstance(g, GraphExecutionState)
assert g.graph == simple_graph assert g.graph == simple_graph
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke(mock_invoker: Invoker, simple_graph): def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph) g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke(g) invocation_id = mock_invoker.invoke(g)
@ -53,7 +91,8 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.services.graph_execution_manager.get(g.id) g = mock_invoker.services.graph_execution_manager.get(g.id)
assert len(g.executed) > 0 assert len(g.executed) > 0
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke_all(mock_invoker: Invoker, simple_graph): def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph) g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke(g, invoke_all=True) invocation_id = mock_invoker.invoke(g, invoke_all=True)
@ -69,7 +108,8 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.services.graph_execution_manager.get(g.id) g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.is_complete() assert g.is_complete()
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_handles_errors(mock_invoker: Invoker): def test_handles_errors(mock_invoker: Invoker):
g = mock_invoker.create_execution_state() g = mock_invoker.create_execution_state()
g.graph.add_node(ErrorInvocation(id="1")) g.graph.add_node(ErrorInvocation(id="1"))
@ -87,4 +127,4 @@ def test_handles_errors(mock_invoker: Invoker):
assert g.has_error() assert g.has_error()
assert g.is_complete() assert g.is_complete()
assert all((i in g.errors for i in g.source_prepared_mapping['1'])) assert all((i in g.errors for i in g.source_prepared_mapping["1"]))

View File

@ -1,6 +1,5 @@
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from invokeai.app.invocations.upscale import UpscaleInvocation from invokeai.app.invocations.upscale import UpscaleInvocation
from invokeai.app.invocations.image import * from invokeai.app.invocations.image import *
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
@ -18,7 +17,7 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
# Tests # Tests
def test_connections_are_compatible(): def test_connections_are_compatible():
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "image" from_field = "image"
to_node = UpscaleInvocation(id = "2") to_node = UpscaleInvocation(id = "2")
to_field = "image" to_field = "image"
@ -28,7 +27,7 @@ def test_connections_are_compatible():
assert result == True assert result == True
def test_connections_are_incompatible(): def test_connections_are_incompatible():
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "image" from_field = "image"
to_node = UpscaleInvocation(id = "2") to_node = UpscaleInvocation(id = "2")
to_field = "strength" to_field = "strength"
@ -38,7 +37,7 @@ def test_connections_are_incompatible():
assert result == False assert result == False
def test_connections_incompatible_with_invalid_fields(): def test_connections_incompatible_with_invalid_fields():
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "invalid_field" from_field = "invalid_field"
to_node = UpscaleInvocation(id = "2") to_node = UpscaleInvocation(id = "2")
to_field = "image" to_field = "image"
@ -56,28 +55,28 @@ def test_connections_incompatible_with_invalid_fields():
def test_graph_can_add_node(): def test_graph_can_add_node():
g = Graph() g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi") n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n) g.add_node(n)
assert n.id in g.nodes assert n.id in g.nodes
def test_graph_fails_to_add_node_with_duplicate_id(): def test_graph_fails_to_add_node_with_duplicate_id():
g = Graph() g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi") n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n) g.add_node(n)
n2 = TextToImageInvocation(id = "1", prompt = "Banana sushi the second") n2 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi the second")
with pytest.raises(NodeAlreadyInGraphError): with pytest.raises(NodeAlreadyInGraphError):
g.add_node(n2) g.add_node(n2)
def test_graph_updates_node(): def test_graph_updates_node():
g = Graph() g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi") n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n) g.add_node(n)
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second") n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second")
g.add_node(n2) g.add_node(n2)
nu = TextToImageInvocation(id = "1", prompt = "Banana sushi updated") nu = TextToImageTestInvocation(id = "1", prompt = "Banana sushi updated")
g.update_node("1", nu) g.update_node("1", nu)
@ -85,7 +84,7 @@ def test_graph_updates_node():
def test_graph_fails_to_update_node_if_type_changes(): def test_graph_fails_to_update_node_if_type_changes():
g = Graph() g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi") n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n) g.add_node(n)
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n2) g.add_node(n2)
@ -97,14 +96,14 @@ def test_graph_fails_to_update_node_if_type_changes():
def test_graph_allows_non_conflicting_id_change(): def test_graph_allows_non_conflicting_id_change():
g = Graph() g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi") n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n) g.add_node(n)
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n2) g.add_node(n2)
e1 = create_edge(n.id,"image",n2.id,"image") e1 = create_edge(n.id,"image",n2.id,"image")
g.add_edge(e1) g.add_edge(e1)
nu = TextToImageInvocation(id = "3", prompt = "Banana sushi") nu = TextToImageTestInvocation(id = "3", prompt = "Banana sushi")
g.update_node("1", nu) g.update_node("1", nu)
with pytest.raises(NodeNotFoundError): with pytest.raises(NodeNotFoundError):
@ -117,18 +116,18 @@ def test_graph_allows_non_conflicting_id_change():
def test_graph_fails_to_update_node_id_if_conflict(): def test_graph_fails_to_update_node_id_if_conflict():
g = Graph() g = Graph()
n = TextToImageInvocation(id = "1", prompt = "Banana sushi") n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n) g.add_node(n)
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second") n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second")
g.add_node(n2) g.add_node(n2)
nu = TextToImageInvocation(id = "2", prompt = "Banana sushi") nu = TextToImageTestInvocation(id = "2", prompt = "Banana sushi")
with pytest.raises(NodeAlreadyInGraphError): with pytest.raises(NodeAlreadyInGraphError):
g.update_node("1", nu) g.update_node("1", nu)
def test_graph_adds_edge(): def test_graph_adds_edge():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -148,7 +147,7 @@ def test_graph_fails_to_add_edge_with_cycle():
def test_graph_fails_to_add_edge_with_long_cycle(): def test_graph_fails_to_add_edge_with_long_cycle():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
n3 = UpscaleInvocation(id = "3") n3 = UpscaleInvocation(id = "3")
g.add_node(n1) g.add_node(n1)
@ -164,7 +163,7 @@ def test_graph_fails_to_add_edge_with_long_cycle():
def test_graph_fails_to_add_edge_with_missing_node_id(): def test_graph_fails_to_add_edge_with_missing_node_id():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -177,7 +176,7 @@ def test_graph_fails_to_add_edge_with_missing_node_id():
def test_graph_fails_to_add_edge_when_destination_exists(): def test_graph_fails_to_add_edge_when_destination_exists():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
n3 = UpscaleInvocation(id = "3") n3 = UpscaleInvocation(id = "3")
g.add_node(n1) g.add_node(n1)
@ -194,7 +193,7 @@ def test_graph_fails_to_add_edge_when_destination_exists():
def test_graph_fails_to_add_edge_with_mismatched_types(): def test_graph_fails_to_add_edge_with_mismatched_types():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -204,8 +203,8 @@ def test_graph_fails_to_add_edge_with_mismatched_types():
def test_graph_connects_collector(): def test_graph_connects_collector():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi 2") n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi 2")
n3 = CollectInvocation(id = "3") n3 = CollectInvocation(id = "3")
n4 = ListPassThroughInvocation(id = "4") n4 = ListPassThroughInvocation(id = "4")
g.add_node(n1) g.add_node(n1)
@ -224,7 +223,7 @@ def test_graph_connects_collector():
def test_graph_collector_invalid_with_varying_input_types(): def test_graph_collector_invalid_with_varying_input_types():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2") n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2")
n3 = CollectInvocation(id = "3") n3 = CollectInvocation(id = "3")
g.add_node(n1) g.add_node(n1)
@ -282,7 +281,7 @@ def test_graph_connects_iterator():
g = Graph() g = Graph()
n1 = ListPassThroughInvocation(id = "1") n1 = ListPassThroughInvocation(id = "1")
n2 = IterateInvocation(id = "2") n2 = IterateInvocation(id = "2")
n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi") n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
g.add_node(n3) g.add_node(n3)
@ -298,7 +297,7 @@ def test_graph_iterator_invalid_if_multiple_inputs():
g = Graph() g = Graph()
n1 = ListPassThroughInvocation(id = "1") n1 = ListPassThroughInvocation(id = "1")
n2 = IterateInvocation(id = "2") n2 = IterateInvocation(id = "2")
n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi") n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi")
n4 = ListPassThroughInvocation(id = "4") n4 = ListPassThroughInvocation(id = "4")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -316,7 +315,7 @@ def test_graph_iterator_invalid_if_multiple_inputs():
def test_graph_iterator_invalid_if_input_not_list(): def test_graph_iterator_invalid_if_input_not_list():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = IterateInvocation(id = "2") n2 = IterateInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -344,7 +343,7 @@ def test_graph_iterator_invalid_if_output_and_input_types_different():
def test_graph_validates(): def test_graph_validates():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -355,7 +354,7 @@ def test_graph_validates():
def test_graph_invalid_if_edges_reference_missing_nodes(): def test_graph_invalid_if_edges_reference_missing_nodes():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.nodes[n1.id] = n1 g.nodes[n1.id] = n1
e1 = create_edge("1","image","2","image") e1 = create_edge("1","image","2","image")
g.edges.append(e1) g.edges.append(e1)
@ -367,7 +366,7 @@ def test_graph_invalid_if_subgraph_invalid():
n1 = GraphInvocation(id = "1") n1 = GraphInvocation(id = "1")
n1.graph = Graph() n1.graph = Graph()
n1_1 = TextToImageInvocation(id = "2", prompt = "Banana sushi") n1_1 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi")
n1.graph.nodes[n1_1.id] = n1_1 n1.graph.nodes[n1_1.id] = n1_1
e1 = create_edge("1","image","2","image") e1 = create_edge("1","image","2","image")
n1.graph.edges.append(e1) n1.graph.edges.append(e1)
@ -391,7 +390,7 @@ def test_graph_invalid_if_has_cycle():
def test_graph_invalid_with_invalid_connection(): def test_graph_invalid_with_invalid_connection():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.nodes[n1.id] = n1 g.nodes[n1.id] = n1
g.nodes[n2.id] = n2 g.nodes[n2.id] = n2
@ -408,7 +407,7 @@ def test_graph_gets_subgraph_node():
n1.graph = Graph() n1.graph = Graph()
n1.graph.add_node n1.graph.add_node
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n1.graph.add_node(n1_1) n1.graph.add_node(n1_1)
g.add_node(n1) g.add_node(n1)
@ -485,7 +484,7 @@ def test_graph_fails_to_get_missing_subgraph_node():
n1.graph = Graph() n1.graph = Graph()
n1.graph.add_node n1.graph.add_node
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n1.graph.add_node(n1_1) n1.graph.add_node(n1_1)
g.add_node(n1) g.add_node(n1)
@ -499,7 +498,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
n1.graph = Graph() n1.graph = Graph()
n1.graph.add_node n1.graph.add_node
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n1.graph.add_node(n1_1) n1.graph.add_node(n1_1)
g.add_node(n1) g.add_node(n1)
@ -512,7 +511,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
def test_graph_gets_networkx_graph(): def test_graph_gets_networkx_graph():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -529,7 +528,7 @@ def test_graph_gets_networkx_graph():
# TODO: Graph serializes and deserializes # TODO: Graph serializes and deserializes
def test_graph_can_serialize(): def test_graph_can_serialize():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)
@ -541,7 +540,7 @@ def test_graph_can_serialize():
def test_graph_can_deserialize(): def test_graph_can_deserialize():
g = Graph() g = Graph()
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2") n2 = UpscaleInvocation(id = "2")
g.add_node(n1) g.add_node(n1)
g.add_node(n2) g.add_node(n2)

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Literal from typing import Any, Callable, Literal, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.image import ImageField from invokeai.app.invocations.image import ImageField
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
@ -43,14 +43,23 @@ class ImageTestInvocationOutput(BaseInvocationOutput):
image: ImageField = Field() image: ImageField = Field()
class ImageTestInvocation(BaseInvocation): class TextToImageTestInvocation(BaseInvocation):
type: Literal['test_image'] = 'test_image' type: Literal['test_text_to_image'] = 'test_text_to_image'
prompt: str = Field(default = "") prompt: str = Field(default = "")
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
class ImageToImageTestInvocation(BaseInvocation):
type: Literal['test_image_to_image'] = 'test_image_to_image'
prompt: str = Field(default = "")
image: Union[ImageField, None] = Field(default=None)
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
class PromptCollectionTestInvocationOutput(BaseInvocationOutput): class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output' type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output'
collection: list[str] = Field(default_factory=list) collection: list[str] = Field(default_factory=list)
@ -62,7 +71,6 @@ class PromptCollectionTestInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Edge, EdgeConnection from invokeai.app.services.graph import Edge, EdgeConnection

View File

@ -1,301 +0,0 @@
import unittest
from typing import Union
import torch
from invokeai.backend.stable_diffusion import TextualInversionManager
KNOWN_WORDS = ['a', 'b', 'c']
KNOWN_WORDS_TOKEN_IDS = [0, 1, 2]
UNKNOWN_WORDS = ['d', 'e', 'f']
class DummyEmbeddingsList(list):
def __getattr__(self, name):
if name == 'num_embeddings':
return len(self)
elif name == 'weight':
return self
elif name == 'data':
return self
def make_dummy_embedding():
return torch.randn([768])
class DummyTransformer:
def __init__(self):
self.embeddings = DummyEmbeddingsList([make_dummy_embedding() for _ in range(len(KNOWN_WORDS))])
def resize_token_embeddings(self, new_size=None):
if new_size is None:
return self.embeddings
else:
while len(self.embeddings) > new_size:
self.embeddings.pop(-1)
while len(self.embeddings) < new_size:
self.embeddings.append(make_dummy_embedding())
def get_input_embeddings(self):
return self.embeddings
class DummyTokenizer():
def __init__(self):
self.tokens = KNOWN_WORDS.copy()
self.bos_token_id = 49406 # these are what the real CLIPTokenizer has
self.eos_token_id = 49407
self.pad_token_id = 49407
self.unk_token_id = 49407
def convert_tokens_to_ids(self, token_str):
try:
return self.tokens.index(token_str)
except ValueError:
return self.unk_token_id
def add_tokens(self, token_str):
if token_str in self.tokens:
return 0
self.tokens.append(token_str)
return 1
class DummyClipEmbedder:
def __init__(self):
self.max_length = 77
self.transformer = DummyTransformer()
self.tokenizer = DummyTokenizer()
self.position_embeddings_tensor = torch.randn([77,768], dtype=torch.float32)
def position_embedding(self, indices: Union[list,torch.Tensor]):
if type(indices) is list:
indices = torch.tensor(indices, dtype=int)
return torch.index_select(self.position_embeddings_tensor, 0, indices)
def was_embedding_overwritten_correctly(tim: TextualInversionManager, overwritten_embedding: torch.Tensor, ti_indices: list, ti_embedding: torch.Tensor) -> bool:
return torch.allclose(overwritten_embedding[ti_indices], ti_embedding + tim.clip_embedder.position_embedding(ti_indices))
def make_dummy_textual_inversion_manager():
return TextualInversionManager(
tokenizer=DummyTokenizer(),
text_encoder=DummyTransformer()
)
class TextualInversionManagerTestCase(unittest.TestCase):
def test_construction(self):
tim = make_dummy_textual_inversion_manager()
def test_add_embedding_for_known_token(self):
tim = make_dummy_textual_inversion_manager()
test_embedding = torch.randn([1, 768])
test_embedding_name = KNOWN_WORDS[0]
self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
ti = tim._add_textual_inversion(test_embedding_name, test_embedding)
self.assertEqual(ti.trigger_token_id, 0)
# check adding 'test' did not create a new word
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
self.assertEqual(pre_embeddings_count, embeddings_count)
# check it was added
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
self.assertIsNotNone(textual_inversion)
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding))
self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
self.assertEqual(textual_inversion.trigger_token_id, ti.trigger_token_id)
def test_add_embedding_for_unknown_token(self):
tim = make_dummy_textual_inversion_manager()
test_embedding_1 = torch.randn([1, 768])
test_embedding_name_1 = UNKNOWN_WORDS[0]
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
added_token_id_1 = tim._add_textual_inversion(test_embedding_name_1, test_embedding_1).trigger_token_id
# new token id should get added on the end
self.assertEqual(added_token_id_1, len(KNOWN_WORDS))
# check adding did create a new word
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
self.assertEqual(pre_embeddings_count+1, embeddings_count)
# check it was added
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1))
textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_1)
self.assertIsNotNone(textual_inversion)
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1))
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1)
self.assertEqual(textual_inversion.trigger_token_id, added_token_id_1)
# add another one
test_embedding_2 = torch.randn([1, 768])
test_embedding_name_2 = UNKNOWN_WORDS[1]
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
added_token_id_2 = tim._add_textual_inversion(test_embedding_name_2, test_embedding_2).trigger_token_id
self.assertEqual(added_token_id_2, len(KNOWN_WORDS)+1)
# check adding did create a new word
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
self.assertEqual(pre_embeddings_count+1, embeddings_count)
# check it was added
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_2))
textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_2)
self.assertIsNotNone(textual_inversion)
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_2))
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_2)
self.assertEqual(textual_inversion.trigger_token_id, added_token_id_2)
# check the old one is still there
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1))
textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_1)
self.assertIsNotNone(textual_inversion)
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1))
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1)
self.assertEqual(textual_inversion.trigger_token_id, added_token_id_1)
def test_pad_raises_on_eos_bos(self):
tim = make_dummy_textual_inversion_manager()
prompt_token_ids_with_eos_bos = [tim.tokenizer.bos_token_id] + \
[KNOWN_WORDS_TOKEN_IDS] + \
[tim.tokenizer.eos_token_id]
with self.assertRaises(ValueError):
tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_with_eos_bos)
def test_pad_tokens_list_vector_length_1(self):
tim = make_dummy_textual_inversion_manager()
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
test_embedding_1v = torch.randn([1, 768])
test_embedding_1v_token = "<inversion-trigger-vector-length-1>"
test_embedding_1v_token_id = tim._add_textual_inversion(test_embedding_1v_token, test_embedding_1v).trigger_token_id
self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS))
# at the end
prompt_token_ids_1v_append = prompt_token_ids + [test_embedding_1v_token_id]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_append)
self.assertEqual(prompt_token_ids_1v_append, expanded_prompt_token_ids)
# at the start
prompt_token_ids_1v_prepend = [test_embedding_1v_token_id] + prompt_token_ids
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_prepend)
self.assertEqual(prompt_token_ids_1v_prepend, expanded_prompt_token_ids)
# in the middle
prompt_token_ids_1v_insert = prompt_token_ids[0:2] + [test_embedding_1v_token_id] + prompt_token_ids[2:3]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_insert)
self.assertEqual(prompt_token_ids_1v_insert, expanded_prompt_token_ids)
def test_pad_tokens_list_vector_length_2(self):
tim = make_dummy_textual_inversion_manager()
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
test_embedding_2v = torch.randn([2, 768])
test_embedding_2v_token = "<inversion-trigger-vector-length-2>"
test_embedding_2v_token_id = tim._add_textual_inversion(test_embedding_2v_token, test_embedding_2v).trigger_token_id
test_embedding_2v_pad_token_ids = tim.get_textual_inversion_for_token_id(test_embedding_2v_token_id).pad_token_ids
self.assertEqual(test_embedding_2v_token_id, len(KNOWN_WORDS))
# at the end
prompt_token_ids_2v_append = prompt_token_ids + [test_embedding_2v_token_id]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_append)
self.assertNotEqual(prompt_token_ids_2v_append, expanded_prompt_token_ids)
self.assertEqual(prompt_token_ids + [test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids, expanded_prompt_token_ids)
# at the start
prompt_token_ids_2v_prepend = [test_embedding_2v_token_id] + prompt_token_ids
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_prepend)
self.assertNotEqual(prompt_token_ids_2v_prepend, expanded_prompt_token_ids)
self.assertEqual([test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids + prompt_token_ids, expanded_prompt_token_ids)
# in the middle
prompt_token_ids_2v_insert = prompt_token_ids[0:2] + [test_embedding_2v_token_id] + prompt_token_ids[2:3]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_insert)
self.assertNotEqual(prompt_token_ids_2v_insert, expanded_prompt_token_ids)
self.assertEqual(prompt_token_ids[0:2] + [test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids + prompt_token_ids[2:3], expanded_prompt_token_ids)
def test_pad_tokens_list_vector_length_8(self):
tim = make_dummy_textual_inversion_manager()
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
test_embedding_8v = torch.randn([8, 768])
test_embedding_8v_token = "<inversion-trigger-vector-length-8>"
test_embedding_8v_token_id = tim._add_textual_inversion(test_embedding_8v_token, test_embedding_8v).trigger_token_id
test_embedding_8v_pad_token_ids = tim.get_textual_inversion_for_token_id(test_embedding_8v_token_id).pad_token_ids
self.assertEqual(test_embedding_8v_token_id, len(KNOWN_WORDS))
# at the end
prompt_token_ids_8v_append = prompt_token_ids + [test_embedding_8v_token_id]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_append)
self.assertNotEqual(prompt_token_ids_8v_append, expanded_prompt_token_ids)
self.assertEqual(prompt_token_ids + [test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids, expanded_prompt_token_ids)
# at the start
prompt_token_ids_8v_prepend = [test_embedding_8v_token_id] + prompt_token_ids
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_prepend)
self.assertNotEqual(prompt_token_ids_8v_prepend, expanded_prompt_token_ids)
self.assertEqual([test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids + prompt_token_ids, expanded_prompt_token_ids)
# in the middle
prompt_token_ids_8v_insert = prompt_token_ids[0:2] + [test_embedding_8v_token_id] + prompt_token_ids[2:3]
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_insert)
self.assertNotEqual(prompt_token_ids_8v_insert, expanded_prompt_token_ids)
self.assertEqual(prompt_token_ids[0:2] + [test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids + prompt_token_ids[2:3], expanded_prompt_token_ids)
def test_deferred_loading(self):
tim = make_dummy_textual_inversion_manager()
test_embedding = torch.randn([1, 768])
test_embedding_name = UNKNOWN_WORDS[0]
self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
ti = tim._add_textual_inversion(test_embedding_name, test_embedding, defer_injecting_tokens=True)
self.assertIsNone(ti.trigger_token_id)
# check that a new word is not yet created
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
self.assertEqual(pre_embeddings_count, embeddings_count)
# check it was added
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
self.assertIsNotNone(textual_inversion)
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding))
self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
self.assertIsNone(textual_inversion.trigger_token_id, ti.trigger_token_id)
# check it lazy-loads
prompt = " ".join([KNOWN_WORDS[0], UNKNOWN_WORDS[0], KNOWN_WORDS[1]])
tim.create_deferred_token_ids_for_any_trigger_terms(prompt)
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
self.assertEqual(pre_embeddings_count+1, embeddings_count)
textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
self.assertEqual(textual_inversion.trigger_token_id, len(KNOWN_WORDS))