mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/fix-vae-conversion-crash
This commit is contained in:
commit
cfd09214d3
@ -7,7 +7,7 @@ if TYPE_CHECKING:
|
|||||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||||
from invokeai.app.services.boards import BoardServiceABC
|
from invokeai.app.services.boards import BoardServiceABC
|
||||||
from invokeai.app.services.images import ImageServiceABC
|
from invokeai.app.services.images import ImageServiceABC
|
||||||
from invokeai.backend import ModelManager
|
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||||
from invokeai.app.services.restoration_services import RestorationServices
|
from invokeai.app.services.restoration_services import RestorationServices
|
||||||
@ -22,46 +22,47 @@ class InvocationServices:
|
|||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||||
events: "EventServiceBase"
|
|
||||||
latents: "LatentsStorageBase"
|
|
||||||
queue: "InvocationQueueABC"
|
|
||||||
model_manager: "ModelManager"
|
|
||||||
restoration: "RestorationServices"
|
|
||||||
configuration: "InvokeAISettings"
|
|
||||||
images: "ImageServiceABC"
|
|
||||||
boards: "BoardServiceABC"
|
|
||||||
board_images: "BoardImagesServiceABC"
|
board_images: "BoardImagesServiceABC"
|
||||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
boards: "BoardServiceABC"
|
||||||
|
configuration: "InvokeAISettings"
|
||||||
|
events: "EventServiceBase"
|
||||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||||
|
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||||
|
images: "ImageServiceABC"
|
||||||
|
latents: "LatentsStorageBase"
|
||||||
|
logger: "Logger"
|
||||||
|
model_manager: "ModelManagerServiceBase"
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
|
queue: "InvocationQueueABC"
|
||||||
|
restoration: "RestorationServices"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_manager: "ModelManager",
|
|
||||||
events: "EventServiceBase",
|
|
||||||
logger: "Logger",
|
|
||||||
latents: "LatentsStorageBase",
|
|
||||||
images: "ImageServiceABC",
|
|
||||||
boards: "BoardServiceABC",
|
|
||||||
board_images: "BoardImagesServiceABC",
|
board_images: "BoardImagesServiceABC",
|
||||||
queue: "InvocationQueueABC",
|
boards: "BoardServiceABC",
|
||||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
|
||||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
|
||||||
processor: "InvocationProcessorABC",
|
|
||||||
restoration: "RestorationServices",
|
|
||||||
configuration: "InvokeAISettings",
|
configuration: "InvokeAISettings",
|
||||||
|
events: "EventServiceBase",
|
||||||
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||||
|
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||||
|
images: "ImageServiceABC",
|
||||||
|
latents: "LatentsStorageBase",
|
||||||
|
logger: "Logger",
|
||||||
|
model_manager: "ModelManagerServiceBase",
|
||||||
|
processor: "InvocationProcessorABC",
|
||||||
|
queue: "InvocationQueueABC",
|
||||||
|
restoration: "RestorationServices",
|
||||||
):
|
):
|
||||||
self.model_manager = model_manager
|
|
||||||
self.events = events
|
|
||||||
self.logger = logger
|
|
||||||
self.latents = latents
|
|
||||||
self.images = images
|
|
||||||
self.boards = boards
|
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.queue = queue
|
|
||||||
self.graph_library = graph_library
|
|
||||||
self.graph_execution_manager = graph_execution_manager
|
|
||||||
self.processor = processor
|
|
||||||
self.restoration = restoration
|
|
||||||
self.configuration = configuration
|
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
|
self.boards = boards
|
||||||
|
self.configuration = configuration
|
||||||
|
self.events = events
|
||||||
|
self.graph_execution_manager = graph_execution_manager
|
||||||
|
self.graph_library = graph_library
|
||||||
|
self.images = images
|
||||||
|
self.latents = latents
|
||||||
|
self.logger = logger
|
||||||
|
self.model_manager = model_manager
|
||||||
|
self.processor = processor
|
||||||
|
self.queue = queue
|
||||||
|
self.restoration = restoration
|
||||||
|
@ -4,6 +4,8 @@ import argparse
|
|||||||
import shlex
|
import shlex
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
# note that this includes both old sampler names and new scheduler names
|
||||||
|
# in order to be able to parse both 2.0 and 3.0-pre-nodes versions of invokeai.init
|
||||||
SAMPLER_CHOICES = [
|
SAMPLER_CHOICES = [
|
||||||
"ddim",
|
"ddim",
|
||||||
"ddpm",
|
"ddpm",
|
||||||
@ -27,6 +29,15 @@ SAMPLER_CHOICES = [
|
|||||||
"dpmpp_sde",
|
"dpmpp_sde",
|
||||||
"dpmpp_sde_k",
|
"dpmpp_sde_k",
|
||||||
"unipc",
|
"unipc",
|
||||||
|
"k_dpm_2_a",
|
||||||
|
"k_dpm_2",
|
||||||
|
"k_dpmpp_2_a",
|
||||||
|
"k_dpmpp_2",
|
||||||
|
"k_euler_a",
|
||||||
|
"k_euler",
|
||||||
|
"k_heun",
|
||||||
|
"k_lms",
|
||||||
|
"plms",
|
||||||
]
|
]
|
||||||
|
|
||||||
PRECISION_CHOICES = [
|
PRECISION_CHOICES = [
|
||||||
|
@ -1,30 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
|
||||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
|
||||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
|
||||||
from invokeai.app.services.graph import LibraryGraph, GraphExecutionState
|
|
||||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
|
||||||
|
|
||||||
# Ignore these files as they need to be rewritten following the model manager refactor
|
|
||||||
collect_ignore = ["nodes/test_graph_execution_state.py", "nodes/test_node_graph.py", "test_textual_inversion.py"]
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def mock_services():
|
|
||||||
# NOTE: none of these are actually called by the test invocations
|
|
||||||
return InvocationServices(
|
|
||||||
model_manager = None, # type: ignore
|
|
||||||
events = None, # type: ignore
|
|
||||||
logger = None, # type: ignore
|
|
||||||
images = None, # type: ignore
|
|
||||||
latents = None, # type: ignore
|
|
||||||
board_images=None, # type: ignore
|
|
||||||
boards=None, # type: ignore
|
|
||||||
queue = MemoryInvocationQueue(),
|
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
|
||||||
filename=sqlite_memory, table_name="graphs"
|
|
||||||
),
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
|
||||||
processor = DefaultInvocationProcessor(),
|
|
||||||
restoration = None, # type: ignore
|
|
||||||
configuration = None, # type: ignore
|
|
||||||
)
|
|
@ -1,39 +1,79 @@
|
|||||||
import pytest
|
from .test_invoker import create_edge
|
||||||
|
from .test_nodes import (
|
||||||
from invokeai.app.invocations.baseinvocation import (BaseInvocation,
|
TestEventService,
|
||||||
|
TextToImageTestInvocation,
|
||||||
|
PromptTestInvocation,
|
||||||
|
PromptCollectionTestInvocation,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||||
|
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||||
|
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||||
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext)
|
InvocationContext,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.collections import RangeInvocation
|
from invokeai.app.invocations.collections import RangeInvocation
|
||||||
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
||||||
from invokeai.app.services.graph import (CollectInvocation, Graph,
|
|
||||||
GraphExecutionState,
|
|
||||||
IterateInvocation)
|
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
|
from invokeai.app.services.graph import (
|
||||||
from .test_invoker import create_edge
|
Graph,
|
||||||
from .test_nodes import (ImageTestInvocation, PromptCollectionTestInvocation,
|
CollectInvocation,
|
||||||
PromptTestInvocation)
|
IterateInvocation,
|
||||||
|
GraphExecutionState,
|
||||||
|
LibraryGraph,
|
||||||
|
)
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def simple_graph():
|
def simple_graph():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||||
g.add_node(ImageTestInvocation(id = "2"))
|
g.add_node(TextToImageTestInvocation(id="2"))
|
||||||
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
||||||
return g
|
return g
|
||||||
|
|
||||||
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
|
||||||
|
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||||
|
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||||
|
# the test invocations.
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_services() -> InvocationServices:
|
||||||
|
# NOTE: none of these are actually called by the test invocations
|
||||||
|
return InvocationServices(
|
||||||
|
model_manager = None, # type: ignore
|
||||||
|
events = TestEventService(),
|
||||||
|
logger = None, # type: ignore
|
||||||
|
images = None, # type: ignore
|
||||||
|
latents = None, # type: ignore
|
||||||
|
boards = None, # type: ignore
|
||||||
|
board_images = None, # type: ignore
|
||||||
|
queue = MemoryInvocationQueue(),
|
||||||
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
|
filename=sqlite_memory, table_name="graphs"
|
||||||
|
),
|
||||||
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
|
processor = DefaultInvocationProcessor(),
|
||||||
|
restoration = None, # type: ignore
|
||||||
|
configuration = None, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def invoke_next(
|
||||||
|
g: GraphExecutionState, services: InvocationServices
|
||||||
|
) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
||||||
n = g.next()
|
n = g.next()
|
||||||
if n is None:
|
if n is None:
|
||||||
return (None, None)
|
return (None, None)
|
||||||
|
|
||||||
print(f'invoking {n.id}: {type(n)}')
|
print(f"invoking {n.id}: {type(n)}")
|
||||||
o = n.invoke(InvocationContext(services, "1"))
|
o = n.invoke(InvocationContext(services, "1"))
|
||||||
g.complete(n.id, o)
|
g.complete(n.id, o)
|
||||||
|
|
||||||
return (n, o)
|
return (n, o)
|
||||||
|
|
||||||
|
|
||||||
def test_graph_state_executes_in_order(simple_graph, mock_services):
|
def test_graph_state_executes_in_order(simple_graph, mock_services):
|
||||||
g = GraphExecutionState(graph=simple_graph)
|
g = GraphExecutionState(graph=simple_graph)
|
||||||
|
|
||||||
@ -47,6 +87,7 @@ def test_graph_state_executes_in_order(simple_graph, mock_services):
|
|||||||
assert g.results[n1[0].id].prompt == n1[0].prompt
|
assert g.results[n1[0].id].prompt == n1[0].prompt
|
||||||
assert n2[0].prompt == n1[0].prompt
|
assert n2[0].prompt == n1[0].prompt
|
||||||
|
|
||||||
|
|
||||||
def test_graph_is_complete(simple_graph, mock_services):
|
def test_graph_is_complete(simple_graph, mock_services):
|
||||||
g = GraphExecutionState(graph=simple_graph)
|
g = GraphExecutionState(graph=simple_graph)
|
||||||
n1 = invoke_next(g, mock_services)
|
n1 = invoke_next(g, mock_services)
|
||||||
@ -55,6 +96,7 @@ def test_graph_is_complete(simple_graph, mock_services):
|
|||||||
|
|
||||||
assert g.is_complete()
|
assert g.is_complete()
|
||||||
|
|
||||||
|
|
||||||
def test_graph_is_not_complete(simple_graph, mock_services):
|
def test_graph_is_not_complete(simple_graph, mock_services):
|
||||||
g = GraphExecutionState(graph=simple_graph)
|
g = GraphExecutionState(graph=simple_graph)
|
||||||
n1 = invoke_next(g, mock_services)
|
n1 = invoke_next(g, mock_services)
|
||||||
@ -62,8 +104,10 @@ def test_graph_is_not_complete(simple_graph, mock_services):
|
|||||||
|
|
||||||
assert not g.is_complete()
|
assert not g.is_complete()
|
||||||
|
|
||||||
|
|
||||||
# TODO: test completion with iterators/subgraphs
|
# TODO: test completion with iterators/subgraphs
|
||||||
|
|
||||||
|
|
||||||
def test_graph_state_expands_iterator(mock_services):
|
def test_graph_state_expands_iterator(mock_services):
|
||||||
graph = Graph()
|
graph = Graph()
|
||||||
graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1))
|
graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1))
|
||||||
@ -78,7 +122,7 @@ def test_graph_state_expands_iterator(mock_services):
|
|||||||
while not g.is_complete():
|
while not g.is_complete():
|
||||||
invoke_next(g, mock_services)
|
invoke_next(g, mock_services)
|
||||||
|
|
||||||
prepared_add_nodes = g.source_prepared_mapping['3']
|
prepared_add_nodes = g.source_prepared_mapping["3"]
|
||||||
results = set([g.results[n].a for n in prepared_add_nodes])
|
results = set([g.results[n].a for n in prepared_add_nodes])
|
||||||
expected = set([1, 11, 21])
|
expected = set([1, 11, 21])
|
||||||
assert results == expected
|
assert results == expected
|
||||||
@ -87,7 +131,9 @@ def test_graph_state_expands_iterator(mock_services):
|
|||||||
def test_graph_state_collects(mock_services):
|
def test_graph_state_collects(mock_services):
|
||||||
graph = Graph()
|
graph = Graph()
|
||||||
test_prompts = ["Banana sushi", "Cat sushi"]
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
||||||
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts)))
|
graph.add_node(
|
||||||
|
PromptCollectionTestInvocation(id="1", collection=list(test_prompts))
|
||||||
|
)
|
||||||
graph.add_node(IterateInvocation(id="2"))
|
graph.add_node(IterateInvocation(id="2"))
|
||||||
graph.add_node(PromptTestInvocation(id="3"))
|
graph.add_node(PromptTestInvocation(id="3"))
|
||||||
graph.add_node(CollectInvocation(id="4"))
|
graph.add_node(CollectInvocation(id="4"))
|
||||||
@ -113,10 +159,16 @@ def test_graph_state_prepares_eagerly(mock_services):
|
|||||||
graph = Graph()
|
graph = Graph()
|
||||||
|
|
||||||
test_prompts = ["Banana sushi", "Cat sushi"]
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
||||||
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
|
graph.add_node(
|
||||||
|
PromptCollectionTestInvocation(
|
||||||
|
id="prompt_collection", collection=list(test_prompts)
|
||||||
|
)
|
||||||
|
)
|
||||||
graph.add_node(IterateInvocation(id="iterate"))
|
graph.add_node(IterateInvocation(id="iterate"))
|
||||||
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
||||||
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
|
graph.add_edge(
|
||||||
|
create_edge("prompt_collection", "collection", "iterate", "collection")
|
||||||
|
)
|
||||||
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
||||||
|
|
||||||
# separated, fully-preparable chain of nodes
|
# separated, fully-preparable chain of nodes
|
||||||
@ -142,13 +194,21 @@ def test_graph_executes_depth_first(mock_services):
|
|||||||
graph = Graph()
|
graph = Graph()
|
||||||
|
|
||||||
test_prompts = ["Banana sushi", "Cat sushi"]
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
||||||
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
|
graph.add_node(
|
||||||
|
PromptCollectionTestInvocation(
|
||||||
|
id="prompt_collection", collection=list(test_prompts)
|
||||||
|
)
|
||||||
|
)
|
||||||
graph.add_node(IterateInvocation(id="iterate"))
|
graph.add_node(IterateInvocation(id="iterate"))
|
||||||
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
||||||
graph.add_node(PromptTestInvocation(id="prompt_successor"))
|
graph.add_node(PromptTestInvocation(id="prompt_successor"))
|
||||||
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
|
graph.add_edge(
|
||||||
|
create_edge("prompt_collection", "collection", "iterate", "collection")
|
||||||
|
)
|
||||||
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
||||||
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
|
graph.add_edge(
|
||||||
|
create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt")
|
||||||
|
)
|
||||||
|
|
||||||
g = GraphExecutionState(graph=graph)
|
g = GraphExecutionState(graph=graph)
|
||||||
n1 = invoke_next(g, mock_services)
|
n1 = invoke_next(g, mock_services)
|
||||||
|
@ -1,26 +1,62 @@
|
|||||||
import pytest
|
from .test_nodes import (
|
||||||
|
TestEventService,
|
||||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
ErrorInvocation,
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
TextToImageTestInvocation,
|
||||||
|
PromptTestInvocation,
|
||||||
|
create_edge,
|
||||||
|
wait_until,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||||
|
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||||
|
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from .test_nodes import (ErrorInvocation, ImageTestInvocation,
|
from invokeai.app.services.graph import (
|
||||||
PromptTestInvocation, create_edge, wait_until)
|
Graph,
|
||||||
|
GraphExecutionState,
|
||||||
|
LibraryGraph,
|
||||||
|
)
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def simple_graph():
|
def simple_graph():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||||
g.add_node(ImageTestInvocation(id = "2"))
|
g.add_node(TextToImageTestInvocation(id="2"))
|
||||||
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||||
|
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||||
|
# the test invocations.
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_services() -> InvocationServices:
|
||||||
|
# NOTE: none of these are actually called by the test invocations
|
||||||
|
return InvocationServices(
|
||||||
|
model_manager = None, # type: ignore
|
||||||
|
events = TestEventService(),
|
||||||
|
logger = None, # type: ignore
|
||||||
|
images = None, # type: ignore
|
||||||
|
latents = None, # type: ignore
|
||||||
|
boards = None, # type: ignore
|
||||||
|
board_images = None, # type: ignore
|
||||||
|
queue = MemoryInvocationQueue(),
|
||||||
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
|
filename=sqlite_memory, table_name="graphs"
|
||||||
|
),
|
||||||
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
|
processor = DefaultInvocationProcessor(),
|
||||||
|
restoration = None, # type: ignore
|
||||||
|
configuration = None, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
||||||
return Invoker(
|
return Invoker(services=mock_services)
|
||||||
services = mock_services
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_can_create_graph_state(mock_invoker: Invoker):
|
def test_can_create_graph_state(mock_invoker: Invoker):
|
||||||
g = mock_invoker.create_execution_state()
|
g = mock_invoker.create_execution_state()
|
||||||
@ -29,6 +65,7 @@ def test_can_create_graph_state(mock_invoker: Invoker):
|
|||||||
assert g is not None
|
assert g is not None
|
||||||
assert isinstance(g, GraphExecutionState)
|
assert isinstance(g, GraphExecutionState)
|
||||||
|
|
||||||
|
|
||||||
def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
|
def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
|
||||||
g = mock_invoker.create_execution_state(graph=simple_graph)
|
g = mock_invoker.create_execution_state(graph=simple_graph)
|
||||||
mock_invoker.stop()
|
mock_invoker.stop()
|
||||||
@ -37,7 +74,8 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
|
|||||||
assert isinstance(g, GraphExecutionState)
|
assert isinstance(g, GraphExecutionState)
|
||||||
assert g.graph == simple_graph
|
assert g.graph == simple_graph
|
||||||
|
|
||||||
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
|
||||||
|
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||||
def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
||||||
g = mock_invoker.create_execution_state(graph=simple_graph)
|
g = mock_invoker.create_execution_state(graph=simple_graph)
|
||||||
invocation_id = mock_invoker.invoke(g)
|
invocation_id = mock_invoker.invoke(g)
|
||||||
@ -53,7 +91,8 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
|||||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||||
assert len(g.executed) > 0
|
assert len(g.executed) > 0
|
||||||
|
|
||||||
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
|
||||||
|
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||||
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
||||||
g = mock_invoker.create_execution_state(graph=simple_graph)
|
g = mock_invoker.create_execution_state(graph=simple_graph)
|
||||||
invocation_id = mock_invoker.invoke(g, invoke_all=True)
|
invocation_id = mock_invoker.invoke(g, invoke_all=True)
|
||||||
@ -69,7 +108,8 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
|||||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||||
assert g.is_complete()
|
assert g.is_complete()
|
||||||
|
|
||||||
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
|
||||||
|
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||||
def test_handles_errors(mock_invoker: Invoker):
|
def test_handles_errors(mock_invoker: Invoker):
|
||||||
g = mock_invoker.create_execution_state()
|
g = mock_invoker.create_execution_state()
|
||||||
g.graph.add_node(ErrorInvocation(id="1"))
|
g.graph.add_node(ErrorInvocation(id="1"))
|
||||||
@ -87,4 +127,4 @@ def test_handles_errors(mock_invoker: Invoker):
|
|||||||
assert g.has_error()
|
assert g.has_error()
|
||||||
assert g.is_complete()
|
assert g.is_complete()
|
||||||
|
|
||||||
assert all((i in g.errors for i in g.source_prepared_mapping['1']))
|
assert all((i in g.errors for i in g.source_prepared_mapping["1"]))
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
|
from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation
|
||||||
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||||
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
|
||||||
from invokeai.app.invocations.upscale import UpscaleInvocation
|
from invokeai.app.invocations.upscale import UpscaleInvocation
|
||||||
from invokeai.app.invocations.image import *
|
from invokeai.app.invocations.image import *
|
||||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||||
@ -18,7 +17,7 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
|
|||||||
|
|
||||||
# Tests
|
# Tests
|
||||||
def test_connections_are_compatible():
|
def test_connections_are_compatible():
|
||||||
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
from_field = "image"
|
from_field = "image"
|
||||||
to_node = UpscaleInvocation(id = "2")
|
to_node = UpscaleInvocation(id = "2")
|
||||||
to_field = "image"
|
to_field = "image"
|
||||||
@ -28,7 +27,7 @@ def test_connections_are_compatible():
|
|||||||
assert result == True
|
assert result == True
|
||||||
|
|
||||||
def test_connections_are_incompatible():
|
def test_connections_are_incompatible():
|
||||||
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
from_field = "image"
|
from_field = "image"
|
||||||
to_node = UpscaleInvocation(id = "2")
|
to_node = UpscaleInvocation(id = "2")
|
||||||
to_field = "strength"
|
to_field = "strength"
|
||||||
@ -38,7 +37,7 @@ def test_connections_are_incompatible():
|
|||||||
assert result == False
|
assert result == False
|
||||||
|
|
||||||
def test_connections_incompatible_with_invalid_fields():
|
def test_connections_incompatible_with_invalid_fields():
|
||||||
from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
from_field = "invalid_field"
|
from_field = "invalid_field"
|
||||||
to_node = UpscaleInvocation(id = "2")
|
to_node = UpscaleInvocation(id = "2")
|
||||||
to_field = "image"
|
to_field = "image"
|
||||||
@ -56,28 +55,28 @@ def test_connections_incompatible_with_invalid_fields():
|
|||||||
|
|
||||||
def test_graph_can_add_node():
|
def test_graph_can_add_node():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
|
|
||||||
assert n.id in g.nodes
|
assert n.id in g.nodes
|
||||||
|
|
||||||
def test_graph_fails_to_add_node_with_duplicate_id():
|
def test_graph_fails_to_add_node_with_duplicate_id():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
n2 = TextToImageInvocation(id = "1", prompt = "Banana sushi the second")
|
n2 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi the second")
|
||||||
|
|
||||||
with pytest.raises(NodeAlreadyInGraphError):
|
with pytest.raises(NodeAlreadyInGraphError):
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
|
|
||||||
def test_graph_updates_node():
|
def test_graph_updates_node():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second")
|
n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second")
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
|
|
||||||
nu = TextToImageInvocation(id = "1", prompt = "Banana sushi updated")
|
nu = TextToImageTestInvocation(id = "1", prompt = "Banana sushi updated")
|
||||||
|
|
||||||
g.update_node("1", nu)
|
g.update_node("1", nu)
|
||||||
|
|
||||||
@ -85,7 +84,7 @@ def test_graph_updates_node():
|
|||||||
|
|
||||||
def test_graph_fails_to_update_node_if_type_changes():
|
def test_graph_fails_to_update_node_if_type_changes():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -97,14 +96,14 @@ def test_graph_fails_to_update_node_if_type_changes():
|
|||||||
|
|
||||||
def test_graph_allows_non_conflicting_id_change():
|
def test_graph_allows_non_conflicting_id_change():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
e1 = create_edge(n.id,"image",n2.id,"image")
|
e1 = create_edge(n.id,"image",n2.id,"image")
|
||||||
g.add_edge(e1)
|
g.add_edge(e1)
|
||||||
|
|
||||||
nu = TextToImageInvocation(id = "3", prompt = "Banana sushi")
|
nu = TextToImageTestInvocation(id = "3", prompt = "Banana sushi")
|
||||||
g.update_node("1", nu)
|
g.update_node("1", nu)
|
||||||
|
|
||||||
with pytest.raises(NodeNotFoundError):
|
with pytest.raises(NodeNotFoundError):
|
||||||
@ -117,18 +116,18 @@ def test_graph_allows_non_conflicting_id_change():
|
|||||||
|
|
||||||
def test_graph_fails_to_update_node_id_if_conflict():
|
def test_graph_fails_to_update_node_id_if_conflict():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.add_node(n)
|
g.add_node(n)
|
||||||
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second")
|
n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second")
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
|
|
||||||
nu = TextToImageInvocation(id = "2", prompt = "Banana sushi")
|
nu = TextToImageTestInvocation(id = "2", prompt = "Banana sushi")
|
||||||
with pytest.raises(NodeAlreadyInGraphError):
|
with pytest.raises(NodeAlreadyInGraphError):
|
||||||
g.update_node("1", nu)
|
g.update_node("1", nu)
|
||||||
|
|
||||||
def test_graph_adds_edge():
|
def test_graph_adds_edge():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -148,7 +147,7 @@ def test_graph_fails_to_add_edge_with_cycle():
|
|||||||
|
|
||||||
def test_graph_fails_to_add_edge_with_long_cycle():
|
def test_graph_fails_to_add_edge_with_long_cycle():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
n3 = UpscaleInvocation(id = "3")
|
n3 = UpscaleInvocation(id = "3")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
@ -164,7 +163,7 @@ def test_graph_fails_to_add_edge_with_long_cycle():
|
|||||||
|
|
||||||
def test_graph_fails_to_add_edge_with_missing_node_id():
|
def test_graph_fails_to_add_edge_with_missing_node_id():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -177,7 +176,7 @@ def test_graph_fails_to_add_edge_with_missing_node_id():
|
|||||||
|
|
||||||
def test_graph_fails_to_add_edge_when_destination_exists():
|
def test_graph_fails_to_add_edge_when_destination_exists():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
n3 = UpscaleInvocation(id = "3")
|
n3 = UpscaleInvocation(id = "3")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
@ -194,7 +193,7 @@ def test_graph_fails_to_add_edge_when_destination_exists():
|
|||||||
|
|
||||||
def test_graph_fails_to_add_edge_with_mismatched_types():
|
def test_graph_fails_to_add_edge_with_mismatched_types():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -204,8 +203,8 @@ def test_graph_fails_to_add_edge_with_mismatched_types():
|
|||||||
|
|
||||||
def test_graph_connects_collector():
|
def test_graph_connects_collector():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi 2")
|
n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi 2")
|
||||||
n3 = CollectInvocation(id = "3")
|
n3 = CollectInvocation(id = "3")
|
||||||
n4 = ListPassThroughInvocation(id = "4")
|
n4 = ListPassThroughInvocation(id = "4")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
@ -224,7 +223,7 @@ def test_graph_connects_collector():
|
|||||||
|
|
||||||
def test_graph_collector_invalid_with_varying_input_types():
|
def test_graph_collector_invalid_with_varying_input_types():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2")
|
n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2")
|
||||||
n3 = CollectInvocation(id = "3")
|
n3 = CollectInvocation(id = "3")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
@ -282,7 +281,7 @@ def test_graph_connects_iterator():
|
|||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = ListPassThroughInvocation(id = "1")
|
n1 = ListPassThroughInvocation(id = "1")
|
||||||
n2 = IterateInvocation(id = "2")
|
n2 = IterateInvocation(id = "2")
|
||||||
n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi")
|
n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
g.add_node(n3)
|
g.add_node(n3)
|
||||||
@ -298,7 +297,7 @@ def test_graph_iterator_invalid_if_multiple_inputs():
|
|||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = ListPassThroughInvocation(id = "1")
|
n1 = ListPassThroughInvocation(id = "1")
|
||||||
n2 = IterateInvocation(id = "2")
|
n2 = IterateInvocation(id = "2")
|
||||||
n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi")
|
n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi")
|
||||||
n4 = ListPassThroughInvocation(id = "4")
|
n4 = ListPassThroughInvocation(id = "4")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -316,7 +315,7 @@ def test_graph_iterator_invalid_if_multiple_inputs():
|
|||||||
|
|
||||||
def test_graph_iterator_invalid_if_input_not_list():
|
def test_graph_iterator_invalid_if_input_not_list():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = IterateInvocation(id = "2")
|
n2 = IterateInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -344,7 +343,7 @@ def test_graph_iterator_invalid_if_output_and_input_types_different():
|
|||||||
|
|
||||||
def test_graph_validates():
|
def test_graph_validates():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -355,7 +354,7 @@ def test_graph_validates():
|
|||||||
|
|
||||||
def test_graph_invalid_if_edges_reference_missing_nodes():
|
def test_graph_invalid_if_edges_reference_missing_nodes():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
g.nodes[n1.id] = n1
|
g.nodes[n1.id] = n1
|
||||||
e1 = create_edge("1","image","2","image")
|
e1 = create_edge("1","image","2","image")
|
||||||
g.edges.append(e1)
|
g.edges.append(e1)
|
||||||
@ -367,7 +366,7 @@ def test_graph_invalid_if_subgraph_invalid():
|
|||||||
n1 = GraphInvocation(id = "1")
|
n1 = GraphInvocation(id = "1")
|
||||||
n1.graph = Graph()
|
n1.graph = Graph()
|
||||||
|
|
||||||
n1_1 = TextToImageInvocation(id = "2", prompt = "Banana sushi")
|
n1_1 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi")
|
||||||
n1.graph.nodes[n1_1.id] = n1_1
|
n1.graph.nodes[n1_1.id] = n1_1
|
||||||
e1 = create_edge("1","image","2","image")
|
e1 = create_edge("1","image","2","image")
|
||||||
n1.graph.edges.append(e1)
|
n1.graph.edges.append(e1)
|
||||||
@ -391,7 +390,7 @@ def test_graph_invalid_if_has_cycle():
|
|||||||
|
|
||||||
def test_graph_invalid_with_invalid_connection():
|
def test_graph_invalid_with_invalid_connection():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.nodes[n1.id] = n1
|
g.nodes[n1.id] = n1
|
||||||
g.nodes[n2.id] = n2
|
g.nodes[n2.id] = n2
|
||||||
@ -408,7 +407,7 @@ def test_graph_gets_subgraph_node():
|
|||||||
n1.graph = Graph()
|
n1.graph = Graph()
|
||||||
n1.graph.add_node
|
n1.graph.add_node
|
||||||
|
|
||||||
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n1.graph.add_node(n1_1)
|
n1.graph.add_node(n1_1)
|
||||||
|
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
@ -485,7 +484,7 @@ def test_graph_fails_to_get_missing_subgraph_node():
|
|||||||
n1.graph = Graph()
|
n1.graph = Graph()
|
||||||
n1.graph.add_node
|
n1.graph.add_node
|
||||||
|
|
||||||
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n1.graph.add_node(n1_1)
|
n1.graph.add_node(n1_1)
|
||||||
|
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
@ -499,7 +498,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
|
|||||||
n1.graph = Graph()
|
n1.graph = Graph()
|
||||||
n1.graph.add_node
|
n1.graph.add_node
|
||||||
|
|
||||||
n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n1.graph.add_node(n1_1)
|
n1.graph.add_node(n1_1)
|
||||||
|
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
@ -512,7 +511,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
|
|||||||
|
|
||||||
def test_graph_gets_networkx_graph():
|
def test_graph_gets_networkx_graph():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -529,7 +528,7 @@ def test_graph_gets_networkx_graph():
|
|||||||
# TODO: Graph serializes and deserializes
|
# TODO: Graph serializes and deserializes
|
||||||
def test_graph_can_serialize():
|
def test_graph_can_serialize():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
@ -541,7 +540,7 @@ def test_graph_can_serialize():
|
|||||||
|
|
||||||
def test_graph_can_deserialize():
|
def test_graph_can_deserialize():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi")
|
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||||
n2 = UpscaleInvocation(id = "2")
|
n2 = UpscaleInvocation(id = "2")
|
||||||
g.add_node(n1)
|
g.add_node(n1)
|
||||||
g.add_node(n2)
|
g.add_node(n2)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Callable, Literal
|
from typing import Any, Callable, Literal, Union
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
from invokeai.app.invocations.image import ImageField
|
from invokeai.app.invocations.image import ImageField
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
@ -43,14 +43,23 @@ class ImageTestInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
image: ImageField = Field()
|
image: ImageField = Field()
|
||||||
|
|
||||||
class ImageTestInvocation(BaseInvocation):
|
class TextToImageTestInvocation(BaseInvocation):
|
||||||
type: Literal['test_image'] = 'test_image'
|
type: Literal['test_text_to_image'] = 'test_text_to_image'
|
||||||
|
|
||||||
prompt: str = Field(default = "")
|
prompt: str = Field(default = "")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||||
|
|
||||||
|
class ImageToImageTestInvocation(BaseInvocation):
|
||||||
|
type: Literal['test_image_to_image'] = 'test_image_to_image'
|
||||||
|
|
||||||
|
prompt: str = Field(default = "")
|
||||||
|
image: Union[ImageField, None] = Field(default=None)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||||
|
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||||
|
|
||||||
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
|
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
|
||||||
type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output'
|
type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output'
|
||||||
collection: list[str] = Field(default_factory=list)
|
collection: list[str] = Field(default_factory=list)
|
||||||
@ -62,7 +71,6 @@ class PromptCollectionTestInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||||
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
||||||
|
|
||||||
|
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.graph import Edge, EdgeConnection
|
from invokeai.app.services.graph import Edge, EdgeConnection
|
||||||
|
|
||||||
|
@ -1,301 +0,0 @@
|
|||||||
|
|
||||||
import unittest
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion import TextualInversionManager
|
|
||||||
|
|
||||||
|
|
||||||
KNOWN_WORDS = ['a', 'b', 'c']
|
|
||||||
KNOWN_WORDS_TOKEN_IDS = [0, 1, 2]
|
|
||||||
UNKNOWN_WORDS = ['d', 'e', 'f']
|
|
||||||
|
|
||||||
class DummyEmbeddingsList(list):
|
|
||||||
def __getattr__(self, name):
|
|
||||||
if name == 'num_embeddings':
|
|
||||||
return len(self)
|
|
||||||
elif name == 'weight':
|
|
||||||
return self
|
|
||||||
elif name == 'data':
|
|
||||||
return self
|
|
||||||
|
|
||||||
def make_dummy_embedding():
|
|
||||||
return torch.randn([768])
|
|
||||||
|
|
||||||
class DummyTransformer:
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.embeddings = DummyEmbeddingsList([make_dummy_embedding() for _ in range(len(KNOWN_WORDS))])
|
|
||||||
|
|
||||||
def resize_token_embeddings(self, new_size=None):
|
|
||||||
if new_size is None:
|
|
||||||
return self.embeddings
|
|
||||||
else:
|
|
||||||
while len(self.embeddings) > new_size:
|
|
||||||
self.embeddings.pop(-1)
|
|
||||||
while len(self.embeddings) < new_size:
|
|
||||||
self.embeddings.append(make_dummy_embedding())
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.embeddings
|
|
||||||
|
|
||||||
class DummyTokenizer():
|
|
||||||
def __init__(self):
|
|
||||||
self.tokens = KNOWN_WORDS.copy()
|
|
||||||
self.bos_token_id = 49406 # these are what the real CLIPTokenizer has
|
|
||||||
self.eos_token_id = 49407
|
|
||||||
self.pad_token_id = 49407
|
|
||||||
self.unk_token_id = 49407
|
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, token_str):
|
|
||||||
try:
|
|
||||||
return self.tokens.index(token_str)
|
|
||||||
except ValueError:
|
|
||||||
return self.unk_token_id
|
|
||||||
|
|
||||||
def add_tokens(self, token_str):
|
|
||||||
if token_str in self.tokens:
|
|
||||||
return 0
|
|
||||||
self.tokens.append(token_str)
|
|
||||||
return 1
|
|
||||||
|
|
||||||
|
|
||||||
class DummyClipEmbedder:
|
|
||||||
def __init__(self):
|
|
||||||
self.max_length = 77
|
|
||||||
self.transformer = DummyTransformer()
|
|
||||||
self.tokenizer = DummyTokenizer()
|
|
||||||
self.position_embeddings_tensor = torch.randn([77,768], dtype=torch.float32)
|
|
||||||
|
|
||||||
def position_embedding(self, indices: Union[list,torch.Tensor]):
|
|
||||||
if type(indices) is list:
|
|
||||||
indices = torch.tensor(indices, dtype=int)
|
|
||||||
return torch.index_select(self.position_embeddings_tensor, 0, indices)
|
|
||||||
|
|
||||||
|
|
||||||
def was_embedding_overwritten_correctly(tim: TextualInversionManager, overwritten_embedding: torch.Tensor, ti_indices: list, ti_embedding: torch.Tensor) -> bool:
|
|
||||||
return torch.allclose(overwritten_embedding[ti_indices], ti_embedding + tim.clip_embedder.position_embedding(ti_indices))
|
|
||||||
|
|
||||||
|
|
||||||
def make_dummy_textual_inversion_manager():
|
|
||||||
return TextualInversionManager(
|
|
||||||
tokenizer=DummyTokenizer(),
|
|
||||||
text_encoder=DummyTransformer()
|
|
||||||
)
|
|
||||||
|
|
||||||
class TextualInversionManagerTestCase(unittest.TestCase):
|
|
||||||
|
|
||||||
|
|
||||||
def test_construction(self):
|
|
||||||
tim = make_dummy_textual_inversion_manager()
|
|
||||||
|
|
||||||
def test_add_embedding_for_known_token(self):
|
|
||||||
tim = make_dummy_textual_inversion_manager()
|
|
||||||
test_embedding = torch.randn([1, 768])
|
|
||||||
test_embedding_name = KNOWN_WORDS[0]
|
|
||||||
self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
|
|
||||||
|
|
||||||
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
|
||||||
|
|
||||||
ti = tim._add_textual_inversion(test_embedding_name, test_embedding)
|
|
||||||
self.assertEqual(ti.trigger_token_id, 0)
|
|
||||||
|
|
||||||
|
|
||||||
# check adding 'test' did not create a new word
|
|
||||||
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
|
||||||
self.assertEqual(pre_embeddings_count, embeddings_count)
|
|
||||||
|
|
||||||
# check it was added
|
|
||||||
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
|
|
||||||
textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
|
|
||||||
self.assertIsNotNone(textual_inversion)
|
|
||||||
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding))
|
|
||||||
self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
|
|
||||||
self.assertEqual(textual_inversion.trigger_token_id, ti.trigger_token_id)
|
|
||||||
|
|
||||||
def test_add_embedding_for_unknown_token(self):
|
|
||||||
tim = make_dummy_textual_inversion_manager()
|
|
||||||
test_embedding_1 = torch.randn([1, 768])
|
|
||||||
test_embedding_name_1 = UNKNOWN_WORDS[0]
|
|
||||||
|
|
||||||
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
|
||||||
|
|
||||||
added_token_id_1 = tim._add_textual_inversion(test_embedding_name_1, test_embedding_1).trigger_token_id
|
|
||||||
# new token id should get added on the end
|
|
||||||
self.assertEqual(added_token_id_1, len(KNOWN_WORDS))
|
|
||||||
|
|
||||||
# check adding did create a new word
|
|
||||||
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
|
||||||
self.assertEqual(pre_embeddings_count+1, embeddings_count)
|
|
||||||
|
|
||||||
# check it was added
|
|
||||||
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1))
|
|
||||||
textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_1)
|
|
||||||
self.assertIsNotNone(textual_inversion)
|
|
||||||
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1))
|
|
||||||
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1)
|
|
||||||
self.assertEqual(textual_inversion.trigger_token_id, added_token_id_1)
|
|
||||||
|
|
||||||
# add another one
|
|
||||||
test_embedding_2 = torch.randn([1, 768])
|
|
||||||
test_embedding_name_2 = UNKNOWN_WORDS[1]
|
|
||||||
|
|
||||||
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
|
||||||
|
|
||||||
added_token_id_2 = tim._add_textual_inversion(test_embedding_name_2, test_embedding_2).trigger_token_id
|
|
||||||
self.assertEqual(added_token_id_2, len(KNOWN_WORDS)+1)
|
|
||||||
|
|
||||||
# check adding did create a new word
|
|
||||||
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
|
||||||
self.assertEqual(pre_embeddings_count+1, embeddings_count)
|
|
||||||
|
|
||||||
# check it was added
|
|
||||||
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_2))
|
|
||||||
textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_2)
|
|
||||||
self.assertIsNotNone(textual_inversion)
|
|
||||||
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_2))
|
|
||||||
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_2)
|
|
||||||
self.assertEqual(textual_inversion.trigger_token_id, added_token_id_2)
|
|
||||||
|
|
||||||
# check the old one is still there
|
|
||||||
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1))
|
|
||||||
textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_1)
|
|
||||||
self.assertIsNotNone(textual_inversion)
|
|
||||||
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1))
|
|
||||||
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1)
|
|
||||||
self.assertEqual(textual_inversion.trigger_token_id, added_token_id_1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_pad_raises_on_eos_bos(self):
|
|
||||||
tim = make_dummy_textual_inversion_manager()
|
|
||||||
prompt_token_ids_with_eos_bos = [tim.tokenizer.bos_token_id] + \
|
|
||||||
[KNOWN_WORDS_TOKEN_IDS] + \
|
|
||||||
[tim.tokenizer.eos_token_id]
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_with_eos_bos)
|
|
||||||
|
|
||||||
def test_pad_tokens_list_vector_length_1(self):
|
|
||||||
tim = make_dummy_textual_inversion_manager()
|
|
||||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
|
|
||||||
|
|
||||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
|
|
||||||
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
|
||||||
|
|
||||||
test_embedding_1v = torch.randn([1, 768])
|
|
||||||
test_embedding_1v_token = "<inversion-trigger-vector-length-1>"
|
|
||||||
test_embedding_1v_token_id = tim._add_textual_inversion(test_embedding_1v_token, test_embedding_1v).trigger_token_id
|
|
||||||
self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS))
|
|
||||||
|
|
||||||
# at the end
|
|
||||||
prompt_token_ids_1v_append = prompt_token_ids + [test_embedding_1v_token_id]
|
|
||||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_append)
|
|
||||||
self.assertEqual(prompt_token_ids_1v_append, expanded_prompt_token_ids)
|
|
||||||
|
|
||||||
# at the start
|
|
||||||
prompt_token_ids_1v_prepend = [test_embedding_1v_token_id] + prompt_token_ids
|
|
||||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_prepend)
|
|
||||||
self.assertEqual(prompt_token_ids_1v_prepend, expanded_prompt_token_ids)
|
|
||||||
|
|
||||||
# in the middle
|
|
||||||
prompt_token_ids_1v_insert = prompt_token_ids[0:2] + [test_embedding_1v_token_id] + prompt_token_ids[2:3]
|
|
||||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_insert)
|
|
||||||
self.assertEqual(prompt_token_ids_1v_insert, expanded_prompt_token_ids)
|
|
||||||
|
|
||||||
def test_pad_tokens_list_vector_length_2(self):
|
|
||||||
tim = make_dummy_textual_inversion_manager()
|
|
||||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
|
|
||||||
|
|
||||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
|
|
||||||
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
|
||||||
|
|
||||||
test_embedding_2v = torch.randn([2, 768])
|
|
||||||
test_embedding_2v_token = "<inversion-trigger-vector-length-2>"
|
|
||||||
test_embedding_2v_token_id = tim._add_textual_inversion(test_embedding_2v_token, test_embedding_2v).trigger_token_id
|
|
||||||
test_embedding_2v_pad_token_ids = tim.get_textual_inversion_for_token_id(test_embedding_2v_token_id).pad_token_ids
|
|
||||||
self.assertEqual(test_embedding_2v_token_id, len(KNOWN_WORDS))
|
|
||||||
|
|
||||||
# at the end
|
|
||||||
prompt_token_ids_2v_append = prompt_token_ids + [test_embedding_2v_token_id]
|
|
||||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_append)
|
|
||||||
self.assertNotEqual(prompt_token_ids_2v_append, expanded_prompt_token_ids)
|
|
||||||
self.assertEqual(prompt_token_ids + [test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids, expanded_prompt_token_ids)
|
|
||||||
|
|
||||||
# at the start
|
|
||||||
prompt_token_ids_2v_prepend = [test_embedding_2v_token_id] + prompt_token_ids
|
|
||||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_prepend)
|
|
||||||
self.assertNotEqual(prompt_token_ids_2v_prepend, expanded_prompt_token_ids)
|
|
||||||
self.assertEqual([test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids + prompt_token_ids, expanded_prompt_token_ids)
|
|
||||||
|
|
||||||
# in the middle
|
|
||||||
prompt_token_ids_2v_insert = prompt_token_ids[0:2] + [test_embedding_2v_token_id] + prompt_token_ids[2:3]
|
|
||||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_insert)
|
|
||||||
self.assertNotEqual(prompt_token_ids_2v_insert, expanded_prompt_token_ids)
|
|
||||||
self.assertEqual(prompt_token_ids[0:2] + [test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids + prompt_token_ids[2:3], expanded_prompt_token_ids)
|
|
||||||
|
|
||||||
def test_pad_tokens_list_vector_length_8(self):
|
|
||||||
tim = make_dummy_textual_inversion_manager()
|
|
||||||
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
|
|
||||||
|
|
||||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
|
|
||||||
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
|
||||||
|
|
||||||
test_embedding_8v = torch.randn([8, 768])
|
|
||||||
test_embedding_8v_token = "<inversion-trigger-vector-length-8>"
|
|
||||||
test_embedding_8v_token_id = tim._add_textual_inversion(test_embedding_8v_token, test_embedding_8v).trigger_token_id
|
|
||||||
test_embedding_8v_pad_token_ids = tim.get_textual_inversion_for_token_id(test_embedding_8v_token_id).pad_token_ids
|
|
||||||
self.assertEqual(test_embedding_8v_token_id, len(KNOWN_WORDS))
|
|
||||||
|
|
||||||
# at the end
|
|
||||||
prompt_token_ids_8v_append = prompt_token_ids + [test_embedding_8v_token_id]
|
|
||||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_append)
|
|
||||||
self.assertNotEqual(prompt_token_ids_8v_append, expanded_prompt_token_ids)
|
|
||||||
self.assertEqual(prompt_token_ids + [test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids, expanded_prompt_token_ids)
|
|
||||||
|
|
||||||
# at the start
|
|
||||||
prompt_token_ids_8v_prepend = [test_embedding_8v_token_id] + prompt_token_ids
|
|
||||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_prepend)
|
|
||||||
self.assertNotEqual(prompt_token_ids_8v_prepend, expanded_prompt_token_ids)
|
|
||||||
self.assertEqual([test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids + prompt_token_ids, expanded_prompt_token_ids)
|
|
||||||
|
|
||||||
# in the middle
|
|
||||||
prompt_token_ids_8v_insert = prompt_token_ids[0:2] + [test_embedding_8v_token_id] + prompt_token_ids[2:3]
|
|
||||||
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_insert)
|
|
||||||
self.assertNotEqual(prompt_token_ids_8v_insert, expanded_prompt_token_ids)
|
|
||||||
self.assertEqual(prompt_token_ids[0:2] + [test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids + prompt_token_ids[2:3], expanded_prompt_token_ids)
|
|
||||||
|
|
||||||
|
|
||||||
def test_deferred_loading(self):
|
|
||||||
tim = make_dummy_textual_inversion_manager()
|
|
||||||
test_embedding = torch.randn([1, 768])
|
|
||||||
test_embedding_name = UNKNOWN_WORDS[0]
|
|
||||||
self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
|
|
||||||
|
|
||||||
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
|
||||||
|
|
||||||
ti = tim._add_textual_inversion(test_embedding_name, test_embedding, defer_injecting_tokens=True)
|
|
||||||
self.assertIsNone(ti.trigger_token_id)
|
|
||||||
|
|
||||||
# check that a new word is not yet created
|
|
||||||
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
|
||||||
self.assertEqual(pre_embeddings_count, embeddings_count)
|
|
||||||
|
|
||||||
# check it was added
|
|
||||||
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
|
|
||||||
textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
|
|
||||||
self.assertIsNotNone(textual_inversion)
|
|
||||||
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding))
|
|
||||||
self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
|
|
||||||
self.assertIsNone(textual_inversion.trigger_token_id, ti.trigger_token_id)
|
|
||||||
|
|
||||||
# check it lazy-loads
|
|
||||||
prompt = " ".join([KNOWN_WORDS[0], UNKNOWN_WORDS[0], KNOWN_WORDS[1]])
|
|
||||||
tim.create_deferred_token_ids_for_any_trigger_terms(prompt)
|
|
||||||
|
|
||||||
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
|
||||||
self.assertEqual(pre_embeddings_count+1, embeddings_count)
|
|
||||||
|
|
||||||
textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
|
|
||||||
self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
|
|
||||||
self.assertEqual(textual_inversion.trigger_token_id, len(KNOWN_WORDS))
|
|
Loading…
Reference in New Issue
Block a user