From c00aea7a6ce7e257c2c7324777dc0b15c01cedfd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 29 Jun 2023 16:01:17 +1000 Subject: [PATCH 01/51] tests(nodes): fix nodes tests --- invokeai/app/services/invocation_services.py | 67 +++-- tests/conftest.py | 30 -- tests/nodes/test_graph_execution_state.py | 130 +++++--- tests/nodes/test_invoker.py | 88 ++++-- tests/nodes/test_node_graph.py | 75 +++-- tests/nodes/test_nodes.py | 16 +- tests/test_textual_inversion.py | 301 ------------------- 7 files changed, 242 insertions(+), 465 deletions(-) delete mode 100644 tests/conftest.py delete mode 100644 tests/test_textual_inversion.py diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 10d1d91920..4e1da3b040 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from invokeai.app.services.board_images import BoardImagesServiceABC from invokeai.app.services.boards import BoardServiceABC from invokeai.app.services.images import ImageServiceABC - from invokeai.backend import ModelManager + from invokeai.app.services.model_manager_service import ModelManagerServiceBase from invokeai.app.services.events import EventServiceBase from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.restoration_services import RestorationServices @@ -22,46 +22,47 @@ class InvocationServices: """Services that can be used by invocations""" # TODO: Just forward-declared everything due to circular dependencies. Fix structure. - events: "EventServiceBase" - latents: "LatentsStorageBase" - queue: "InvocationQueueABC" - model_manager: "ModelManager" - restoration: "RestorationServices" - configuration: "InvokeAISettings" - images: "ImageServiceABC" - boards: "BoardServiceABC" board_images: "BoardImagesServiceABC" - graph_library: "ItemStorageABC"["LibraryGraph"] + boards: "BoardServiceABC" + configuration: "InvokeAISettings" + events: "EventServiceBase" graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] + graph_library: "ItemStorageABC"["LibraryGraph"] + images: "ImageServiceABC" + latents: "LatentsStorageBase" + logger: "Logger" + model_manager: "ModelManagerServiceBase" processor: "InvocationProcessorABC" + queue: "InvocationQueueABC" + restoration: "RestorationServices" def __init__( self, - model_manager: "ModelManager", - events: "EventServiceBase", - logger: "Logger", - latents: "LatentsStorageBase", - images: "ImageServiceABC", - boards: "BoardServiceABC", board_images: "BoardImagesServiceABC", - queue: "InvocationQueueABC", - graph_library: "ItemStorageABC"["LibraryGraph"], - graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], - processor: "InvocationProcessorABC", - restoration: "RestorationServices", + boards: "BoardServiceABC", configuration: "InvokeAISettings", + events: "EventServiceBase", + graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], + graph_library: "ItemStorageABC"["LibraryGraph"], + images: "ImageServiceABC", + latents: "LatentsStorageBase", + logger: "Logger", + model_manager: "ModelManagerServiceBase", + processor: "InvocationProcessorABC", + queue: "InvocationQueueABC", + restoration: "RestorationServices", ): - self.model_manager = model_manager - self.events = events - self.logger = logger - self.latents = latents - self.images = images - self.boards = boards self.board_images = board_images - self.queue = queue - self.graph_library = graph_library - self.graph_execution_manager = graph_execution_manager - self.processor = processor - self.restoration = restoration - self.configuration = configuration self.boards = boards + self.boards = boards + self.configuration = configuration + self.events = events + self.graph_execution_manager = graph_execution_manager + self.graph_library = graph_library + self.images = images + self.latents = latents + self.logger = logger + self.model_manager = model_manager + self.processor = processor + self.queue = queue + self.restoration = restoration diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 06502b6c41..0000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest -from invokeai.app.services.invocation_services import InvocationServices -from invokeai.app.services.invocation_queue import MemoryInvocationQueue -from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory -from invokeai.app.services.graph import LibraryGraph, GraphExecutionState -from invokeai.app.services.processor import DefaultInvocationProcessor - -# Ignore these files as they need to be rewritten following the model manager refactor -collect_ignore = ["nodes/test_graph_execution_state.py", "nodes/test_node_graph.py", "test_textual_inversion.py"] - -@pytest.fixture(scope="session", autouse=True) -def mock_services(): - # NOTE: none of these are actually called by the test invocations - return InvocationServices( - model_manager = None, # type: ignore - events = None, # type: ignore - logger = None, # type: ignore - images = None, # type: ignore - latents = None, # type: ignore - board_images=None, # type: ignore - boards=None, # type: ignore - queue = MemoryInvocationQueue(), - graph_library=SqliteItemStorage[LibraryGraph]( - filename=sqlite_memory, table_name="graphs" - ), - graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor(), - restoration = None, # type: ignore - configuration = None, # type: ignore - ) diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index df8964da18..f34b18310b 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -1,41 +1,81 @@ -import pytest - -from invokeai.app.invocations.baseinvocation import (BaseInvocation, - BaseInvocationOutput, - InvocationContext) +from .test_invoker import create_edge +from .test_nodes import ( + TestEventService, + TextToImageTestInvocation, + PromptTestInvocation, + PromptCollectionTestInvocation, +) +from invokeai.app.services.invocation_queue import MemoryInvocationQueue +from invokeai.app.services.processor import DefaultInvocationProcessor +from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InvocationContext, +) from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation -from invokeai.app.services.graph import (CollectInvocation, Graph, - GraphExecutionState, - IterateInvocation) from invokeai.app.services.invocation_services import InvocationServices - -from .test_invoker import create_edge -from .test_nodes import (ImageTestInvocation, PromptCollectionTestInvocation, - PromptTestInvocation) +from invokeai.app.services.graph import ( + Graph, + CollectInvocation, + IterateInvocation, + GraphExecutionState, + LibraryGraph, +) +import pytest @pytest.fixture def simple_graph(): g = Graph() - g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi")) - g.add_node(ImageTestInvocation(id = "2")) + g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) + g.add_node(TextToImageTestInvocation(id="2")) g.add_edge(create_edge("1", "prompt", "2", "prompt")) return g -def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: + +# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types +# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate +# the test invocations. +@pytest.fixture +def mock_services() -> InvocationServices: + # NOTE: none of these are actually called by the test invocations + return InvocationServices( + model_manager = None, # type: ignore + events = TestEventService(), + logger = None, # type: ignore + images = None, # type: ignore + latents = None, # type: ignore + boards = None, # type: ignore + board_images = None, # type: ignore + queue = MemoryInvocationQueue(), + graph_library=SqliteItemStorage[LibraryGraph]( + filename=sqlite_memory, table_name="graphs" + ), + graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), + processor = DefaultInvocationProcessor(), + restoration = None, # type: ignore + configuration = None, # type: ignore + ) + + +def invoke_next( + g: GraphExecutionState, services: InvocationServices +) -> tuple[BaseInvocation, BaseInvocationOutput]: n = g.next() if n is None: return (None, None) - print(f'invoking {n.id}: {type(n)}') + print(f"invoking {n.id}: {type(n)}") o = n.invoke(InvocationContext(services, "1")) g.complete(n.id, o) return (n, o) + def test_graph_state_executes_in_order(simple_graph, mock_services): - g = GraphExecutionState(graph = simple_graph) + g = GraphExecutionState(graph=simple_graph) n1 = invoke_next(g, mock_services) n2 = invoke_next(g, mock_services) @@ -47,38 +87,42 @@ def test_graph_state_executes_in_order(simple_graph, mock_services): assert g.results[n1[0].id].prompt == n1[0].prompt assert n2[0].prompt == n1[0].prompt + def test_graph_is_complete(simple_graph, mock_services): - g = GraphExecutionState(graph = simple_graph) + g = GraphExecutionState(graph=simple_graph) n1 = invoke_next(g, mock_services) n2 = invoke_next(g, mock_services) n3 = g.next() assert g.is_complete() + def test_graph_is_not_complete(simple_graph, mock_services): - g = GraphExecutionState(graph = simple_graph) + g = GraphExecutionState(graph=simple_graph) n1 = invoke_next(g, mock_services) n2 = g.next() assert not g.is_complete() + # TODO: test completion with iterators/subgraphs + def test_graph_state_expands_iterator(mock_services): graph = Graph() - graph.add_node(RangeInvocation(id = "0", start = 0, stop = 3, step = 1)) - graph.add_node(IterateInvocation(id = "1")) - graph.add_node(MultiplyInvocation(id = "2", b = 10)) - graph.add_node(AddInvocation(id = "3", b = 1)) + graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1)) + graph.add_node(IterateInvocation(id="1")) + graph.add_node(MultiplyInvocation(id="2", b=10)) + graph.add_node(AddInvocation(id="3", b=1)) graph.add_edge(create_edge("0", "collection", "1", "collection")) graph.add_edge(create_edge("1", "item", "2", "a")) graph.add_edge(create_edge("2", "a", "3", "a")) - g = GraphExecutionState(graph = graph) + g = GraphExecutionState(graph=graph) while not g.is_complete(): invoke_next(g, mock_services) - prepared_add_nodes = g.source_prepared_mapping['3'] + prepared_add_nodes = g.source_prepared_mapping["3"] results = set([g.results[n].a for n in prepared_add_nodes]) expected = set([1, 11, 21]) assert results == expected @@ -87,15 +131,17 @@ def test_graph_state_expands_iterator(mock_services): def test_graph_state_collects(mock_services): graph = Graph() test_prompts = ["Banana sushi", "Cat sushi"] - graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts))) - graph.add_node(IterateInvocation(id = "2")) - graph.add_node(PromptTestInvocation(id = "3")) - graph.add_node(CollectInvocation(id = "4")) + graph.add_node( + PromptCollectionTestInvocation(id="1", collection=list(test_prompts)) + ) + graph.add_node(IterateInvocation(id="2")) + graph.add_node(PromptTestInvocation(id="3")) + graph.add_node(CollectInvocation(id="4")) graph.add_edge(create_edge("1", "collection", "2", "collection")) graph.add_edge(create_edge("2", "item", "3", "prompt")) graph.add_edge(create_edge("3", "prompt", "4", "item")) - g = GraphExecutionState(graph = graph) + g = GraphExecutionState(graph=graph) n1 = invoke_next(g, mock_services) n2 = invoke_next(g, mock_services) n3 = invoke_next(g, mock_services) @@ -113,10 +159,16 @@ def test_graph_state_prepares_eagerly(mock_services): graph = Graph() test_prompts = ["Banana sushi", "Cat sushi"] - graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts))) + graph.add_node( + PromptCollectionTestInvocation( + id="prompt_collection", collection=list(test_prompts) + ) + ) graph.add_node(IterateInvocation(id="iterate")) graph.add_node(PromptTestInvocation(id="prompt_iterated")) - graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection")) + graph.add_edge( + create_edge("prompt_collection", "collection", "iterate", "collection") + ) graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt")) # separated, fully-preparable chain of nodes @@ -142,13 +194,21 @@ def test_graph_executes_depth_first(mock_services): graph = Graph() test_prompts = ["Banana sushi", "Cat sushi"] - graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts))) + graph.add_node( + PromptCollectionTestInvocation( + id="prompt_collection", collection=list(test_prompts) + ) + ) graph.add_node(IterateInvocation(id="iterate")) graph.add_node(PromptTestInvocation(id="prompt_iterated")) graph.add_node(PromptTestInvocation(id="prompt_successor")) - graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection")) + graph.add_edge( + create_edge("prompt_collection", "collection", "iterate", "collection") + ) graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt")) - graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt")) + graph.add_edge( + create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt") + ) g = GraphExecutionState(graph=graph) n1 = invoke_next(g, mock_services) diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 4331e62d21..19d7dd20b3 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -1,26 +1,62 @@ -import pytest - -from invokeai.app.services.graph import Graph, GraphExecutionState -from invokeai.app.services.invocation_services import InvocationServices +from .test_nodes import ( + TestEventService, + ErrorInvocation, + TextToImageTestInvocation, + PromptTestInvocation, + create_edge, + wait_until, +) +from invokeai.app.services.invocation_queue import MemoryInvocationQueue +from invokeai.app.services.processor import DefaultInvocationProcessor +from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.invoker import Invoker - -from .test_nodes import (ErrorInvocation, ImageTestInvocation, - PromptTestInvocation, create_edge, wait_until) +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.graph import ( + Graph, + GraphExecutionState, + LibraryGraph, +) +import pytest @pytest.fixture def simple_graph(): g = Graph() - g.add_node(PromptTestInvocation(id = "1", prompt = "Banana sushi")) - g.add_node(ImageTestInvocation(id = "2")) + g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) + g.add_node(TextToImageTestInvocation(id="2")) g.add_edge(create_edge("1", "prompt", "2", "prompt")) return g + +# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types +# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate +# the test invocations. +@pytest.fixture +def mock_services() -> InvocationServices: + # NOTE: none of these are actually called by the test invocations + return InvocationServices( + model_manager = None, # type: ignore + events = TestEventService(), + logger = None, # type: ignore + images = None, # type: ignore + latents = None, # type: ignore + boards = None, # type: ignore + board_images = None, # type: ignore + queue = MemoryInvocationQueue(), + graph_library=SqliteItemStorage[LibraryGraph]( + filename=sqlite_memory, table_name="graphs" + ), + graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), + processor = DefaultInvocationProcessor(), + restoration = None, # type: ignore + configuration = None, # type: ignore + ) + + @pytest.fixture() def mock_invoker(mock_services: InvocationServices) -> Invoker: - return Invoker( - services = mock_services - ) + return Invoker(services=mock_services) + def test_can_create_graph_state(mock_invoker: Invoker): g = mock_invoker.create_execution_state() @@ -29,17 +65,19 @@ def test_can_create_graph_state(mock_invoker: Invoker): assert g is not None assert isinstance(g, GraphExecutionState) + def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph = simple_graph) + g = mock_invoker.create_execution_state(graph=simple_graph) mock_invoker.stop() assert g is not None assert isinstance(g, GraphExecutionState) assert g.graph == simple_graph -@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") + +# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") def test_can_invoke(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph = simple_graph) + g = mock_invoker.create_execution_state(graph=simple_graph) invocation_id = mock_invoker.invoke(g) assert invocation_id is not None @@ -47,32 +85,34 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph): g = mock_invoker.services.graph_execution_manager.get(g.id) return len(g.executed) > 0 - wait_until(lambda: has_executed_any(g), timeout = 5, interval = 1) + wait_until(lambda: has_executed_any(g), timeout=5, interval=1) mock_invoker.stop() g = mock_invoker.services.graph_execution_manager.get(g.id) assert len(g.executed) > 0 -@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") + +# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") def test_can_invoke_all(mock_invoker: Invoker, simple_graph): - g = mock_invoker.create_execution_state(graph = simple_graph) - invocation_id = mock_invoker.invoke(g, invoke_all = True) + g = mock_invoker.create_execution_state(graph=simple_graph) + invocation_id = mock_invoker.invoke(g, invoke_all=True) assert invocation_id is not None def has_executed_all(g: GraphExecutionState): g = mock_invoker.services.graph_execution_manager.get(g.id) return g.is_complete() - wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1) + wait_until(lambda: has_executed_all(g), timeout=5, interval=1) mock_invoker.stop() g = mock_invoker.services.graph_execution_manager.get(g.id) assert g.is_complete() -@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") + +# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") def test_handles_errors(mock_invoker: Invoker): g = mock_invoker.create_execution_state() - g.graph.add_node(ErrorInvocation(id = "1")) + g.graph.add_node(ErrorInvocation(id="1")) mock_invoker.invoke(g, invoke_all=True) @@ -80,11 +120,11 @@ def test_handles_errors(mock_invoker: Invoker): g = mock_invoker.services.graph_execution_manager.get(g.id) return g.is_complete() - wait_until(lambda: has_executed_all(g), timeout = 5, interval = 1) + wait_until(lambda: has_executed_all(g), timeout=5, interval=1) mock_invoker.stop() g = mock_invoker.services.graph_execution_manager.get(g.id) assert g.has_error() assert g.is_complete() - assert all((i in g.errors for i in g.source_prepared_mapping['1'])) + assert all((i in g.errors for i in g.source_prepared_mapping["1"])) diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index 82818414b2..df7378150d 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -1,6 +1,5 @@ -from .test_nodes import ListPassThroughInvocation, PromptTestInvocation +from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation -from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation from invokeai.app.invocations.upscale import UpscaleInvocation from invokeai.app.invocations.image import * from invokeai.app.invocations.math import AddInvocation, SubtractInvocation @@ -18,7 +17,7 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg # Tests def test_connections_are_compatible(): - from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") + from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") from_field = "image" to_node = UpscaleInvocation(id = "2") to_field = "image" @@ -28,7 +27,7 @@ def test_connections_are_compatible(): assert result == True def test_connections_are_incompatible(): - from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") + from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") from_field = "image" to_node = UpscaleInvocation(id = "2") to_field = "strength" @@ -38,7 +37,7 @@ def test_connections_are_incompatible(): assert result == False def test_connections_incompatible_with_invalid_fields(): - from_node = TextToImageInvocation(id = "1", prompt = "Banana sushi") + from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") from_field = "invalid_field" to_node = UpscaleInvocation(id = "2") to_field = "image" @@ -56,28 +55,28 @@ def test_connections_incompatible_with_invalid_fields(): def test_graph_can_add_node(): g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.add_node(n) assert n.id in g.nodes def test_graph_fails_to_add_node_with_duplicate_id(): g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.add_node(n) - n2 = TextToImageInvocation(id = "1", prompt = "Banana sushi the second") + n2 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi the second") with pytest.raises(NodeAlreadyInGraphError): g.add_node(n2) def test_graph_updates_node(): g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.add_node(n) - n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second") + n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second") g.add_node(n2) - nu = TextToImageInvocation(id = "1", prompt = "Banana sushi updated") + nu = TextToImageTestInvocation(id = "1", prompt = "Banana sushi updated") g.update_node("1", nu) @@ -85,7 +84,7 @@ def test_graph_updates_node(): def test_graph_fails_to_update_node_if_type_changes(): g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.add_node(n) n2 = UpscaleInvocation(id = "2") g.add_node(n2) @@ -97,14 +96,14 @@ def test_graph_fails_to_update_node_if_type_changes(): def test_graph_allows_non_conflicting_id_change(): g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.add_node(n) n2 = UpscaleInvocation(id = "2") g.add_node(n2) e1 = create_edge(n.id,"image",n2.id,"image") g.add_edge(e1) - nu = TextToImageInvocation(id = "3", prompt = "Banana sushi") + nu = TextToImageTestInvocation(id = "3", prompt = "Banana sushi") g.update_node("1", nu) with pytest.raises(NodeNotFoundError): @@ -117,18 +116,18 @@ def test_graph_allows_non_conflicting_id_change(): def test_graph_fails_to_update_node_id_if_conflict(): g = Graph() - n = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.add_node(n) - n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi the second") + n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second") g.add_node(n2) - nu = TextToImageInvocation(id = "2", prompt = "Banana sushi") + nu = TextToImageTestInvocation(id = "2", prompt = "Banana sushi") with pytest.raises(NodeAlreadyInGraphError): g.update_node("1", nu) def test_graph_adds_edge(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.add_node(n1) g.add_node(n2) @@ -148,7 +147,7 @@ def test_graph_fails_to_add_edge_with_cycle(): def test_graph_fails_to_add_edge_with_long_cycle(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") n3 = UpscaleInvocation(id = "3") g.add_node(n1) @@ -164,7 +163,7 @@ def test_graph_fails_to_add_edge_with_long_cycle(): def test_graph_fails_to_add_edge_with_missing_node_id(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.add_node(n1) g.add_node(n2) @@ -177,7 +176,7 @@ def test_graph_fails_to_add_edge_with_missing_node_id(): def test_graph_fails_to_add_edge_when_destination_exists(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") n3 = UpscaleInvocation(id = "3") g.add_node(n1) @@ -194,7 +193,7 @@ def test_graph_fails_to_add_edge_when_destination_exists(): def test_graph_fails_to_add_edge_with_mismatched_types(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.add_node(n1) g.add_node(n2) @@ -204,8 +203,8 @@ def test_graph_fails_to_add_edge_with_mismatched_types(): def test_graph_connects_collector(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") - n2 = TextToImageInvocation(id = "2", prompt = "Banana sushi 2") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi 2") n3 = CollectInvocation(id = "3") n4 = ListPassThroughInvocation(id = "4") g.add_node(n1) @@ -224,7 +223,7 @@ def test_graph_connects_collector(): def test_graph_collector_invalid_with_varying_input_types(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2") n3 = CollectInvocation(id = "3") g.add_node(n1) @@ -282,7 +281,7 @@ def test_graph_connects_iterator(): g = Graph() n1 = ListPassThroughInvocation(id = "1") n2 = IterateInvocation(id = "2") - n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi") + n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi") g.add_node(n1) g.add_node(n2) g.add_node(n3) @@ -298,7 +297,7 @@ def test_graph_iterator_invalid_if_multiple_inputs(): g = Graph() n1 = ListPassThroughInvocation(id = "1") n2 = IterateInvocation(id = "2") - n3 = ImageToImageInvocation(id = "3", prompt = "Banana sushi") + n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi") n4 = ListPassThroughInvocation(id = "4") g.add_node(n1) g.add_node(n2) @@ -316,7 +315,7 @@ def test_graph_iterator_invalid_if_multiple_inputs(): def test_graph_iterator_invalid_if_input_not_list(): g = Graph() - n1 = TextToImageInvocation(id = "1", promopt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = IterateInvocation(id = "2") g.add_node(n1) g.add_node(n2) @@ -344,7 +343,7 @@ def test_graph_iterator_invalid_if_output_and_input_types_different(): def test_graph_validates(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.add_node(n1) g.add_node(n2) @@ -355,7 +354,7 @@ def test_graph_validates(): def test_graph_invalid_if_edges_reference_missing_nodes(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") g.nodes[n1.id] = n1 e1 = create_edge("1","image","2","image") g.edges.append(e1) @@ -367,7 +366,7 @@ def test_graph_invalid_if_subgraph_invalid(): n1 = GraphInvocation(id = "1") n1.graph = Graph() - n1_1 = TextToImageInvocation(id = "2", prompt = "Banana sushi") + n1_1 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi") n1.graph.nodes[n1_1.id] = n1_1 e1 = create_edge("1","image","2","image") n1.graph.edges.append(e1) @@ -391,7 +390,7 @@ def test_graph_invalid_if_has_cycle(): def test_graph_invalid_with_invalid_connection(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.nodes[n1.id] = n1 g.nodes[n2.id] = n2 @@ -408,7 +407,7 @@ def test_graph_gets_subgraph_node(): n1.graph = Graph() n1.graph.add_node - n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1.graph.add_node(n1_1) g.add_node(n1) @@ -485,7 +484,7 @@ def test_graph_fails_to_get_missing_subgraph_node(): n1.graph = Graph() n1.graph.add_node - n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1.graph.add_node(n1_1) g.add_node(n1) @@ -499,7 +498,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node(): n1.graph = Graph() n1.graph.add_node - n1_1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n1.graph.add_node(n1_1) g.add_node(n1) @@ -512,7 +511,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node(): def test_graph_gets_networkx_graph(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.add_node(n1) g.add_node(n2) @@ -529,7 +528,7 @@ def test_graph_gets_networkx_graph(): # TODO: Graph serializes and deserializes def test_graph_can_serialize(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.add_node(n1) g.add_node(n2) @@ -541,7 +540,7 @@ def test_graph_can_serialize(): def test_graph_can_deserialize(): g = Graph() - n1 = TextToImageInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") n2 = UpscaleInvocation(id = "2") g.add_node(n1) g.add_node(n2) diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py index d16d67d815..af011954c5 100644 --- a/tests/nodes/test_nodes.py +++ b/tests/nodes/test_nodes.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, Union from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.image import ImageField from invokeai.app.services.invocation_services import InvocationServices @@ -43,14 +43,23 @@ class ImageTestInvocationOutput(BaseInvocationOutput): image: ImageField = Field() -class ImageTestInvocation(BaseInvocation): - type: Literal['test_image'] = 'test_image' +class TextToImageTestInvocation(BaseInvocation): + type: Literal['test_text_to_image'] = 'test_text_to_image' prompt: str = Field(default = "") def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) +class ImageToImageTestInvocation(BaseInvocation): + type: Literal['test_image_to_image'] = 'test_image_to_image' + + prompt: str = Field(default = "") + image: Union[ImageField, None] = Field(default=None) + + def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: + return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) + class PromptCollectionTestInvocationOutput(BaseInvocationOutput): type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output' collection: list[str] = Field(default_factory=list) @@ -62,7 +71,6 @@ class PromptCollectionTestInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) - from invokeai.app.services.events import EventServiceBase from invokeai.app.services.graph import Edge, EdgeConnection diff --git a/tests/test_textual_inversion.py b/tests/test_textual_inversion.py deleted file mode 100644 index 53d2f2bfe6..0000000000 --- a/tests/test_textual_inversion.py +++ /dev/null @@ -1,301 +0,0 @@ - -import unittest -from typing import Union - -import torch - -from invokeai.backend.stable_diffusion import TextualInversionManager - - -KNOWN_WORDS = ['a', 'b', 'c'] -KNOWN_WORDS_TOKEN_IDS = [0, 1, 2] -UNKNOWN_WORDS = ['d', 'e', 'f'] - -class DummyEmbeddingsList(list): - def __getattr__(self, name): - if name == 'num_embeddings': - return len(self) - elif name == 'weight': - return self - elif name == 'data': - return self - -def make_dummy_embedding(): - return torch.randn([768]) - -class DummyTransformer: - - - def __init__(self): - self.embeddings = DummyEmbeddingsList([make_dummy_embedding() for _ in range(len(KNOWN_WORDS))]) - - def resize_token_embeddings(self, new_size=None): - if new_size is None: - return self.embeddings - else: - while len(self.embeddings) > new_size: - self.embeddings.pop(-1) - while len(self.embeddings) < new_size: - self.embeddings.append(make_dummy_embedding()) - - def get_input_embeddings(self): - return self.embeddings - -class DummyTokenizer(): - def __init__(self): - self.tokens = KNOWN_WORDS.copy() - self.bos_token_id = 49406 # these are what the real CLIPTokenizer has - self.eos_token_id = 49407 - self.pad_token_id = 49407 - self.unk_token_id = 49407 - - def convert_tokens_to_ids(self, token_str): - try: - return self.tokens.index(token_str) - except ValueError: - return self.unk_token_id - - def add_tokens(self, token_str): - if token_str in self.tokens: - return 0 - self.tokens.append(token_str) - return 1 - - -class DummyClipEmbedder: - def __init__(self): - self.max_length = 77 - self.transformer = DummyTransformer() - self.tokenizer = DummyTokenizer() - self.position_embeddings_tensor = torch.randn([77,768], dtype=torch.float32) - - def position_embedding(self, indices: Union[list,torch.Tensor]): - if type(indices) is list: - indices = torch.tensor(indices, dtype=int) - return torch.index_select(self.position_embeddings_tensor, 0, indices) - - -def was_embedding_overwritten_correctly(tim: TextualInversionManager, overwritten_embedding: torch.Tensor, ti_indices: list, ti_embedding: torch.Tensor) -> bool: - return torch.allclose(overwritten_embedding[ti_indices], ti_embedding + tim.clip_embedder.position_embedding(ti_indices)) - - -def make_dummy_textual_inversion_manager(): - return TextualInversionManager( - tokenizer=DummyTokenizer(), - text_encoder=DummyTransformer() - ) - -class TextualInversionManagerTestCase(unittest.TestCase): - - - def test_construction(self): - tim = make_dummy_textual_inversion_manager() - - def test_add_embedding_for_known_token(self): - tim = make_dummy_textual_inversion_manager() - test_embedding = torch.randn([1, 768]) - test_embedding_name = KNOWN_WORDS[0] - self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name)) - - pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - - ti = tim._add_textual_inversion(test_embedding_name, test_embedding) - self.assertEqual(ti.trigger_token_id, 0) - - - # check adding 'test' did not create a new word - embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - self.assertEqual(pre_embeddings_count, embeddings_count) - - # check it was added - self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name)) - textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name) - self.assertIsNotNone(textual_inversion) - self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding)) - self.assertEqual(textual_inversion.trigger_string, test_embedding_name) - self.assertEqual(textual_inversion.trigger_token_id, ti.trigger_token_id) - - def test_add_embedding_for_unknown_token(self): - tim = make_dummy_textual_inversion_manager() - test_embedding_1 = torch.randn([1, 768]) - test_embedding_name_1 = UNKNOWN_WORDS[0] - - pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - - added_token_id_1 = tim._add_textual_inversion(test_embedding_name_1, test_embedding_1).trigger_token_id - # new token id should get added on the end - self.assertEqual(added_token_id_1, len(KNOWN_WORDS)) - - # check adding did create a new word - embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - self.assertEqual(pre_embeddings_count+1, embeddings_count) - - # check it was added - self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1)) - textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_1) - self.assertIsNotNone(textual_inversion) - self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1)) - self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1) - self.assertEqual(textual_inversion.trigger_token_id, added_token_id_1) - - # add another one - test_embedding_2 = torch.randn([1, 768]) - test_embedding_name_2 = UNKNOWN_WORDS[1] - - pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - - added_token_id_2 = tim._add_textual_inversion(test_embedding_name_2, test_embedding_2).trigger_token_id - self.assertEqual(added_token_id_2, len(KNOWN_WORDS)+1) - - # check adding did create a new word - embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - self.assertEqual(pre_embeddings_count+1, embeddings_count) - - # check it was added - self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_2)) - textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_2) - self.assertIsNotNone(textual_inversion) - self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_2)) - self.assertEqual(textual_inversion.trigger_string, test_embedding_name_2) - self.assertEqual(textual_inversion.trigger_token_id, added_token_id_2) - - # check the old one is still there - self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1)) - textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_1) - self.assertIsNotNone(textual_inversion) - self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1)) - self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1) - self.assertEqual(textual_inversion.trigger_token_id, added_token_id_1) - - - def test_pad_raises_on_eos_bos(self): - tim = make_dummy_textual_inversion_manager() - prompt_token_ids_with_eos_bos = [tim.tokenizer.bos_token_id] + \ - [KNOWN_WORDS_TOKEN_IDS] + \ - [tim.tokenizer.eos_token_id] - with self.assertRaises(ValueError): - tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_with_eos_bos) - - def test_pad_tokens_list_vector_length_1(self): - tim = make_dummy_textual_inversion_manager() - prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy() - - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids) - self.assertEqual(prompt_token_ids, expanded_prompt_token_ids) - - test_embedding_1v = torch.randn([1, 768]) - test_embedding_1v_token = "" - test_embedding_1v_token_id = tim._add_textual_inversion(test_embedding_1v_token, test_embedding_1v).trigger_token_id - self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS)) - - # at the end - prompt_token_ids_1v_append = prompt_token_ids + [test_embedding_1v_token_id] - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_append) - self.assertEqual(prompt_token_ids_1v_append, expanded_prompt_token_ids) - - # at the start - prompt_token_ids_1v_prepend = [test_embedding_1v_token_id] + prompt_token_ids - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_prepend) - self.assertEqual(prompt_token_ids_1v_prepend, expanded_prompt_token_ids) - - # in the middle - prompt_token_ids_1v_insert = prompt_token_ids[0:2] + [test_embedding_1v_token_id] + prompt_token_ids[2:3] - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_insert) - self.assertEqual(prompt_token_ids_1v_insert, expanded_prompt_token_ids) - - def test_pad_tokens_list_vector_length_2(self): - tim = make_dummy_textual_inversion_manager() - prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy() - - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids) - self.assertEqual(prompt_token_ids, expanded_prompt_token_ids) - - test_embedding_2v = torch.randn([2, 768]) - test_embedding_2v_token = "" - test_embedding_2v_token_id = tim._add_textual_inversion(test_embedding_2v_token, test_embedding_2v).trigger_token_id - test_embedding_2v_pad_token_ids = tim.get_textual_inversion_for_token_id(test_embedding_2v_token_id).pad_token_ids - self.assertEqual(test_embedding_2v_token_id, len(KNOWN_WORDS)) - - # at the end - prompt_token_ids_2v_append = prompt_token_ids + [test_embedding_2v_token_id] - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_append) - self.assertNotEqual(prompt_token_ids_2v_append, expanded_prompt_token_ids) - self.assertEqual(prompt_token_ids + [test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids, expanded_prompt_token_ids) - - # at the start - prompt_token_ids_2v_prepend = [test_embedding_2v_token_id] + prompt_token_ids - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_prepend) - self.assertNotEqual(prompt_token_ids_2v_prepend, expanded_prompt_token_ids) - self.assertEqual([test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids + prompt_token_ids, expanded_prompt_token_ids) - - # in the middle - prompt_token_ids_2v_insert = prompt_token_ids[0:2] + [test_embedding_2v_token_id] + prompt_token_ids[2:3] - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_insert) - self.assertNotEqual(prompt_token_ids_2v_insert, expanded_prompt_token_ids) - self.assertEqual(prompt_token_ids[0:2] + [test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids + prompt_token_ids[2:3], expanded_prompt_token_ids) - - def test_pad_tokens_list_vector_length_8(self): - tim = make_dummy_textual_inversion_manager() - prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy() - - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids) - self.assertEqual(prompt_token_ids, expanded_prompt_token_ids) - - test_embedding_8v = torch.randn([8, 768]) - test_embedding_8v_token = "" - test_embedding_8v_token_id = tim._add_textual_inversion(test_embedding_8v_token, test_embedding_8v).trigger_token_id - test_embedding_8v_pad_token_ids = tim.get_textual_inversion_for_token_id(test_embedding_8v_token_id).pad_token_ids - self.assertEqual(test_embedding_8v_token_id, len(KNOWN_WORDS)) - - # at the end - prompt_token_ids_8v_append = prompt_token_ids + [test_embedding_8v_token_id] - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_append) - self.assertNotEqual(prompt_token_ids_8v_append, expanded_prompt_token_ids) - self.assertEqual(prompt_token_ids + [test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids, expanded_prompt_token_ids) - - # at the start - prompt_token_ids_8v_prepend = [test_embedding_8v_token_id] + prompt_token_ids - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_prepend) - self.assertNotEqual(prompt_token_ids_8v_prepend, expanded_prompt_token_ids) - self.assertEqual([test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids + prompt_token_ids, expanded_prompt_token_ids) - - # in the middle - prompt_token_ids_8v_insert = prompt_token_ids[0:2] + [test_embedding_8v_token_id] + prompt_token_ids[2:3] - expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_insert) - self.assertNotEqual(prompt_token_ids_8v_insert, expanded_prompt_token_ids) - self.assertEqual(prompt_token_ids[0:2] + [test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids + prompt_token_ids[2:3], expanded_prompt_token_ids) - - - def test_deferred_loading(self): - tim = make_dummy_textual_inversion_manager() - test_embedding = torch.randn([1, 768]) - test_embedding_name = UNKNOWN_WORDS[0] - self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name)) - - pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - - ti = tim._add_textual_inversion(test_embedding_name, test_embedding, defer_injecting_tokens=True) - self.assertIsNone(ti.trigger_token_id) - - # check that a new word is not yet created - embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - self.assertEqual(pre_embeddings_count, embeddings_count) - - # check it was added - self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name)) - textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name) - self.assertIsNotNone(textual_inversion) - self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding)) - self.assertEqual(textual_inversion.trigger_string, test_embedding_name) - self.assertIsNone(textual_inversion.trigger_token_id, ti.trigger_token_id) - - # check it lazy-loads - prompt = " ".join([KNOWN_WORDS[0], UNKNOWN_WORDS[0], KNOWN_WORDS[1]]) - tim.create_deferred_token_ids_for_any_trigger_terms(prompt) - - embeddings_count = len(tim.text_encoder.resize_token_embeddings(None)) - self.assertEqual(pre_embeddings_count+1, embeddings_count) - - textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name) - self.assertEqual(textual_inversion.trigger_string, test_embedding_name) - self.assertEqual(textual_inversion.trigger_token_id, len(KNOWN_WORDS)) From 4d2c7806fcfee98e677ba0f46c39f64ca7e5d5c3 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 3 Jul 2023 10:08:10 -0400 Subject: [PATCH 02/51] quash memory leak when compel invocation called --- invokeai/app/invocations/compel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 8c6b23944c..0421841e8a 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -2,6 +2,7 @@ from typing import Literal, Optional, Union from pydantic import BaseModel, Field from contextlib import ExitStack import re +import torch from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .model import ClipField @@ -56,6 +57,7 @@ class CompelInvocation(BaseInvocation): }, } + @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: tokenizer_info = context.services.model_manager.get_model( From e73f774920d891f73141678790d9c37fc27762ed Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Mon, 26 Jun 2023 18:14:44 +1200 Subject: [PATCH 03/51] fix: Restore Model display and select functionality --- .../components/ModelManager/ModelList.tsx | 130 ++++++++---------- .../components/ModelManager/ModelListItem.tsx | 17 ++- .../ModelManager/ModelManagerModal.tsx | 12 +- 3 files changed, 75 insertions(+), 84 deletions(-) diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelList.tsx b/invokeai/frontend/web/src/features/system/components/ModelManager/ModelList.tsx index 4ef311e1d4..b090c2a07b 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelList.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelManager/ModelList.tsx @@ -1,36 +1,16 @@ import { Box, Flex, Heading, Spacer, Spinner, Text } from '@chakra-ui/react'; -import IAIInput from 'common/components/IAIInput'; import IAIButton from 'common/components/IAIButton'; +import IAIInput from 'common/components/IAIInput'; import AddModel from './AddModel'; -import ModelListItem from './ModelListItem'; import MergeModels from './MergeModels'; +import ModelListItem from './ModelListItem'; -import { useAppSelector } from 'app/store/storeHooks'; import { useTranslation } from 'react-i18next'; -import { createSelector } from '@reduxjs/toolkit'; -import { systemSelector } from 'features/system/store/systemSelectors'; -import type { SystemState } from 'features/system/store/systemSlice'; -import { isEqual, map } from 'lodash-es'; - -import React, { useMemo, useState, useTransition } from 'react'; import type { ChangeEvent, ReactNode } from 'react'; - -const modelListSelector = createSelector( - systemSelector, - (system: SystemState) => { - const models = map(system.model_list, (model, key) => { - return { name: key, ...model }; - }); - return models; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); +import React, { useMemo, useState, useTransition } from 'react'; +import { useListModelsQuery } from 'services/api/endpoints/models'; function ModelFilterButton({ label, @@ -58,7 +38,9 @@ function ModelFilterButton({ } const ModelList = () => { - const models = useAppSelector(modelListSelector); + const { data: pipelineModels } = useListModelsQuery({ + model_type: 'pipeline', + }); const [renderModelList, setRenderModelList] = React.useState(false); @@ -90,43 +72,49 @@ const ModelList = () => { const filteredModelListItemsToRender: ReactNode[] = []; const localFilteredModelListItemsToRender: ReactNode[] = []; - models.forEach((model, i) => { - if (model.name.toLowerCase().includes(searchText.toLowerCase())) { + if (!pipelineModels) return; + + const modelList = pipelineModels.entities; + + Object.keys(modelList).forEach((model, i) => { + if ( + modelList[model].name.toLowerCase().includes(searchText.toLowerCase()) + ) { filteredModelListItemsToRender.push( ); - if (model.format === isSelectedFilter) { + if (modelList[model]?.model_format === isSelectedFilter) { localFilteredModelListItemsToRender.push( ); } } - if (model.format !== 'diffusers') { + if (modelList[model]?.model_format !== 'diffusers') { ckptModelListItemsToRender.push( ); } else { diffusersModelListItemsToRender.push( ); } @@ -142,6 +130,23 @@ const ModelList = () => { {isSelectedFilter === 'all' && ( <> + + + {t('modelManager.diffusersModels')} + + {diffusersModelListItemsToRender} + { {ckptModelListItemsToRender} - - - {t('modelManager.diffusersModels')} - - {diffusersModelListItemsToRender} - )} - {isSelectedFilter === 'ckpt' && ( - - {ckptModelListItemsToRender} - - )} - {isSelectedFilter === 'diffusers' && ( {diffusersModelListItemsToRender} )} + + {isSelectedFilter === 'ckpt' && ( + + {ckptModelListItemsToRender} + + )} ); - }, [models, searchText, t, isSelectedFilter]); + }, [pipelineModels, searchText, t, isSelectedFilter]); return ( @@ -211,7 +199,7 @@ const ModelList = () => { { onClick={() => setIsSelectedFilter('all')} isActive={isSelectedFilter === 'all'} /> - setIsSelectedFilter('ckpt')} - isActive={isSelectedFilter === 'ckpt'} - /> setIsSelectedFilter('diffusers')} isActive={isSelectedFilter === 'diffusers'} /> + setIsSelectedFilter('ckpt')} + isActive={isSelectedFilter === 'ckpt'} + /> {renderModelList ? ( diff --git a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelListItem.tsx b/invokeai/frontend/web/src/features/system/components/ModelManager/ModelListItem.tsx index aa9f87816c..e1b3bbab1e 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelManager/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelManager/ModelListItem.tsx @@ -1,6 +1,6 @@ import { DeleteIcon, EditIcon } from '@chakra-ui/icons'; import { Box, Button, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; -import { ModelStatus } from 'app/types/invokeai'; + // import { deleteModel, requestModelChange } from 'app/socketio/actions'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; @@ -10,9 +10,9 @@ import { setOpenModel } from 'features/system/store/systemSlice'; import { useTranslation } from 'react-i18next'; type ModelListItemProps = { + modelKey: string; name: string; - status: ModelStatus; - description: string; + description: string | undefined; }; export default function ModelListItem(props: ModelListItemProps) { @@ -28,18 +28,18 @@ export default function ModelListItem(props: ModelListItemProps) { const dispatch = useAppDispatch(); - const { name, status, description } = props; + const { modelKey, name, description } = props; const handleChangeModel = () => { - dispatch(requestModelChange(name)); + dispatch(requestModelChange(modelKey)); }; const openModelHandler = () => { - dispatch(setOpenModel(name)); + dispatch(setOpenModel(modelKey)); }; const handleModelDelete = () => { - dispatch(deleteModel(name)); + dispatch(deleteModel(modelKey)); dispatch(setOpenModel(null)); }; @@ -60,7 +60,7 @@ export default function ModelListItem(props: ModelListItemProps) { p={2} borderRadius="base" sx={ - name === openModel + modelKey === openModel ? { bg: 'accent.750', _hover: { @@ -81,7 +81,6 @@ export default function ModelListItem(props: ModelListItemProps) { - {status} - } size="sm" From 2f8f558df3e68bc14bd8f7d2e6654bd97b2e016c Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Fri, 30 Jun 2023 09:10:19 +1200 Subject: [PATCH 17/51] wip: Move Add Model from Modal to new Panel --- .../ModelManager/subpanels/AddModelsPanel.tsx | 53 +++++++- .../AddModelsPanel/AddDiffusersModel.tsx | 2 +- .../subpanels/AddModelsPanel/AddModel.tsx | 125 ------------------ 3 files changed, 50 insertions(+), 130 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddModel.tsx diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel.tsx index b1c03fcb6b..25f4adf4aa 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel.tsx @@ -1,10 +1,55 @@ -import { Flex } from '@chakra-ui/react'; -import AddModel from 'features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddModel'; +import { Divider, Flex } from '@chakra-ui/react'; +import { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAIButton from 'common/components/IAIButton'; +import { setAddNewModelUIOption } from 'features/ui/store/uiSlice'; +import { useTranslation } from 'react-i18next'; +import AddCheckpointModel from './AddModelsPanel/AddCheckpointModel'; +import AddDiffusersModel from './AddModelsPanel/AddDiffusersModel'; export default function AddModelsPanel() { + const addNewModelUIOption = useAppSelector( + (state: RootState) => state.ui.addNewModelUIOption + ); + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + return ( - - + + + dispatch(setAddNewModelUIOption('ckpt'))} + sx={{ + backgroundColor: + addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.700', + '&:hover': { + backgroundColor: + addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.600', + }, + }} + > + {t('modelManager.addCheckpointModel')} + + dispatch(setAddNewModelUIOption('diffusers'))} + sx={{ + backgroundColor: + addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.700', + '&:hover': { + backgroundColor: + addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.600', + }, + }} + > + {t('modelManager.addDiffuserModel')} + + + + + + {addNewModelUIOption == 'ckpt' && } + {addNewModelUIOption == 'diffusers' && } ); } diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddDiffusersModel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddDiffusersModel.tsx index cb3af5f176..dd491828da 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddDiffusersModel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AddDiffusersModel.tsx @@ -66,7 +66,7 @@ export default function AddDiffusersModel() { }; return ( - + void; -}) { - return ( - - {text} - - ); -} - -export default function AddModel() { - const { isOpen, onOpen, onClose } = useDisclosure(); - - const addNewModelUIOption = useAppSelector( - (state: RootState) => state.ui.addNewModelUIOption - ); - - const dispatch = useAppDispatch(); - - const { t } = useTranslation(); - - const addModelModalClose = () => { - onClose(); - dispatch(setAddNewModelUIOption(null)); - }; - - return ( - <> - - - - {t('modelManager.addNew')} - - - - - - - {t('modelManager.addNewModel')} - {addNewModelUIOption !== null && ( - dispatch(setAddNewModelUIOption(null))} - position="absolute" - variant="ghost" - zIndex={1} - size="sm" - insetInlineEnd={12} - top={2} - icon={} - /> - )} - - - {addNewModelUIOption == null && ( - - dispatch(setAddNewModelUIOption('ckpt'))} - /> - dispatch(setAddNewModelUIOption('diffusers'))} - /> - - )} - {addNewModelUIOption == 'ckpt' && } - {addNewModelUIOption == 'diffusers' && } - - - - - - ); -} From 6684e00f0a9464ca95b598bb8c209d14566e0528 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Fri, 30 Jun 2023 09:31:51 +1200 Subject: [PATCH 18/51] wip: Move Merge Models to new panel & use new model fetch --- .../subpanels/MergeModelsPanel.tsx | 258 ++++++++++++++- .../MergeModelsPanel/MergeModels.tsx | 313 ------------------ 2 files changed, 254 insertions(+), 317 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel/MergeModels.tsx diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx index 8c26357720..0cd90a9492 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx @@ -1,10 +1,260 @@ -import { Flex } from '@chakra-ui/react'; -import MergeModels from 'features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel/MergeModels'; +import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react'; +import { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAIButton from 'common/components/IAIButton'; +import IAIInput from 'common/components/IAIInput'; +import IAISelect from 'common/components/IAISelect'; +import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; +import IAISlider from 'common/components/IAISlider'; +import { pickBy } from 'lodash-es'; +import { useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useListModelsQuery } from 'services/api/endpoints/models'; export default function MergeModelsPanel() { + const { t } = useTranslation(); + + const dispatch = useAppDispatch(); + + const { data } = useListModelsQuery({ + model_type: 'main', + }); + + const diffusersModels = pickBy( + data?.entities, + (value, _) => value?.model_format === 'diffusers' + ); + + const [modelOne, setModelOne] = useState( + Object.keys(diffusersModels)[0] + ); + const [modelTwo, setModelTwo] = useState( + Object.keys(diffusersModels)[1] + ); + const [modelThree, setModelThree] = useState('none'); + + const [mergedModelName, setMergedModelName] = useState(''); + const [modelMergeAlpha, setModelMergeAlpha] = useState(0.5); + + const [modelMergeInterp, setModelMergeInterp] = useState< + 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference' + >('weighted_sum'); + + const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState< + 'root' | 'custom' + >('root'); + + const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] = + useState(''); + + const [modelMergeForce, setModelMergeForce] = useState(false); + + const modelOneList = Object.keys(diffusersModels).filter( + (model) => model !== modelTwo && model !== modelThree + ); + + const modelTwoList = Object.keys(diffusersModels).filter( + (model) => model !== modelOne && model !== modelThree + ); + + const modelThreeList = [ + { key: t('modelManager.none'), value: 'none' }, + ...Object.keys(diffusersModels) + .filter((model) => model !== modelOne && model !== modelTwo) + .map((model) => ({ key: model, value: model })), + ]; + + const isProcessing = useAppSelector( + (state: RootState) => state.system.isProcessing + ); + + const mergeModelsHandler = () => { + let modelsToMerge: string[] = [modelOne, modelTwo, modelThree]; + modelsToMerge = modelsToMerge.filter((model) => model !== 'none'); + + const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = { + models_to_merge: modelsToMerge, + merged_model_name: + mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'), + alpha: modelMergeAlpha, + interp: modelMergeInterp, + model_merge_save_path: + modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc, + force: modelMergeForce, + }; + + dispatch(mergeDiffusersModels(mergeModelsInfo)); + }; + return ( - - + + + {t('modelManager.modelMergeHeaderHelp1')} + + {t('modelManager.modelMergeHeaderHelp2')} + + + + setModelOne(e.target.value)} + /> + setModelTwo(e.target.value)} + /> + { + if (e.target.value !== 'none') { + setModelThree(e.target.value); + setModelMergeInterp('add_difference'); + } else { + setModelThree('none'); + setModelMergeInterp('weighted_sum'); + } + }} + /> + + + setMergedModelName(e.target.value)} + /> + + + setModelMergeAlpha(v)} + withInput + withReset + handleReset={() => setModelMergeAlpha(0.5)} + withSliderMarks + /> + + {t('modelManager.modelMergeAlphaHelp')} + + + + + + {t('modelManager.interpolationType')} + + setModelMergeInterp(v)} + > + + {modelThree === 'none' ? ( + <> + + {t('modelManager.weightedSum')} + + + {t('modelManager.sigmoid')} + + + {t('modelManager.inverseSigmoid')} + + + ) : ( + + + {t('modelManager.addDifference')} + + + )} + + + + + + + + {t('modelManager.mergedModelSaveLocation')} + + setModelMergeSaveLocType(v)} + > + + + {t('modelManager.invokeAIFolder')} + + + + {t('modelManager.custom')} + + + + + + {modelMergeSaveLocType === 'custom' && ( + setModelMergeCustomSaveLoc(e.target.value)} + /> + )} + + + setModelMergeForce(e.target.checked)} + fontWeight="500" + /> + + + {t('modelManager.merge')} + ); } diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel/MergeModels.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel/MergeModels.tsx deleted file mode 100644 index 219d49d4ee..0000000000 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel/MergeModels.tsx +++ /dev/null @@ -1,313 +0,0 @@ -import { - Flex, - Modal, - ModalBody, - ModalCloseButton, - ModalContent, - ModalFooter, - ModalHeader, - ModalOverlay, - Radio, - RadioGroup, - Text, - Tooltip, - useDisclosure, -} from '@chakra-ui/react'; -// import { mergeDiffusersModels } from 'app/socketio/actions'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAIButton from 'common/components/IAIButton'; -import IAIInput from 'common/components/IAIInput'; -import IAISelect from 'common/components/IAISelect'; -import { diffusersModelsSelector } from 'features/system/store/systemSelectors'; -import { useState } from 'react'; -import { useTranslation } from 'react-i18next'; -import * as InvokeAI from 'app/types/invokeai'; -import IAISlider from 'common/components/IAISlider'; -import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; - -export default function MergeModels() { - const dispatch = useAppDispatch(); - - const { isOpen, onOpen, onClose } = useDisclosure(); - - const diffusersModels = useAppSelector(diffusersModelsSelector); - - const { t } = useTranslation(); - - const [modelOne, setModelOne] = useState( - Object.keys(diffusersModels)[0] - ); - const [modelTwo, setModelTwo] = useState( - Object.keys(diffusersModels)[1] - ); - const [modelThree, setModelThree] = useState('none'); - - const [mergedModelName, setMergedModelName] = useState(''); - const [modelMergeAlpha, setModelMergeAlpha] = useState(0.5); - - const [modelMergeInterp, setModelMergeInterp] = useState< - 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference' - >('weighted_sum'); - - const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState< - 'root' | 'custom' - >('root'); - - const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] = - useState(''); - - const [modelMergeForce, setModelMergeForce] = useState(false); - - const modelOneList = Object.keys(diffusersModels).filter( - (model) => model !== modelTwo && model !== modelThree - ); - - const modelTwoList = Object.keys(diffusersModels).filter( - (model) => model !== modelOne && model !== modelThree - ); - - const modelThreeList = [ - { key: t('modelManager.none'), value: 'none' }, - ...Object.keys(diffusersModels) - .filter((model) => model !== modelOne && model !== modelTwo) - .map((model) => ({ key: model, value: model })), - ]; - - const isProcessing = useAppSelector( - (state: RootState) => state.system.isProcessing - ); - - const mergeModelsHandler = () => { - let modelsToMerge: string[] = [modelOne, modelTwo, modelThree]; - modelsToMerge = modelsToMerge.filter((model) => model !== 'none'); - - const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = { - models_to_merge: modelsToMerge, - merged_model_name: - mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'), - alpha: modelMergeAlpha, - interp: modelMergeInterp, - model_merge_save_path: - modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc, - force: modelMergeForce, - }; - - dispatch(mergeDiffusersModels(mergeModelsInfo)); - }; - - return ( - <> - - - {t('modelManager.mergeModels')} - - - - - - - {t('modelManager.mergeModels')} - - - - - {t('modelManager.modelMergeHeaderHelp1')} - - {t('modelManager.modelMergeHeaderHelp2')} - - - - setModelOne(e.target.value)} - /> - setModelTwo(e.target.value)} - /> - { - if (e.target.value !== 'none') { - setModelThree(e.target.value); - setModelMergeInterp('add_difference'); - } else { - setModelThree('none'); - setModelMergeInterp('weighted_sum'); - } - }} - /> - - - setMergedModelName(e.target.value)} - /> - - - setModelMergeAlpha(v)} - withInput - withReset - handleReset={() => setModelMergeAlpha(0.5)} - withSliderMarks - /> - - {t('modelManager.modelMergeAlphaHelp')} - - - - - - {t('modelManager.interpolationType')} - - setModelMergeInterp(v)} - > - - {modelThree === 'none' ? ( - <> - - - {t('modelManager.weightedSum')} - - - - {t('modelManager.sigmoid')} - - - - {t('modelManager.inverseSigmoid')} - - - - ) : ( - - - - {t('modelManager.addDifference')} - - - - )} - - - - - - - - {t('modelManager.mergedModelSaveLocation')} - - - setModelMergeSaveLocType(v) - } - > - - - - {t('modelManager.invokeAIFolder')} - - - - - {t('modelManager.custom')} - - - - - - {modelMergeSaveLocType === 'custom' && ( - setModelMergeCustomSaveLoc(e.target.value)} - /> - )} - - - setModelMergeForce(e.target.checked)} - fontWeight="500" - /> - - - {t('modelManager.merge')} - - - - - - - - ); -} From 6c6299ce490b8a53b4a255cc1b5bfb106e566768 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Fri, 30 Jun 2023 09:51:48 +1200 Subject: [PATCH 19/51] fix: Style fixes to the MM edit forms --- .../subpanels/ModelManagerPanel/CheckpointModelEdit.tsx | 8 +------- .../subpanels/ModelManagerPanel/DiffusersModelEdit.tsx | 7 +------ .../subpanels/ModelManagerPanel/ModelConvert.tsx | 1 - 3 files changed, 2 insertions(+), 14 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx index 34a6d6885e..0d5d21175a 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/CheckpointModelEdit.tsx @@ -81,19 +81,13 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) { flexDirection="column" maxHeight={window.innerHeight - 270} overflowY="scroll" - paddingInlineEnd={8} >
editModelFormSubmitHandler(values) )} > - + - + 🧨 {t('modelManager.convertToDiffusers')} From 630f3c8b0b8dd745a5308ba8f4465c0b30a77fd0 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sat, 1 Jul 2023 10:06:18 +1200 Subject: [PATCH 20/51] fix: Missing VAE Loader stuff --- .../frontend/web/src/features/nodes/types/constants.ts | 7 +++++++ .../web/src/features/nodes/util/modelIdToVAEModelField.ts | 2 +- invokeai/frontend/web/src/services/api/types.d.ts | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 9f6124c9d4..b864501803 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -17,6 +17,7 @@ export const FIELD_TYPE_MAP: Record = { ClipField: 'clip', VaeField: 'vae', model: 'model', + vae_model: 'vae_model', array: 'array', item: 'item', ColorField: 'color', @@ -116,6 +117,12 @@ export const FIELDS: Record = { title: 'Model', description: 'Models are models.', }, + vae_model: { + color: 'teal', + colorCssVar: getColorTokenCssVariable('teal'), + title: 'Model', + description: 'Models are models.', + }, array: { color: 'gray', colorCssVar: getColorTokenCssVariable('gray'), diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts index 3850ad443d..0cb608a936 100644 --- a/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToVAEModelField.ts @@ -1,4 +1,4 @@ -import { BaseModelType } from 'services/api/types'; +import { BaseModelType, VAEModelField } from 'services/api/types'; /** * Crudely converts a model id to a main model field diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index 3da844d764..18942a47d6 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -34,6 +34,7 @@ export type OffsetPaginatedResults_ImageDTO_ = export type ModelType = S<'ModelType'>; export type BaseModelType = S<'BaseModelType'>; export type MainModelField = S<'MainModelField'>; +export type VAEModelField = S<'VAEModelField'>; export type ModelsList = S<'ModelsList'>; // Graphs From fa8a5838d3487cdda10a4077488e08dbfddaa7f5 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 30 Jun 2023 18:15:04 -0400 Subject: [PATCH 21/51] add vae lodaer --- invokeai/app/invocations/model.py | 54 +++++- .../frontend/web/src/services/api/schema.d.ts | 174 ++++++++++++------ 2 files changed, 169 insertions(+), 59 deletions(-) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 12b8c7cdd6..73a640e04b 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,10 +1,9 @@ import copy -from typing import List, Literal, Optional, Union +from typing import List, Literal, Optional from pydantic import BaseModel, Field from ...backend.model_management import BaseModelType, ModelType, SubModelType -from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) @@ -32,7 +31,6 @@ class VaeField(BaseModel): # TODO: better naming? vae: ModelInfo = Field(description="Info to load vae submodel") - class ModelLoaderOutput(BaseInvocationOutput): """Model loader output""" @@ -223,3 +221,53 @@ class LoraLoaderInvocation(BaseInvocation): return output +class VAEModelField(BaseModel): + """Vae model field""" + + model_name: str = Field(description="Name of the model") + base_model: BaseModelType = Field(description="Base model") + +class VaeLoaderOutput(BaseInvocationOutput): + """Model loader output""" + + #fmt: off + type: Literal["vae_loader_output"] = "vae_loader_output" + + vae: VaeField = Field(default=None, description="Vae model") + #fmt: on + +class VaeLoaderInvocation(BaseInvocation): + """Loads a VAE model, outputting a VaeLoaderOutput""" + type: Literal["vae_loader"] = "vae_loader" + + vae_model: VAEModelField = Field(description="The VAE to load") + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["model", "loader"], + "type_hints": { + "vae_model": "vae_model" + } + }, + } + + def invoke(self, context: InvocationContext) -> VaeLoaderOutput: + base_model = self.vae.base_model + model_name = self.vae.model_name + model_type = ModelType.vae + + if not context.services.model_manager.model_exists( + base_model=base_model, + model_name=model_name, + model_type=model_type, + ): + raise Exception(f"Unkown vae name: {model_name}!") + return VaeLoaderOutput( + vae=VaeField( + model_name = model_name, + base_model = base_model, + model_type = model_type, + ) + ) diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 9d8ac90535..41058e3cd3 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -1068,12 +1068,6 @@ export type components = { nodes?: { [key: string]: | ( - | components['schemas']['RangeInvocation'] - | components['schemas']['RangeOfSizeInvocation'] - | components['schemas']['RandomRangeInvocation'] - | components['schemas']['MainModelLoaderInvocation'] - | components['schemas']['LoraLoaderInvocation'] - | components['schemas']['CompelInvocation'] | components['schemas']['LoadImageInvocation'] | components['schemas']['ShowImageInvocation'] | components['schemas']['ImageCropInvocation'] @@ -1087,31 +1081,38 @@ export type components = { | components['schemas']['ImageScaleInvocation'] | components['schemas']['ImageLerpInvocation'] | components['schemas']['ImageInverseLerpInvocation'] + | components['schemas']['CvInpaintInvocation'] + | components['schemas']['RestoreFaceInvocation'] + | components['schemas']['MainModelLoaderInvocation'] + | components['schemas']['LoraLoaderInvocation'] + | components['schemas']['VaeLoaderInvocation'] + | components['schemas']['CompelInvocation'] | components['schemas']['ControlNetInvocation'] | components['schemas']['ImageProcessorInvocation'] - | components['schemas']['CvInpaintInvocation'] | components['schemas']['TextToLatentsInvocation'] | components['schemas']['LatentsToImageInvocation'] | components['schemas']['ResizeLatentsInvocation'] | components['schemas']['ScaleLatentsInvocation'] | components['schemas']['ImageToLatentsInvocation'] | components['schemas']['InpaintInvocation'] - | components['schemas']['InfillColorInvocation'] - | components['schemas']['InfillTileInvocation'] - | components['schemas']['InfillPatchMatchInvocation'] | components['schemas']['AddInvocation'] | components['schemas']['SubtractInvocation'] | components['schemas']['MultiplyInvocation'] | components['schemas']['DivideInvocation'] | components['schemas']['RandomIntInvocation'] - | components['schemas']['NoiseInvocation'] | components['schemas']['ParamIntInvocation'] | components['schemas']['ParamFloatInvocation'] + | components['schemas']['UpscaleInvocation'] + | components['schemas']['RangeInvocation'] + | components['schemas']['RangeOfSizeInvocation'] + | components['schemas']['RandomRangeInvocation'] + | components['schemas']['DynamicPromptInvocation'] + | components['schemas']['InfillColorInvocation'] + | components['schemas']['InfillTileInvocation'] + | components['schemas']['InfillPatchMatchInvocation'] + | components['schemas']['NoiseInvocation'] | components['schemas']['FloatLinearRangeInvocation'] | components['schemas']['StepParamEasingInvocation'] - | components['schemas']['DynamicPromptInvocation'] - | components['schemas']['RestoreFaceInvocation'] - | components['schemas']['UpscaleInvocation'] | components['schemas']['GraphInvocation'] | components['schemas']['IterateInvocation'] | components['schemas']['CollectInvocation'] @@ -1177,20 +1178,21 @@ export type components = { results: { [key: string]: | ( - | components['schemas']['IntCollectionOutput'] - | components['schemas']['FloatCollectionOutput'] - | components['schemas']['ModelLoaderOutput'] - | components['schemas']['LoraLoaderOutput'] - | components['schemas']['CompelOutput'] | components['schemas']['ImageOutput'] | components['schemas']['MaskOutput'] + | components['schemas']['ModelLoaderOutput'] + | components['schemas']['LoraLoaderOutput'] + | components['schemas']['VaeLoaderOutput'] + | components['schemas']['CompelOutput'] | components['schemas']['ControlOutput'] | components['schemas']['LatentsOutput'] | components['schemas']['IntOutput'] | components['schemas']['FloatOutput'] - | components['schemas']['NoiseOutput'] + | components['schemas']['IntCollectionOutput'] + | components['schemas']['FloatCollectionOutput'] | components['schemas']['PromptOutput'] | components['schemas']['PromptCollectionOutput'] + | components['schemas']['NoiseOutput'] | components['schemas']['GraphInvocationOutput'] | components['schemas']['IterateInvocationOutput'] | components['schemas']['CollectInvocationOutput'] @@ -3267,14 +3269,14 @@ export type components = { ModelsList: { /** Models */ models: ( - | components['schemas']['StableDiffusion1ModelCheckpointConfig'] | components['schemas']['StableDiffusion1ModelDiffusersConfig'] + | components['schemas']['StableDiffusion1ModelCheckpointConfig'] | components['schemas']['VaeModelConfig'] | components['schemas']['LoRAModelConfig'] | components['schemas']['ControlNetModelConfig'] | components['schemas']['TextualInversionModelConfig'] - | components['schemas']['StableDiffusion2ModelDiffusersConfig'] | components['schemas']['StableDiffusion2ModelCheckpointConfig'] + | components['schemas']['StableDiffusion2ModelDiffusersConfig'] )[]; }; /** @@ -4539,6 +4541,19 @@ export type components = { */ level?: 2 | 4; }; + /** + * VAEModelField + * @description Vae model field + */ + VAEModelField: { + /** + * Model Name + * @description Name of the model + */ + model_name: string; + /** @description Base model */ + base_model: components['schemas']['BaseModelType']; + }; /** VaeField */ VaeField: { /** @@ -4547,6 +4562,51 @@ export type components = { */ vae: components['schemas']['ModelInfo']; }; + /** + * VaeLoaderInvocation + * @description Loads a VAE model, outputting a VaeLoaderOutput + */ + VaeLoaderInvocation: { + /** + * Id + * @description The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this node is an intermediate node. + * @default false + */ + is_intermediate?: boolean; + /** + * Type + * @default vae_loader + * @enum {string} + */ + type?: 'vae_loader'; + /** + * Vae Model + * @description The VAE to load + */ + vae_model: components['schemas']['VAEModelField']; + }; + /** + * VaeLoaderOutput + * @description Model loader output + */ + VaeLoaderOutput: { + /** + * Type + * @default vae_loader_output + * @enum {string} + */ + type?: 'vae_loader_output'; + /** + * Vae + * @description Vae model + */ + vae?: components['schemas']['VaeField']; + }; /** VaeModelConfig */ VaeModelConfig: { /** Name */ @@ -4625,18 +4685,18 @@ export type components = { */ image?: components['schemas']['ImageField']; }; - /** - * StableDiffusion1ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion1ModelFormat: 'checkpoint' | 'diffusers'; /** * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ StableDiffusion2ModelFormat: 'checkpoint' | 'diffusers'; + /** + * StableDiffusion1ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion1ModelFormat: 'checkpoint' | 'diffusers'; }; responses: never; parameters: never; @@ -4747,12 +4807,6 @@ export type operations = { requestBody: { content: { 'application/json': - | components['schemas']['RangeInvocation'] - | components['schemas']['RangeOfSizeInvocation'] - | components['schemas']['RandomRangeInvocation'] - | components['schemas']['MainModelLoaderInvocation'] - | components['schemas']['LoraLoaderInvocation'] - | components['schemas']['CompelInvocation'] | components['schemas']['LoadImageInvocation'] | components['schemas']['ShowImageInvocation'] | components['schemas']['ImageCropInvocation'] @@ -4766,31 +4820,38 @@ export type operations = { | components['schemas']['ImageScaleInvocation'] | components['schemas']['ImageLerpInvocation'] | components['schemas']['ImageInverseLerpInvocation'] + | components['schemas']['CvInpaintInvocation'] + | components['schemas']['RestoreFaceInvocation'] + | components['schemas']['MainModelLoaderInvocation'] + | components['schemas']['LoraLoaderInvocation'] + | components['schemas']['VaeLoaderInvocation'] + | components['schemas']['CompelInvocation'] | components['schemas']['ControlNetInvocation'] | components['schemas']['ImageProcessorInvocation'] - | components['schemas']['CvInpaintInvocation'] | components['schemas']['TextToLatentsInvocation'] | components['schemas']['LatentsToImageInvocation'] | components['schemas']['ResizeLatentsInvocation'] | components['schemas']['ScaleLatentsInvocation'] | components['schemas']['ImageToLatentsInvocation'] | components['schemas']['InpaintInvocation'] - | components['schemas']['InfillColorInvocation'] - | components['schemas']['InfillTileInvocation'] - | components['schemas']['InfillPatchMatchInvocation'] | components['schemas']['AddInvocation'] | components['schemas']['SubtractInvocation'] | components['schemas']['MultiplyInvocation'] | components['schemas']['DivideInvocation'] | components['schemas']['RandomIntInvocation'] - | components['schemas']['NoiseInvocation'] | components['schemas']['ParamIntInvocation'] | components['schemas']['ParamFloatInvocation'] + | components['schemas']['UpscaleInvocation'] + | components['schemas']['RangeInvocation'] + | components['schemas']['RangeOfSizeInvocation'] + | components['schemas']['RandomRangeInvocation'] + | components['schemas']['DynamicPromptInvocation'] + | components['schemas']['InfillColorInvocation'] + | components['schemas']['InfillTileInvocation'] + | components['schemas']['InfillPatchMatchInvocation'] + | components['schemas']['NoiseInvocation'] | components['schemas']['FloatLinearRangeInvocation'] | components['schemas']['StepParamEasingInvocation'] - | components['schemas']['DynamicPromptInvocation'] - | components['schemas']['RestoreFaceInvocation'] - | components['schemas']['UpscaleInvocation'] | components['schemas']['GraphInvocation'] | components['schemas']['IterateInvocation'] | components['schemas']['CollectInvocation'] @@ -4847,12 +4908,6 @@ export type operations = { requestBody: { content: { 'application/json': - | components['schemas']['RangeInvocation'] - | components['schemas']['RangeOfSizeInvocation'] - | components['schemas']['RandomRangeInvocation'] - | components['schemas']['MainModelLoaderInvocation'] - | components['schemas']['LoraLoaderInvocation'] - | components['schemas']['CompelInvocation'] | components['schemas']['LoadImageInvocation'] | components['schemas']['ShowImageInvocation'] | components['schemas']['ImageCropInvocation'] @@ -4866,31 +4921,38 @@ export type operations = { | components['schemas']['ImageScaleInvocation'] | components['schemas']['ImageLerpInvocation'] | components['schemas']['ImageInverseLerpInvocation'] + | components['schemas']['CvInpaintInvocation'] + | components['schemas']['RestoreFaceInvocation'] + | components['schemas']['MainModelLoaderInvocation'] + | components['schemas']['LoraLoaderInvocation'] + | components['schemas']['VaeLoaderInvocation'] + | components['schemas']['CompelInvocation'] | components['schemas']['ControlNetInvocation'] | components['schemas']['ImageProcessorInvocation'] - | components['schemas']['CvInpaintInvocation'] | components['schemas']['TextToLatentsInvocation'] | components['schemas']['LatentsToImageInvocation'] | components['schemas']['ResizeLatentsInvocation'] | components['schemas']['ScaleLatentsInvocation'] | components['schemas']['ImageToLatentsInvocation'] | components['schemas']['InpaintInvocation'] - | components['schemas']['InfillColorInvocation'] - | components['schemas']['InfillTileInvocation'] - | components['schemas']['InfillPatchMatchInvocation'] | components['schemas']['AddInvocation'] | components['schemas']['SubtractInvocation'] | components['schemas']['MultiplyInvocation'] | components['schemas']['DivideInvocation'] | components['schemas']['RandomIntInvocation'] - | components['schemas']['NoiseInvocation'] | components['schemas']['ParamIntInvocation'] | components['schemas']['ParamFloatInvocation'] + | components['schemas']['UpscaleInvocation'] + | components['schemas']['RangeInvocation'] + | components['schemas']['RangeOfSizeInvocation'] + | components['schemas']['RandomRangeInvocation'] + | components['schemas']['DynamicPromptInvocation'] + | components['schemas']['InfillColorInvocation'] + | components['schemas']['InfillTileInvocation'] + | components['schemas']['InfillPatchMatchInvocation'] + | components['schemas']['NoiseInvocation'] | components['schemas']['FloatLinearRangeInvocation'] | components['schemas']['StepParamEasingInvocation'] - | components['schemas']['DynamicPromptInvocation'] - | components['schemas']['RestoreFaceInvocation'] - | components['schemas']['UpscaleInvocation'] | components['schemas']['GraphInvocation'] | components['schemas']['IterateInvocation'] | components['schemas']['CollectInvocation'] From a8a22095603cb92833243b3065cd051eade1dd22 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 30 Jun 2023 18:46:36 -0400 Subject: [PATCH 22/51] VAE loader is loading proper VAE. Unclear if it is changing the image --- invokeai/app/invocations/model.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 73a640e04b..447d488803 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -254,9 +254,9 @@ class VaeLoaderInvocation(BaseInvocation): } def invoke(self, context: InvocationContext) -> VaeLoaderOutput: - base_model = self.vae.base_model - model_name = self.vae.model_name - model_type = ModelType.vae + base_model = self.vae_model.base_model + model_name = self.vae_model.model_name + model_type = ModelType.Vae if not context.services.model_manager.model_exists( base_model=base_model, @@ -266,8 +266,10 @@ class VaeLoaderInvocation(BaseInvocation): raise Exception(f"Unkown vae name: {model_name}!") return VaeLoaderOutput( vae=VaeField( - model_name = model_name, - base_model = base_model, - model_type = model_type, + vae = ModelInfo( + model_name = model_name, + base_model = base_model, + model_type = model_type, + ) ) ) From bd5a7649883203894c8d03b56a19dc5c001b293b Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sat, 1 Jul 2023 11:05:03 +1200 Subject: [PATCH 23/51] Remove 'automatic' from VAE Loader in Nodes --- .../components/fields/VaeModelInputFieldComponent.tsx | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx index 8a341d920b..74d9942c84 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx @@ -38,13 +38,7 @@ const VaeModelInputFieldComponent = ( return []; } - const data: SelectItem[] = [ - { - value: 'auto', - label: 'Automatic', - group: 'Default', - }, - ]; + const data: SelectItem[] = []; forEach(vaeModels.entities, (model, id) => { if (!model) { From 7e18814dd020482119f813aecfaabd1a720ec9f1 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sat, 1 Jul 2023 11:10:46 +1200 Subject: [PATCH 24/51] Add standard names for Model Loader Nodes --- invokeai/app/invocations/model.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 447d488803..e51873c59e 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -62,6 +62,7 @@ class MainModelLoaderInvocation(BaseInvocation): class Config(InvocationConfig): schema_extra = { "ui": { + "title": "Model Loader", "tags": ["model", "loader"], "type_hints": { "model": "model" @@ -175,6 +176,14 @@ class LoraLoaderInvocation(BaseInvocation): unet: Optional[UNetField] = Field(description="UNet model for applying lora") clip: Optional[ClipField] = Field(description="Clip model for applying lora") + class Config(InvocationConfig): + schema_extra = { + "ui": { + "title": "Lora Loader", + "tags": ["lora", "loader"], + }, + } + def invoke(self, context: InvocationContext) -> LoraLoaderOutput: # TODO: ui rewrite @@ -246,7 +255,8 @@ class VaeLoaderInvocation(BaseInvocation): class Config(InvocationConfig): schema_extra = { "ui": { - "tags": ["model", "loader"], + "title": "VAE Loader", + "tags": ["vae", "loader"], "type_hints": { "vae_model": "vae_model" } From 511978979e7f25299b2e8eda4c10fb83dae616ba Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sat, 1 Jul 2023 12:10:35 +1200 Subject: [PATCH 25/51] feat: Add Custom VAE Support to Linear UI --- .../nodes/util/graphBuilders/addVAEToGraph.ts | 68 +++++++++++++++++++ .../buildCanvasImageToImageGraph.ts | 24 ++----- .../graphBuilders/buildCanvasInpaintGraph.ts | 14 ++-- .../buildCanvasTextToImageGraph.ts | 14 ++-- .../buildLinearImageToImageGraph.ts | 28 ++------ .../buildLinearTextToImageGraph.ts | 14 ++-- .../nodes/util/graphBuilders/constants.ts | 1 + 7 files changed, 91 insertions(+), 72 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts new file mode 100644 index 0000000000..a4a10e574e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts @@ -0,0 +1,68 @@ +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { modelIdToVAEModelField } from '../modelIdToVAEModelField'; +import { + IMAGE_TO_LATENTS, + INPAINT, + LATENTS_TO_IMAGE, + MAIN_MODEL_LOADER, + VAE_LOADER, +} from './constants'; + +export const addVAEToGraph = ( + graph: NonNullableGraph, + state: RootState +): void => { + const { vae: vaeId } = state.generation; + const vae_model = modelIdToVAEModelField(vaeId); + + if (vaeId !== 'auto') { + graph.nodes[VAE_LOADER] = { + type: 'vae_loader', + id: VAE_LOADER, + vae_model, + }; + } + + if ( + graph.id === 'text_to_image_graph' || + graph.id === 'image_to_image_graph' + ) { + graph.edges.push({ + source: { + node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, + field: 'vae', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'vae', + }, + }); + } + + if (graph.id === 'image_to_image_graph') { + graph.edges.push({ + source: { + node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, + field: 'vae', + }, + destination: { + node_id: IMAGE_TO_LATENTS, + field: 'vae', + }, + }); + } + + if (graph.id === 'inpaint_graph') { + graph.edges.push({ + source: { + node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, + field: 'vae', + }, + destination: { + node_id: INPAINT, + field: 'vae', + }, + }); + } +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts index 26713fdbb6..5cf9882ac1 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts @@ -9,6 +9,7 @@ import { import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; import { IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_LATENTS, @@ -122,16 +123,6 @@ export const buildCanvasImageToImageGraph = ( field: 'clip', }, }, - { - source: { - node_id: MAIN_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: LATENTS_TO_IMAGE, - field: 'vae', - }, - }, { source: { node_id: LATENTS_TO_LATENTS, @@ -162,16 +153,6 @@ export const buildCanvasImageToImageGraph = ( field: 'noise', }, }, - { - source: { - node_id: MAIN_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: IMAGE_TO_LATENTS, - field: 'vae', - }, - }, { source: { node_id: MAIN_MODEL_LOADER, @@ -271,6 +252,9 @@ export const buildCanvasImageToImageGraph = ( }); } + // Add VAE + addVAEToGraph(graph, state); + // add dynamic prompts, mutating `graph` addDynamicPromptsToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts index 0a8dd67477..82912de219 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts @@ -8,6 +8,7 @@ import { RangeOfSizeInvocation, } from 'services/api/types'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { addVAEToGraph } from './addVAEToGraph'; import { INPAINT, INPAINT_GRAPH, @@ -170,16 +171,6 @@ export const buildCanvasInpaintGraph = ( field: 'unet', }, }, - { - source: { - node_id: MAIN_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: INPAINT, - field: 'vae', - }, - }, { source: { node_id: RANGE_OF_SIZE, @@ -203,6 +194,9 @@ export const buildCanvasInpaintGraph = ( ], }; + // Add VAE + addVAEToGraph(graph, state); + // handle seed if (shouldRandomizeSeed) { // Random int node to generate the starting seed diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts index cfc74564f1..cfe5e62805 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts @@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; import { LATENTS_TO_IMAGE, MAIN_MODEL_LOADER, @@ -143,16 +144,6 @@ export const buildCanvasTextToImageGraph = ( field: 'latents', }, }, - { - source: { - node_id: MAIN_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: LATENTS_TO_IMAGE, - field: 'vae', - }, - }, { source: { node_id: NOISE, @@ -166,6 +157,9 @@ export const buildCanvasTextToImageGraph = ( ], }; + // Add VAE + addVAEToGraph(graph, state); + // add dynamic prompts, mutating `graph` addDynamicPromptsToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts index cc118960e1..2e4383c3e7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts @@ -10,7 +10,10 @@ import { import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; import { + IMAGE_COLLECTION, + IMAGE_COLLECTION_ITERATE, IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_LATENTS, LATENTS_TO_IMAGE, @@ -20,8 +23,6 @@ import { NOISE, POSITIVE_CONDITIONING, RESIZE, - IMAGE_COLLECTION, - IMAGE_COLLECTION_ITERATE, } from './constants'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -136,16 +137,6 @@ export const buildLinearImageToImageGraph = ( field: 'clip', }, }, - { - source: { - node_id: MAIN_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: LATENTS_TO_IMAGE, - field: 'vae', - }, - }, { source: { node_id: LATENTS_TO_LATENTS, @@ -176,16 +167,7 @@ export const buildLinearImageToImageGraph = ( field: 'noise', }, }, - { - source: { - node_id: MAIN_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: IMAGE_TO_LATENTS, - field: 'vae', - }, - }, + { source: { node_id: MAIN_MODEL_LOADER, @@ -322,6 +304,8 @@ export const buildLinearImageToImageGraph = ( }, }); } + // Add VAE + addVAEToGraph(graph, state); // add dynamic prompts, mutating `graph` addDynamicPromptsToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts index f3d98381c9..e0e71a00a2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts @@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addVAEToGraph } from './addVAEToGraph'; import { LATENTS_TO_IMAGE, MAIN_MODEL_LOADER, @@ -136,16 +137,6 @@ export const buildLinearTextToImageGraph = ( field: 'latents', }, }, - { - source: { - node_id: MAIN_MODEL_LOADER, - field: 'vae', - }, - destination: { - node_id: LATENTS_TO_IMAGE, - field: 'vae', - }, - }, { source: { node_id: NOISE, @@ -159,6 +150,9 @@ export const buildLinearTextToImageGraph = ( ], }; + // Add Custom VAE Support + addVAEToGraph(graph, state); + // add dynamic prompts, mutating `graph` addDynamicPromptsToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts index afc03fdc60..58a7d0335b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts @@ -8,6 +8,7 @@ export const RANDOM_INT = 'rand_int'; export const RANGE_OF_SIZE = 'range_of_size'; export const ITERATE = 'iterate'; export const MAIN_MODEL_LOADER = 'main_model_loader'; +export const VAE_LOADER = 'vae_loader'; export const IMAGE_TO_LATENTS = 'image_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const RESIZE = 'resize_image'; From 089d95baebde4244628d54a7df230d31d24da1a6 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sat, 1 Jul 2023 21:26:41 +1200 Subject: [PATCH 26/51] fix: Change graph id vals to constants --- .../nodes/util/graphBuilders/addVAEToGraph.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts index a4a10e574e..4dd3d644ee 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts @@ -2,10 +2,13 @@ import { RootState } from 'app/store/store'; import { NonNullableGraph } from 'features/nodes/types/types'; import { modelIdToVAEModelField } from '../modelIdToVAEModelField'; import { + IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_LATENTS, INPAINT, + INPAINT_GRAPH, LATENTS_TO_IMAGE, MAIN_MODEL_LOADER, + TEXT_TO_IMAGE_GRAPH, VAE_LOADER, } from './constants'; @@ -24,10 +27,7 @@ export const addVAEToGraph = ( }; } - if ( - graph.id === 'text_to_image_graph' || - graph.id === 'image_to_image_graph' - ) { + if (graph.id === TEXT_TO_IMAGE_GRAPH || graph.id === IMAGE_TO_IMAGE_GRAPH) { graph.edges.push({ source: { node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, @@ -40,7 +40,7 @@ export const addVAEToGraph = ( }); } - if (graph.id === 'image_to_image_graph') { + if (graph.id === IMAGE_TO_IMAGE_GRAPH) { graph.edges.push({ source: { node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, @@ -53,7 +53,7 @@ export const addVAEToGraph = ( }); } - if (graph.id === 'inpaint_graph') { + if (graph.id === INPAINT_GRAPH) { graph.edges.push({ source: { node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER, From 0988725c1b8313b406b8855fdbe37cf843a3f511 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sat, 1 Jul 2023 21:37:49 +1200 Subject: [PATCH 27/51] fix: Floating gallery button showing up in Model Manager --- .../features/ui/components/FloatingGalleryButton.tsx | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/FloatingGalleryButton.tsx b/invokeai/frontend/web/src/features/ui/components/FloatingGalleryButton.tsx index 3e2c2153e6..1bab4abe02 100644 --- a/invokeai/frontend/web/src/features/ui/components/FloatingGalleryButton.tsx +++ b/invokeai/frontend/web/src/features/ui/components/FloatingGalleryButton.tsx @@ -1,22 +1,26 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIIconButton from 'common/components/IAIIconButton'; -import { useTranslation } from 'react-i18next'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { setShouldShowGallery } from 'features/ui/store/uiSlice'; import { isEqual } from 'lodash-es'; -import { MdPhotoLibrary } from 'react-icons/md'; -import { activeTabNameSelector, uiSelector } from '../store/uiSelectors'; import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { MdPhotoLibrary } from 'react-icons/md'; +import { InvokeTabName } from '../store/tabMap'; +import { activeTabNameSelector, uiSelector } from '../store/uiSelectors'; const floatingGalleryButtonSelector = createSelector( [activeTabNameSelector, uiSelector], (activeTabName, ui) => { const { shouldPinGallery, shouldShowGallery } = ui; + const noGalleryTabs: InvokeTabName[] = ['modelmanager']; return { shouldPinGallery, - shouldShowGalleryButton: !shouldShowGallery, + shouldShowGalleryButton: noGalleryTabs.includes(activeTabName) + ? false + : !shouldShowGallery, }; }, { memoizeOptions: { resultEqualityCheck: isEqual } } From 96bf92ead4b02822cfdd45d01d6a01f55b6030c9 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 3 Jul 2023 19:32:54 -0400 Subject: [PATCH 28/51] add the import model router --- invokeai/app/api/routers/models.py | 47 +- .../app/services/model_manager_service.py | 44 + .../backend/install/model_install_backend.py | 46 +- invokeai/backend/model_management/__init__.py | 2 +- .../backend/model_management/model_manager.py | 43 +- invokeai/frontend/web/dist/index.html | 2 +- invokeai/frontend/web/dist/locales/en.json | 17 +- invokeai/frontend/web/stats.html | 1521 +---------------- 8 files changed, 233 insertions(+), 1489 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 0b03c8e729..dcbdbec82d 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -2,17 +2,17 @@ from typing import Literal, Optional, Union -from fastapi import Query +from fastapi import Query, Body from fastapi.routing import APIRouter, HTTPException from pydantic import BaseModel, Field, parse_obj_as from ..dependencies import ApiDependencies from invokeai.backend import BaseModelType, ModelType +from invokeai.backend.model_management import AddModelResult from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)] models_router = APIRouter(prefix="/v1/models", tags=["models"]) - class VaeRepo(BaseModel): repo_id: str = Field(description="The repo ID to use for this VAE") path: Optional[str] = Field(description="The path to the VAE") @@ -51,9 +51,12 @@ class CreateModelResponse(BaseModel): info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info") status: str = Field(description="The status of the API response") -class ImportModelRequest(BaseModel): - name: str = Field(description="A model path, repo_id or URL to import") - prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files') +class ImportModelResponse(BaseModel): + name: str = Field(description="The name of the imported model") +# base_model: str = Field(description="The base model") +# model_type: str = Field(description="The model type") + info: AddModelResult = Field(description="The model info") + status: str = Field(description="The status of the API response") class ConversionRequest(BaseModel): name: str = Field(description="The name of the new model") @@ -86,7 +89,6 @@ async def list_models( models = parse_obj_as(ModelsList, { "models": models_raw }) return models - @models_router.post( "/", operation_id="update_model", @@ -109,27 +111,38 @@ async def update_model( return model_response @models_router.post( - "/", + "/import", operation_id="import_model", - responses={200: {"status": "success"}}, + responses= { + 201: {"description" : "The model imported successfully"}, + 404: {"description" : "The model could not be found"}, + }, + status_code=201, + response_model=ImportModelResponse ) async def import_model( - model_request: ImportModelRequest -) -> None: - """ Add Model """ - items_to_import = set([model_request.name]) + name: str = Query(description="A model path, repo_id or URL to import"), + prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), +) -> ImportModelResponse: + """ Add a model using its local path, repo_id, or remote URL """ + items_to_import = {name} prediction_types = { x.value: x for x in SchedulerPredictionType } logger = ApiDependencies.invoker.services.logger installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( items_to_import = items_to_import, - prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type) + prediction_type_helper = lambda x: prediction_types.get(prediction_type) ) - if len(installed_models) > 0: - logger.info(f'Successfully imported {model_request.name}') + if info := installed_models.get(name): + logger.info(f'Successfully imported {name}, got {info}') + return ImportModelResponse( + name = name, + info = info, + status = "success", + ) else: - logger.error(f'Model {model_request.name} not imported') - raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported') + logger.error(f'Model {name} not imported') + raise HTTPException(status_code=404, detail=f'Model {name} not found') @models_router.delete( "/{model_name}", diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 8b46b17ad0..98b4d81ba8 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -135,6 +135,29 @@ class ModelManagerServiceBase(ABC): """ pass + @abstractmethod + def heuristic_import(self, + items_to_import: Set[str], + prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, + )->Dict[str, AddModelResult]: + '''Import a list of paths, repo_ids or URLs. Returns the set of + successfully imported items. + :param items_to_import: Set of strings corresponding to models to be imported. + :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. + + The prediction type helper is necessary to distinguish between + models based on Stable Diffusion 2 Base (requiring + SchedulerPredictionType.Epsilson) and Stable Diffusion 768 + (requiring SchedulerPredictionType.VPrediction). It is + generally impossible to do this programmatically, so the + prediction_type_helper usually asks the user to choose. + + The result is a set of successfully installed models. Each element + of the set is a dict corresponding to the newly-created OmegaConf stanza for + that model. + ''' + pass + @abstractmethod def commit(self, conf_file: Path = None) -> None: """ @@ -361,3 +384,24 @@ class ModelManagerService(ModelManagerServiceBase): def logger(self): return self.mgr.logger + def heuristic_import(self, + items_to_import: Set[str], + prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, + )->Dict[str, AddModelResult]: + '''Import a list of paths, repo_ids or URLs. Returns the set of + successfully imported items. + :param items_to_import: Set of strings corresponding to models to be imported. + :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. + + The prediction type helper is necessary to distinguish between + models based on Stable Diffusion 2 Base (requiring + SchedulerPredictionType.Epsilson) and Stable Diffusion 768 + (requiring SchedulerPredictionType.VPrediction). It is + generally impossible to do this programmatically, so the + prediction_type_helper usually asks the user to choose. + + The result is a set of successfully installed models. Each element + of the set is a dict corresponding to the newly-created OmegaConf stanza for + that model. + ''' + return self.mgr.heuristic_import(items_to_import, prediction_type_helper) diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index 1c2f4d2fc1..a10fa852c0 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -18,7 +18,7 @@ from tqdm import tqdm import invokeai.configs as configs from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType +from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo from invokeai.backend.util import download_with_resume from ..util.logging import InvokeAILogger @@ -166,17 +166,22 @@ class ModelInstall(object): # add requested models for path in selections.install_models: logger.info(f'Installing {path} [{job}/{jobs}]') - self.heuristic_install(path) + self.heuristic_import(path) job += 1 self.mgr.commit() - def heuristic_install(self, + def heuristic_import(self, model_path_id_or_url: Union[str,Path], - models_installed: Set[Path]=None)->Set[Path]: + models_installed: Set[Path]=None)->Dict[str, AddModelResult]: + ''' + :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL + :param models_installed: Set of installed models, used for recursive invocation + Returns a set of dict objects corresponding to newly-created stanzas in models.yaml. + ''' if not models_installed: - models_installed = set() + models_installed = dict() # A little hack to allow nested routines to retrieve info on the requested ID self.current_id = model_path_id_or_url @@ -185,24 +190,24 @@ class ModelInstall(object): try: # checkpoint file, or similar if path.is_file(): - models_installed.add(self._install_path(path)) + models_installed.update(self._install_path(path)) # folders style or similar elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): - models_installed.add(self._install_path(path)) + models_installed.update(self._install_path(path)) # recursive scan elif path.is_dir(): for child in path.iterdir(): - self.heuristic_install(child, models_installed=models_installed) + self.heuristic_import(child, models_installed=models_installed) # huggingface repo elif len(str(path).split('/')) == 2: - models_installed.add(self._install_repo(str(path))) + models_installed.update(self._install_repo(str(path))) # a URL elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")): - models_installed.add(self._install_url(model_path_id_or_url)) + models_installed.update(self._install_url(model_path_id_or_url)) else: logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') @@ -214,24 +219,25 @@ class ModelInstall(object): # install a model from a local path. The optional info parameter is there to prevent # the model from being probed twice in the event that it has already been probed. - def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path: + def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]: try: - # logger.debug(f'Probing {path}') + model_result = None info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) model_name = path.stem if info.format=='checkpoint' else path.name if self.mgr.model_exists(model_name, info.base_type, info.model_type): raise ValueError(f'A model named "{model_name}" is already installed.') attributes = self._make_attributes(path,info) - self.mgr.add_model(model_name = model_name, - base_model = info.base_type, - model_type = info.model_type, - model_attributes = attributes, - ) + model_result = self.mgr.add_model(model_name = model_name, + base_model = info.base_type, + model_type = info.model_type, + model_attributes = attributes, + ) except Exception as e: logger.warning(f'{str(e)} Skipping registration.') - return path + return {} + return {str(path): model_result} - def _install_url(self, url: str)->Path: + def _install_url(self, url: str)->dict: # copy to a staging area, probe, import and delete with TemporaryDirectory(dir=self.config.models_path) as staging: location = download_with_resume(url,Path(staging)) @@ -244,7 +250,7 @@ class ModelInstall(object): # staged version will be garbage-collected at this time return self._install_path(Path(models_path), info) - def _install_repo(self, repo_id: str)->Path: + def _install_repo(self, repo_id: str)->dict: hinfo = HfApi().model_info(repo_id) # we try to figure out how to download this most economically diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index fb3b20a20a..34e0b15728 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -1,7 +1,7 @@ """ Initialization file for invokeai.backend.model_management """ -from .model_manager import ModelManager, ModelInfo +from .model_manager import ModelManager, ModelInfo, AddModelResult from .model_cache import ModelCache from .models import BaseModelType, ModelType, SubModelType, ModelVariantType diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 73c68e8afc..a8cbb50474 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -233,14 +233,14 @@ import hashlib import textwrap from dataclasses import dataclass from pathlib import Path -from typing import Optional, List, Tuple, Union, Set, Callable, types +from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types from shutil import rmtree import torch from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig -from pydantic import BaseModel +from pydantic import BaseModel, Field import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig @@ -278,8 +278,13 @@ class InvalidModelError(Exception): "Raised when an invalid model is requested" pass -MAX_CACHE_SIZE = 6.0 # GB +class AddModelResult(BaseModel): + name: str = Field(description="The name of the model after import") + model_type: ModelType = Field(description="The type of model") + base_model: BaseModelType = Field(description="The base model") + config: ModelConfigBase = Field(description="The configuration of the model") +MAX_CACHE_SIZE = 6.0 # GB class ConfigMeta(BaseModel): version: str @@ -571,13 +576,16 @@ class ModelManager(object): model_type: ModelType, model_attributes: dict, clobber: bool = False, - ) -> None: + ) -> AddModelResult: """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. On a successful update, the config will be changed in memory and the method will return True. Will fail with an assertion error if provided attributes are incorrect or the model name is missing. + + The returned dict has the same format as the dict returned by + model_info(). """ model_class = MODEL_CLASSES[base_model][model_type] @@ -601,12 +609,18 @@ class ModelManager(object): old_model_cache.unlink() # remove in-memory cache - # note: it not garantie to release memory(model can has other references) + # note: it not guaranteed to release memory(model can has other references) cache_ids = self.cache_keys.pop(model_key, []) for cache_id in cache_ids: self.cache.uncache_model(cache_id) self.models[model_key] = model_config + return AddModelResult( + name = model_name, + model_type = model_type, + base_model = base_model, + config = model_config, + ) def search_models(self, search_folder): self.logger.info(f"Finding Models In: {search_folder}") @@ -729,7 +743,7 @@ class ModelManager(object): if (new_models_found or imported_models) and self.config_path: self.commit() - def autoimport(self)->set[Path]: + def autoimport(self)->Dict[str, AddModelResult]: ''' Scan the autoimport directory (if defined) and import new models, delete defunct models. ''' @@ -742,7 +756,6 @@ class ModelManager(object): prediction_type_helper = ask_user_for_prediction_type, ) - installed = set() scanned_dirs = set() config = self.app_config @@ -756,13 +769,14 @@ class ModelManager(object): continue self.logger.info(f'Scanning {autodir} for models to import') + installed = dict() autodir = self.app_config.root_path / autodir if not autodir.exists(): continue items_scanned = 0 - new_models_found = set() + new_models_found = dict() for root, dirs, files in os.walk(autodir): items_scanned += len(dirs) + len(files) @@ -772,7 +786,7 @@ class ModelManager(object): scanned_dirs.add(path) continue if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): - new_models_found.update(installer.heuristic_install(path)) + new_models_found.update(installer.heuristic_import(path)) scanned_dirs.add(path) for f in files: @@ -780,7 +794,7 @@ class ModelManager(object): if path in known_paths or path.parent in scanned_dirs: continue if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: - new_models_found.update(installer.heuristic_install(path)) + new_models_found.update(installer.heuristic_import(path)) self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models') installed.update(new_models_found) @@ -790,7 +804,7 @@ class ModelManager(object): def heuristic_import(self, items_to_import: Set[str], prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, - )->Set[str]: + )->Dict[str, AddModelResult]: '''Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. :param items_to_import: Set of strings corresponding to models to be imported. @@ -803,17 +817,20 @@ class ModelManager(object): generally impossible to do this programmatically, so the prediction_type_helper usually asks the user to choose. + The result is a set of successfully installed models. Each element + of the set is a dict corresponding to the newly-created OmegaConf stanza for + that model. ''' # avoid circular import here from invokeai.backend.install.model_install_backend import ModelInstall - successfully_installed = set() + successfully_installed = dict() installer = ModelInstall(config = self.app_config, prediction_type_helper = prediction_type_helper, model_manager = self) for thing in items_to_import: try: - installed = installer.heuristic_install(thing) + installed = installer.heuristic_import(thing) successfully_installed.update(installed) except Exception as e: self.logger.warning(f'{thing} could not be imported: {str(e)}') diff --git a/invokeai/frontend/web/dist/index.html b/invokeai/frontend/web/dist/index.html index 6c4c1c21ae..a0adc1d803 100644 --- a/invokeai/frontend/web/dist/index.html +++ b/invokeai/frontend/web/dist/index.html @@ -12,7 +12,7 @@ margin: 0; } - + diff --git a/invokeai/frontend/web/dist/locales/en.json b/invokeai/frontend/web/dist/locales/en.json index 7a73bae411..6fb56a2979 100644 --- a/invokeai/frontend/web/dist/locales/en.json +++ b/invokeai/frontend/web/dist/locales/en.json @@ -24,16 +24,13 @@ }, "common": { "hotkeysLabel": "Hotkeys", - "themeLabel": "Theme", + "darkMode": "Dark Mode", + "lightMode": "Light Mode", "languagePickerLabel": "Language", "reportBugLabel": "Report Bug", "githubLabel": "Github", "discordLabel": "Discord", "settingsLabel": "Settings", - "darkTheme": "Dark", - "lightTheme": "Light", - "greenTheme": "Green", - "oceanTheme": "Ocean", "langArabic": "العربية", "langEnglish": "English", "langDutch": "Nederlands", @@ -55,6 +52,7 @@ "unifiedCanvas": "Unified Canvas", "linear": "Linear", "nodes": "Node Editor", + "modelmanager": "Model Manager", "postprocessing": "Post Processing", "nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.", "postProcessing": "Post Processing", @@ -336,6 +334,7 @@ "modelManager": { "modelManager": "Model Manager", "model": "Model", + "vae": "VAE", "allModels": "All Models", "checkpointModels": "Checkpoints", "diffusersModels": "Diffusers", @@ -351,6 +350,7 @@ "scanForModels": "Scan For Models", "addManually": "Add Manually", "manual": "Manual", + "baseModel": "Base Model", "name": "Name", "nameValidationMsg": "Enter a name for your model", "description": "Description", @@ -363,6 +363,7 @@ "repoIDValidationMsg": "Online repository of your model", "vaeLocation": "VAE Location", "vaeLocationValidationMsg": "Path to where your VAE is located.", + "variant": "Variant", "vaeRepoID": "VAE Repo ID", "vaeRepoIDValidationMsg": "Online repository of your VAE", "width": "Width", @@ -524,7 +525,8 @@ "initialImage": "Initial Image", "showOptionsPanel": "Show Options Panel", "hidePreview": "Hide Preview", - "showPreview": "Show Preview" + "showPreview": "Show Preview", + "controlNetControlMode": "Control Mode" }, "settings": { "models": "Models", @@ -547,7 +549,8 @@ "general": "General", "generation": "Generation", "ui": "User Interface", - "availableSchedulers": "Available Schedulers" + "favoriteSchedulers": "Favorite Schedulers", + "favoriteSchedulersPlaceholder": "No schedulers favorited" }, "toast": { "serverError": "Server Error", diff --git a/invokeai/frontend/web/stats.html b/invokeai/frontend/web/stats.html index dc999e13df..7c7df1671a 100644 --- a/invokeai/frontend/web/stats.html +++ b/invokeai/frontend/web/stats.html @@ -145,9 +145,9 @@ main { var drawChart = (function (exports) { 'use strict'; - var n,l$1,u$1,t$1,o$2,r$1,f$1={},e$1=[],c$1=/acit|ex(?:s|g|n|p|$)|rph|grid|ows|mnc|ntw|ine[ch]|zoo|^ord|itera/i;function s$1(n,l){for(var u in l)n[u]=l[u];return n}function a$1(n){var l=n.parentNode;l&&l.removeChild(n);}function h$1(l,u,i){var t,o,r,f={};for(r in u)"key"==r?t=u[r]:"ref"==r?o=u[r]:f[r]=u[r];if(arguments.length>2&&(f.children=arguments.length>3?n.call(arguments,2):i),"function"==typeof l&&null!=l.defaultProps)for(r in l.defaultProps)void 0===f[r]&&(f[r]=l.defaultProps[r]);return v$1(l,f,t,o,null)}function v$1(n,i,t,o,r){var f={type:n,props:i,key:t,ref:o,__k:null,__:null,__b:0,__e:null,__d:void 0,__c:null,__h:null,constructor:void 0,__v:null==r?++u$1:r};return null==r&&null!=l$1.vnode&&l$1.vnode(f),f}function p$1(n){return n.children}function d$1(n,l){this.props=n,this.context=l;}function _$2(n,l){if(null==l)return n.__?_$2(n.__,n.__.__k.indexOf(n)+1):null;for(var u;l0?v$1(k.type,k.props,k.key,k.ref?k.ref:null,k.__v):k)){if(k.__=u,k.__b=u.__b+1,null===(d=x[h])||d&&k.key==d.key&&k.type===d.type)x[h]=void 0;else for(y=0;y2&&(f.children=arguments.length>3?n.call(arguments,2):i),"function"==typeof l&&null!=l.defaultProps)for(r in l.defaultProps)void 0===f[r]&&(f[r]=l.defaultProps[r]);return d$1(l,f,t,o,null)}function d$1(n,i,t,o,r){var f={type:n,props:i,key:t,ref:o,__k:null,__:null,__b:0,__e:null,__d:void 0,__c:null,__h:null,constructor:void 0,__v:null==r?++u$1:r};return null==r&&null!=l$1.vnode&&l$1.vnode(f),f}function k$1(n){return n.children}function b$1(n,l){this.props=n,this.context=l;}function g$1(n,l){if(null==l)return n.__?g$1(n.__,n.__.__k.indexOf(n)+1):null;for(var u;ll&&t$1.sort(f$1));x.__r=0;}function P(n,l,u,i,t,o,r,f,e,a){var h,p,y,_,b,m,w,x=i&&i.__k||s$1,P=x.length;for(u.__k=[],h=0;h0?d$1(_.type,_.props,_.key,_.ref?_.ref:null,_.__v):_)){if(_.__=u,_.__b=u.__b+1,null===(y=x[h])||y&&_.key==y.key&&_.type===y.type)x[h]=void 0;else for(p=0;p=0;l--)if((u=n.__k[l])&&(i=A(u)))return i;return null}function H(n,l,u,i,t){var o;for(o in u)"children"===o||"key"===o||o in l||T$1(n,o,null,u[o],i);for(o in l)t&&"function"!=typeof l[o]||"children"===o||"key"===o||"value"===o||"checked"===o||u[o]===l[o]||T$1(n,o,l[o],u[o],i);}function I(n,l,u){"-"===l[0]?n.setProperty(l,null==u?"":u):n[l]=null==u?"":"number"!=typeof u||a$1.test(l)?u:u+"px";}function T$1(n,l,u,i,t){var o;n:if("style"===l)if("string"==typeof u)n.style.cssText=u;else {if("string"==typeof i&&(n.style.cssText=i=""),i)for(l in i)u&&l in u||I(n.style,l,"");if(u)for(l in u)i&&u[l]===i[l]||I(n.style,l,u[l]);}else if("o"===l[0]&&"n"===l[1])o=l!==(l=l.replace(/Capture$/,"")),l=l.toLowerCase()in n?l.toLowerCase().slice(2):l.slice(2),n.l||(n.l={}),n.l[l+o]=u,u?i||n.addEventListener(l,o?z$1:j$1,o):n.removeEventListener(l,o?z$1:j$1,o);else if("dangerouslySetInnerHTML"!==l){if(t)l=l.replace(/xlink(H|:h)/,"h").replace(/sName$/,"s");else if("width"!==l&&"height"!==l&&"href"!==l&&"list"!==l&&"form"!==l&&"tabIndex"!==l&&"download"!==l&&"rowSpan"!==l&&"colSpan"!==l&&l in n)try{n[l]=null==u?"":u;break n}catch(n){}"function"==typeof u||(null==u||!1===u&&"-"!==l[4]?n.removeAttribute(l):n.setAttribute(l,u));}}function j$1(n){return this.l[n.type+!1](l$1.event?l$1.event(n):n)}function z$1(n){return this.l[n.type+!0](l$1.event?l$1.event(n):n)}function L(n,u,i,t,o,r,f,e,c){var s,a,p,y,d,_,g,m,w,x,C,S,$,A,H,I=u.type;if(void 0!==u.constructor)return null;null!=i.__h&&(c=i.__h,e=u.__e=i.__e,u.__h=null,r=[e]),(s=l$1.__b)&&s(u);try{n:if("function"==typeof I){if(m=u.props,w=(s=I.contextType)&&t[s.__c],x=s?w?w.props.value:s.__:t,i.__c?g=(a=u.__c=i.__c).__=a.__E:("prototype"in I&&I.prototype.render?u.__c=a=new I(m,x):(u.__c=a=new b$1(m,x),a.constructor=I,a.render=B$1),w&&w.sub(a),a.props=m,a.state||(a.state={}),a.context=x,a.__n=t,p=a.__d=!0,a.__h=[],a._sb=[]),null==a.__s&&(a.__s=a.state),null!=I.getDerivedStateFromProps&&(a.__s==a.state&&(a.__s=h$1({},a.__s)),h$1(a.__s,I.getDerivedStateFromProps(m,a.__s))),y=a.props,d=a.state,a.__v=u,p)null==I.getDerivedStateFromProps&&null!=a.componentWillMount&&a.componentWillMount(),null!=a.componentDidMount&&a.__h.push(a.componentDidMount);else {if(null==I.getDerivedStateFromProps&&m!==y&&null!=a.componentWillReceiveProps&&a.componentWillReceiveProps(m,x),!a.__e&&null!=a.shouldComponentUpdate&&!1===a.shouldComponentUpdate(m,a.__s,x)||u.__v===i.__v){for(u.__v!==i.__v&&(a.props=m,a.state=a.__s,a.__d=!1),a.__e=!1,u.__e=i.__e,u.__k=i.__k,u.__k.forEach(function(n){n&&(n.__=u);}),C=0;C=i.__.length&&i.__.push({__V:c}),i.__[t]}function p(n){return o=1,y(B$1,n)}function y(n,u,i){var o=d(t++,2);if(o.t=n,!o.__c&&(o.__=[i?i(u):B$1(void 0,u),function(n){var t=o.__N?o.__N[0]:o.__[0],r=o.t(t,n);t!==r&&(o.__N=[r,o.__[1]],o.__c.setState({}));}],o.__c=r,!r.u)){r.u=!0;var f=r.shouldComponentUpdate;r.shouldComponentUpdate=function(n,t,r){if(!o.__c.__H)return !0;var u=o.__c.__H.__.filter(function(n){return n.__c});if(u.every(function(n){return !n.__N}))return !f||f.call(this,n,t,r);var i=!1;return u.forEach(function(n){if(n.__N){var t=n.__[0];n.__=n.__N,n.__N=void 0,t!==n.__[0]&&(i=!0);}}),!(!i&&o.__c.props===n)&&(!f||f.call(this,n,t,r))};}return o.__N||o.__}function h(u,i){var o=d(t++,3);!l$1.__s&&z(o.__H,i)&&(o.__=u,o.i=i,r.__H.__h.push(o));}function s(u,i){var o=d(t++,4);!l$1.__s&&z(o.__H,i)&&(o.__=u,o.i=i,r.__h.push(o));}function _(n){return o=5,F(function(){return {current:n}},[])}function F(n,r){var u=d(t++,7);return z(u.__H,r)?(u.__V=n(),u.i=r,u.__h=n,u.__V):u.__}function T(n,t){return o=8,F(function(){return n},t)}function q(n){var u=r.context[n.__c],i=d(t++,9);return i.c=n,u?(null==i.__&&(i.__=!0,u.sub(r)),u.props.value):n.__}function b(){for(var t;t=f.shift();)if(t.__P&&t.__H)try{t.__H.__h.forEach(k),t.__H.__h.forEach(w),t.__H.__h=[];}catch(r){t.__H.__h=[],l$1.__e(r,t.__v);}}l$1.__b=function(n){r=null,e&&e(n);},l$1.__r=function(n){a&&a(n),t=0;var i=(r=n.__c).__H;i&&(u===r?(i.__h=[],r.__h=[],i.__.forEach(function(n){n.__N&&(n.__=n.__N),n.__V=c,n.__N=n.i=void 0;})):(i.__h.forEach(k),i.__h.forEach(w),i.__h=[])),u=r;},l$1.diffed=function(t){v&&v(t);var o=t.__c;o&&o.__H&&(o.__H.__h.length&&(1!==f.push(o)&&i===l$1.requestAnimationFrame||((i=l$1.requestAnimationFrame)||j)(b)),o.__H.__.forEach(function(n){n.i&&(n.__H=n.i),n.__V!==c&&(n.__=n.__V),n.i=void 0,n.__V=c;})),u=r=null;},l$1.__c=function(t,r){r.some(function(t){try{t.__h.forEach(k),t.__h=t.__h.filter(function(n){return !n.__||w(n)});}catch(u){r.some(function(n){n.__h&&(n.__h=[]);}),r=[],l$1.__e(u,t.__v);}}),l&&l(t,r);},l$1.unmount=function(t){m&&m(t);var r,u=t.__c;u&&u.__H&&(u.__H.__.forEach(function(n){try{k(n);}catch(n){r=n;}}),u.__H=void 0,r&&l$1.__e(r,u.__v));};var g="function"==typeof requestAnimationFrame;function j(n){var t,r=function(){clearTimeout(u),g&&cancelAnimationFrame(t),setTimeout(n);},u=setTimeout(r,100);g&&(t=requestAnimationFrame(r));}function k(n){var t=r,u=n.__c;"function"==typeof u&&(n.__c=void 0,u()),r=t;}function w(n){var t=r;n.__c=n.__(),r=t;}function z(n,t){return !n||n.length!==t.length||t.some(function(t,r){return t!==n[r]})}function B$1(n,t){return "function"==typeof t?t(n):t} + var t,r,u,i,o=0,f=[],c=[],e=l$1.__b,a=l$1.__r,v=l$1.diffed,l=l$1.__c,m=l$1.unmount;function d(t,u){l$1.__h&&l$1.__h(r,t,o||u),o=0;var i=r.__H||(r.__H={__:[],__h:[]});return t>=i.__.length&&i.__.push({__V:c}),i.__[t]}function h(n){return o=1,s(B,n)}function s(n,u,i){var o=d(t++,2);if(o.t=n,!o.__c&&(o.__=[i?i(u):B(void 0,u),function(n){var t=o.__N?o.__N[0]:o.__[0],r=o.t(t,n);t!==r&&(o.__N=[r,o.__[1]],o.__c.setState({}));}],o.__c=r,!r.u)){var f=function(n,t,r){if(!o.__c.__H)return !0;var u=o.__c.__H.__.filter(function(n){return n.__c});if(u.every(function(n){return !n.__N}))return !c||c.call(this,n,t,r);var i=!1;return u.forEach(function(n){if(n.__N){var t=n.__[0];n.__=n.__N,n.__N=void 0,t!==n.__[0]&&(i=!0);}}),!(!i&&o.__c.props===n)&&(!c||c.call(this,n,t,r))};r.u=!0;var c=r.shouldComponentUpdate,e=r.componentWillUpdate;r.componentWillUpdate=function(n,t,r){if(this.__e){var u=c;c=void 0,f(n,t,r),c=u;}e&&e.call(this,n,t,r);},r.shouldComponentUpdate=f;}return o.__N||o.__}function p(u,i){var o=d(t++,3);!l$1.__s&&z(o.__H,i)&&(o.__=u,o.i=i,r.__H.__h.push(o));}function y(u,i){var o=d(t++,4);!l$1.__s&&z(o.__H,i)&&(o.__=u,o.i=i,r.__h.push(o));}function _(n){return o=5,F(function(){return {current:n}},[])}function F(n,r){var u=d(t++,7);return z(u.__H,r)?(u.__V=n(),u.i=r,u.__h=n,u.__V):u.__}function T(n,t){return o=8,F(function(){return n},t)}function q(n){var u=r.context[n.__c],i=d(t++,9);return i.c=n,u?(null==i.__&&(i.__=!0,u.sub(r)),u.props.value):n.__}function b(){for(var t;t=f.shift();)if(t.__P&&t.__H)try{t.__H.__h.forEach(k),t.__H.__h.forEach(w),t.__H.__h=[];}catch(r){t.__H.__h=[],l$1.__e(r,t.__v);}}l$1.__b=function(n){r=null,e&&e(n);},l$1.__r=function(n){a&&a(n),t=0;var i=(r=n.__c).__H;i&&(u===r?(i.__h=[],r.__h=[],i.__.forEach(function(n){n.__N&&(n.__=n.__N),n.__V=c,n.__N=n.i=void 0;})):(i.__h.forEach(k),i.__h.forEach(w),i.__h=[],t=0)),u=r;},l$1.diffed=function(t){v&&v(t);var o=t.__c;o&&o.__H&&(o.__H.__h.length&&(1!==f.push(o)&&i===l$1.requestAnimationFrame||((i=l$1.requestAnimationFrame)||j)(b)),o.__H.__.forEach(function(n){n.i&&(n.__H=n.i),n.__V!==c&&(n.__=n.__V),n.i=void 0,n.__V=c;})),u=r=null;},l$1.__c=function(t,r){r.some(function(t){try{t.__h.forEach(k),t.__h=t.__h.filter(function(n){return !n.__||w(n)});}catch(u){r.some(function(n){n.__h&&(n.__h=[]);}),r=[],l$1.__e(u,t.__v);}}),l&&l(t,r);},l$1.unmount=function(t){m&&m(t);var r,u=t.__c;u&&u.__H&&(u.__H.__.forEach(function(n){try{k(n);}catch(n){r=n;}}),u.__H=void 0,r&&l$1.__e(r,u.__v));};var g="function"==typeof requestAnimationFrame;function j(n){var t,r=function(){clearTimeout(u),g&&cancelAnimationFrame(t),setTimeout(n);},u=setTimeout(r,100);g&&(t=requestAnimationFrame(r));}function k(n){var t=r,u=n.__c;"function"==typeof u&&(n.__c=void 0,u()),r=t;}function w(n){var t=r;n.__c=n.__(),r=t;}function z(n,t){return !n||n.length!==t.length||t.some(function(t,r){return t!==n[r]})}function B(n,t){return "function"==typeof t?t(n):t} const PLACEHOLDER = "bundle-*:**/file/**,**/file**, bundle-*:"; const SideBar = ({ availableSizeProperties, sizeProperty, setSizeProperty, onExcludeChange, onIncludeChange, }) => { - const [includeValue, setIncludeValue] = p(""); - const [excludeValue, setExcludeValue] = p(""); + const [includeValue, setIncludeValue] = h(""); + const [excludeValue, setExcludeValue] = h(""); const handleSizePropertyChange = (sizeProp) => () => { if (sizeProp !== sizeProperty) { setSizeProperty(sizeProp); @@ -682,23 +680,17 @@ var drawChart = (function (exports) { setExcludeValue(value); onExcludeChange(value); }; - return (o$1("aside", Object.assign({ className: "sidebar" }, { children: [o$1("div", Object.assign({ className: "size-selectors" }, { children: availableSizeProperties.length > 1 && + return (o$1("aside", { className: "sidebar", children: [o$1("div", { className: "size-selectors", children: availableSizeProperties.length > 1 && availableSizeProperties.map((sizeProp) => { const id = `selector-${sizeProp}`; - return (o$1("div", Object.assign({ className: "size-selector" }, { children: [o$1("input", { type: "radio", id: id, checked: sizeProp === sizeProperty, onChange: handleSizePropertyChange(sizeProp) }), o$1("label", Object.assign({ htmlFor: id }, { children: LABELS[sizeProp] }))] }), sizeProp)); - }) })), o$1("div", Object.assign({ className: "module-filters" }, { children: [o$1("div", Object.assign({ className: "module-filter" }, { children: [o$1("label", Object.assign({ htmlFor: "module-filter-exclude" }, { children: "Exclude" })), o$1("input", { type: "text", id: "module-filter-exclude", value: excludeValue, onInput: handleExcludeChange, placeholder: PLACEHOLDER })] })), o$1("div", Object.assign({ className: "module-filter" }, { children: [o$1("label", Object.assign({ htmlFor: "module-filter-include" }, { children: "Include" })), o$1("input", { type: "text", id: "module-filter-include", value: includeValue, onInput: handleIncludeChange, placeholder: PLACEHOLDER })] }))] }))] }))); + return (o$1("div", { className: "size-selector", children: [o$1("input", { type: "radio", id: id, checked: sizeProp === sizeProperty, onChange: handleSizePropertyChange(sizeProp) }), o$1("label", { htmlFor: id, children: LABELS[sizeProp] })] }, sizeProp)); + }) }), o$1("div", { className: "module-filters", children: [o$1("div", { className: "module-filter", children: [o$1("label", { htmlFor: "module-filter-exclude", children: "Exclude" }), o$1("input", { type: "text", id: "module-filter-exclude", value: excludeValue, onInput: handleExcludeChange, placeholder: PLACEHOLDER })] }), o$1("div", { className: "module-filter", children: [o$1("label", { htmlFor: "module-filter-include", children: "Include" }), o$1("input", { type: "text", id: "module-filter-include", value: includeValue, onInput: handleIncludeChange, placeholder: PLACEHOLDER })] })] })] })); }; function getDefaultExportFromCjs (x) { return x && x.__esModule && Object.prototype.hasOwnProperty.call(x, 'default') ? x['default'] : x; } - var picomatchBrowserExports = {}; - var picomatchBrowser = { - get exports(){ return picomatchBrowserExports; }, - set exports(v){ picomatchBrowserExports = v; }, - }; - var utils$3 = {}; const WIN_SLASH = '\\\\/'; @@ -941,7 +933,7 @@ var drawChart = (function (exports) { } else { return path.replace(/\/$/, '').replace(/.*\//, ''); } - }; + }; } (utils$3)); const utils$2 = utils$3; @@ -2738,12 +2730,9 @@ var drawChart = (function (exports) { var picomatch_1 = picomatch; - (function (module) { + var picomatchBrowser = picomatch_1; - module.exports = picomatch_1; - } (picomatchBrowser)); - - var pm = /*@__PURE__*/getDefaultExportFromCjs(picomatchBrowserExports); + var pm = /*@__PURE__*/getDefaultExportFromCjs(picomatchBrowser); function isArray(arg) { return Array.isArray(arg); @@ -2834,8 +2823,8 @@ var drawChart = (function (exports) { })); }; const useFilter = () => { - const [includeFilter, setIncludeFilter] = p(""); - const [excludeFilter, setExcludeFilter] = p(""); + const [includeFilter, setIncludeFilter] = h(""); + const [excludeFilter, setExcludeFilter] = h(""); const setIncludeFilterTrottled = F(() => throttleFilter(setIncludeFilter, 200), []); const setExcludeFilterTrottled = F(() => throttleFilter(setExcludeFilter, 200), []); const isIncluded = F(() => createFilter(prepareFilter(includeFilter), prepareFilter(excludeFilter)), [includeFilter, excludeFilter]); @@ -2924,6 +2913,7 @@ var drawChart = (function (exports) { const ascendingBisect = bisector(ascending); const bisectRight = ascendingBisect.right; bisector(number$1).center; + var bisect = bisectRight; class InternMap extends Map { constructor(entries, key = keyof) { @@ -2997,59 +2987,60 @@ var drawChart = (function (exports) { })(values, 0); } - var e10 = Math.sqrt(50), + const e10 = Math.sqrt(50), e5 = Math.sqrt(10), e2 = Math.sqrt(2); - function ticks(start, stop, count) { - var reverse, - i = -1, - n, - ticks, - step; - - stop = +stop, start = +start, count = +count; - if (start === stop && count > 0) return [start]; - if (reverse = stop < start) n = start, start = stop, stop = n; - if ((step = tickIncrement(start, stop, count)) === 0 || !isFinite(step)) return []; - - if (step > 0) { - let r0 = Math.round(start / step), r1 = Math.round(stop / step); - if (r0 * step < start) ++r0; - if (r1 * step > stop) --r1; - ticks = new Array(n = r1 - r0 + 1); - while (++i < n) ticks[i] = (r0 + i) * step; + function tickSpec(start, stop, count) { + const step = (stop - start) / Math.max(0, count), + power = Math.floor(Math.log10(step)), + error = step / Math.pow(10, power), + factor = error >= e10 ? 10 : error >= e5 ? 5 : error >= e2 ? 2 : 1; + let i1, i2, inc; + if (power < 0) { + inc = Math.pow(10, -power) / factor; + i1 = Math.round(start * inc); + i2 = Math.round(stop * inc); + if (i1 / inc < start) ++i1; + if (i2 / inc > stop) --i2; + inc = -inc; } else { - step = -step; - let r0 = Math.round(start * step), r1 = Math.round(stop * step); - if (r0 / step < start) ++r0; - if (r1 / step > stop) --r1; - ticks = new Array(n = r1 - r0 + 1); - while (++i < n) ticks[i] = (r0 + i) / step; + inc = Math.pow(10, power) * factor; + i1 = Math.round(start / inc); + i2 = Math.round(stop / inc); + if (i1 * inc < start) ++i1; + if (i2 * inc > stop) --i2; } + if (i2 < i1 && 0.5 <= count && count < 2) return tickSpec(start, stop, count * 2); + return [i1, i2, inc]; + } - if (reverse) ticks.reverse(); - + function ticks(start, stop, count) { + stop = +stop, start = +start, count = +count; + if (!(count > 0)) return []; + if (start === stop) return [start]; + const reverse = stop < start, [i1, i2, inc] = reverse ? tickSpec(stop, start, count) : tickSpec(start, stop, count); + if (!(i2 >= i1)) return []; + const n = i2 - i1 + 1, ticks = new Array(n); + if (reverse) { + if (inc < 0) for (let i = 0; i < n; ++i) ticks[i] = (i2 - i) / -inc; + else for (let i = 0; i < n; ++i) ticks[i] = (i2 - i) * inc; + } else { + if (inc < 0) for (let i = 0; i < n; ++i) ticks[i] = (i1 + i) / -inc; + else for (let i = 0; i < n; ++i) ticks[i] = (i1 + i) * inc; + } return ticks; } function tickIncrement(start, stop, count) { - var step = (stop - start) / Math.max(0, count), - power = Math.floor(Math.log(step) / Math.LN10), - error = step / Math.pow(10, power); - return power >= 0 - ? (error >= e10 ? 10 : error >= e5 ? 5 : error >= e2 ? 2 : 1) * Math.pow(10, power) - : -Math.pow(10, -power) / (error >= e10 ? 10 : error >= e5 ? 5 : error >= e2 ? 2 : 1); + stop = +stop, start = +start, count = +count; + return tickSpec(start, stop, count)[2]; } function tickStep(start, stop, count) { - var step0 = Math.abs(stop - start) / Math.max(0, count), - step1 = Math.pow(10, Math.floor(Math.log(step0) / Math.LN10)), - error = step0 / step1; - if (error >= e10) step1 *= 10; - else if (error >= e5) step1 *= 5; - else if (error >= e2) step1 *= 2; - return stop < start ? -step1 : step1; + stop = +stop, start = +start, count = +count; + const reverse = stop < start, inc = reverse ? tickIncrement(stop, start, count) : tickIncrement(start, stop, count); + return (reverse ? -1 : 1) * (inc < 0 ? 1 / -inc : inc); } const TOP_PADDING = 20; @@ -3075,7 +3066,7 @@ var drawChart = (function (exports) { else { textProps.y = height / 2; } - s(() => { + y(() => { if (width == 0 || height == 0 || !textRef.current) { return; } @@ -3100,18 +3091,18 @@ var drawChart = (function (exports) { if (width == 0 || height == 0) { return null; } - return (o$1("g", Object.assign({ className: "node", transform: `translate(${x0},${y0})`, onClick: (event) => { + return (o$1("g", { className: "node", transform: `translate(${x0},${y0})`, onClick: (event) => { event.stopPropagation(); onClick(node); }, onMouseOver: (event) => { event.stopPropagation(); onMouseOver(node); - } }, { children: [o$1("rect", { fill: backgroundColor, rx: 2, ry: 2, width: x1 - x0, height: y1 - y0, stroke: selected ? "#fff" : undefined, "stroke-width": selected ? 2 : undefined }), o$1("text", Object.assign({ ref: textRef, fill: fontColor, onClick: (event) => { + }, children: [o$1("rect", { fill: backgroundColor, rx: 2, ry: 2, width: x1 - x0, height: y1 - y0, stroke: selected ? "#fff" : undefined, "stroke-width": selected ? 2 : undefined }), o$1("text", Object.assign({ ref: textRef, fill: fontColor, onClick: (event) => { var _a; if (((_a = window.getSelection()) === null || _a === void 0 ? void 0 : _a.toString()) !== "") { event.stopPropagation(); } - } }, textProps, { children: data.name }))] }))); + } }, textProps, { children: data.name }))] })); }; const TreeMap = ({ root, onNodeHover, selectedNode, onNodeClick, }) => { @@ -3128,18 +3119,14 @@ var drawChart = (function (exports) { return nestedData; }, [root]); console.timeEnd("layering"); - return (o$1("svg", Object.assign({ xmlns: "http://www.w3.org/2000/svg", viewBox: `0 0 ${width} ${height}` }, { children: nestedData.map(({ key, values }) => { - return (o$1("g", Object.assign({ className: "layer" }, { children: values.map((node) => { + return (o$1("svg", { xmlns: "http://www.w3.org/2000/svg", viewBox: `0 0 ${width} ${height}`, children: nestedData.map(({ key, values }) => { + return (o$1("g", { className: "layer", children: values.map((node) => { return (o$1(Node, { node: node, onMouseOver: onNodeHover, selected: selectedNode === node, onClick: onNodeClick }, getModuleIds(node.data).nodeUid.id)); - }) }), key)); - }) }))); + }) }, key)); + }) })); }; - var bytesExports = {}; - var bytes$1 = { - get exports(){ return bytesExports; }, - set exports(v){ bytesExports = v; }, - }; + var bytes$1 = {exports: {}}; /*! * bytes @@ -3154,8 +3141,8 @@ var drawChart = (function (exports) { */ bytes$1.exports = bytes; - var format_1 = bytesExports.format = format$1; - bytesExports.parse = parse; + var format_1 = bytes$1.exports.format = format$1; + bytes$1.exports.parse = parse; /** * Module variables. @@ -3318,7 +3305,7 @@ var drawChart = (function (exports) { const Tooltip = ({ node, visible, root, sizeProperty, }) => { const { availableSizeProperties, getModuleSize, data } = q(StaticContext); const ref = _(null); - const [style, setStyle] = p({}); + const [style, setStyle] = h({}); const content = F(() => { if (!node) return null; @@ -3336,7 +3323,7 @@ var drawChart = (function (exports) { const mainUid = data.nodeParts[node.data.uid].metaUid; dataNode = data.nodeMetas[mainUid]; } - return (o$1(p$1, { children: [o$1("div", { children: path }), availableSizeProperties.map((sizeProp) => { + return (o$1(k$1, { children: [o$1("div", { children: path }), availableSizeProperties.map((sizeProp) => { if (sizeProp === sizeProperty) { return (o$1("div", { children: [o$1("b", { children: [LABELS[sizeProp], ": ", format_1(mainSize)] }), " ", "(", percentageString, ")"] }, sizeProp)); } @@ -3346,7 +3333,7 @@ var drawChart = (function (exports) { }), o$1("br", {}), dataNode && dataNode.importedBy.length > 0 && (o$1("div", { children: [o$1("div", { children: [o$1("b", { children: "Imported By" }), ":"] }), dataNode.importedBy.map(({ uid }) => { const id = data.nodeMetas[uid].id; return o$1("div", { children: id }, id); - })] })), o$1("br", {}), o$1("small", { children: data.options.sourcemap ? SOURCEMAP_RENDERED : RENDRED }), (data.options.gzip || data.options.brotli) && (o$1(p$1, { children: [o$1("br", {}), o$1("small", { children: COMPRESSED })] }))] })); + })] })), o$1("br", {}), o$1("small", { children: data.options.sourcemap ? SOURCEMAP_RENDERED : RENDRED }), (data.options.gzip || data.options.brotli) && (o$1(k$1, { children: [o$1("br", {}), o$1("small", { children: COMPRESSED })] }))] })); }, [availableSizeProperties, data, getModuleSize, node, root.data, sizeProperty]); const updatePosition = (mouseCoords) => { if (!ref.current) @@ -3366,7 +3353,7 @@ var drawChart = (function (exports) { } setStyle(pos); }; - h(() => { + p(() => { const handleMouseMove = (event) => { updatePosition({ x: event.pageX, @@ -3378,13 +3365,13 @@ var drawChart = (function (exports) { document.removeEventListener("mousemove", handleMouseMove, true); }; }, []); - return (o$1("div", Object.assign({ className: `tooltip ${visible ? "" : "tooltip-hidden"}`, ref: ref, style: style }, { children: content }))); + return (o$1("div", { className: `tooltip ${visible ? "" : "tooltip-hidden"}`, ref: ref, style: style, children: content })); }; const Chart = ({ root, sizeProperty, selectedNode, setSelectedNode, }) => { - const [showTooltip, setShowTooltip] = p(false); - const [tooltipNode, setTooltipNode] = p(undefined); - h(() => { + const [showTooltip, setShowTooltip] = h(false); + const [tooltipNode, setTooltipNode] = h(undefined); + p(() => { const handleMouseOut = () => { setShowTooltip(false); }; @@ -3393,7 +3380,7 @@ var drawChart = (function (exports) { document.removeEventListener("mouseover", handleMouseOut); }; }, []); - return (o$1(p$1, { children: [o$1(TreeMap, { root: root, onNodeHover: (node) => { + return (o$1(k$1, { children: [o$1(TreeMap, { root: root, onNodeHover: (node) => { setTooltipNode(node); setShowTooltip(true); }, selectedNode: selectedNode, onNodeClick: (node) => { @@ -3403,8 +3390,8 @@ var drawChart = (function (exports) { const Main = () => { const { availableSizeProperties, rawHierarchy, getModuleSize, layout, data } = q(StaticContext); - const [sizeProperty, setSizeProperty] = p(availableSizeProperties[0]); - const [selectedNode, setSelectedNode] = p(undefined); + const [sizeProperty, setSizeProperty] = h(availableSizeProperties[0]); + const [selectedNode, setSelectedNode] = h(undefined); const { getModuleFilterMultiplier, setExcludeFilter, setIncludeFilter } = useFilter(); console.time("getNodeSizeMultiplier"); const getNodeSizeMultiplier = F(() => { @@ -3459,7 +3446,7 @@ var drawChart = (function (exports) { sizeProperty, ]); console.timeEnd("root hierarchy compute"); - return (o$1(p$1, { children: [o$1(SideBar, { sizeProperty: sizeProperty, availableSizeProperties: availableSizeProperties, setSizeProperty: setSizeProperty, onExcludeChange: setExcludeFilter, onIncludeChange: setIncludeFilter }), o$1(Chart, { root: root, sizeProperty: sizeProperty, selectedNode: selectedNode, setSelectedNode: setSelectedNode })] })); + return (o$1(k$1, { children: [o$1(SideBar, { sizeProperty: sizeProperty, availableSizeProperties: availableSizeProperties, setSizeProperty: setSizeProperty, onExcludeChange: setExcludeFilter, onIncludeChange: setIncludeFilter }), o$1(Chart, { root: root, sizeProperty: sizeProperty, selectedNode: selectedNode, setSelectedNode: setSelectedNode })] })); }; function initRange(domain, range) { @@ -3895,179 +3882,6 @@ var drawChart = (function (exports) { : m1) * 255; } - const radians = Math.PI / 180; - const degrees = 180 / Math.PI; - - // https://observablehq.com/@mbostock/lab-and-rgb - const K = 18, - Xn = 0.96422, - Yn = 1, - Zn = 0.82521, - t0$1 = 4 / 29, - t1$1 = 6 / 29, - t2 = 3 * t1$1 * t1$1, - t3 = t1$1 * t1$1 * t1$1; - - function labConvert(o) { - if (o instanceof Lab) return new Lab(o.l, o.a, o.b, o.opacity); - if (o instanceof Hcl) return hcl2lab(o); - if (!(o instanceof Rgb)) o = rgbConvert(o); - var r = rgb2lrgb(o.r), - g = rgb2lrgb(o.g), - b = rgb2lrgb(o.b), - y = xyz2lab((0.2225045 * r + 0.7168786 * g + 0.0606169 * b) / Yn), x, z; - if (r === g && g === b) x = z = y; else { - x = xyz2lab((0.4360747 * r + 0.3850649 * g + 0.1430804 * b) / Xn); - z = xyz2lab((0.0139322 * r + 0.0971045 * g + 0.7141733 * b) / Zn); - } - return new Lab(116 * y - 16, 500 * (x - y), 200 * (y - z), o.opacity); - } - - function lab(l, a, b, opacity) { - return arguments.length === 1 ? labConvert(l) : new Lab(l, a, b, opacity == null ? 1 : opacity); - } - - function Lab(l, a, b, opacity) { - this.l = +l; - this.a = +a; - this.b = +b; - this.opacity = +opacity; - } - - define(Lab, lab, extend(Color, { - brighter(k) { - return new Lab(this.l + K * (k == null ? 1 : k), this.a, this.b, this.opacity); - }, - darker(k) { - return new Lab(this.l - K * (k == null ? 1 : k), this.a, this.b, this.opacity); - }, - rgb() { - var y = (this.l + 16) / 116, - x = isNaN(this.a) ? y : y + this.a / 500, - z = isNaN(this.b) ? y : y - this.b / 200; - x = Xn * lab2xyz(x); - y = Yn * lab2xyz(y); - z = Zn * lab2xyz(z); - return new Rgb( - lrgb2rgb( 3.1338561 * x - 1.6168667 * y - 0.4906146 * z), - lrgb2rgb(-0.9787684 * x + 1.9161415 * y + 0.0334540 * z), - lrgb2rgb( 0.0719453 * x - 0.2289914 * y + 1.4052427 * z), - this.opacity - ); - } - })); - - function xyz2lab(t) { - return t > t3 ? Math.pow(t, 1 / 3) : t / t2 + t0$1; - } - - function lab2xyz(t) { - return t > t1$1 ? t * t * t : t2 * (t - t0$1); - } - - function lrgb2rgb(x) { - return 255 * (x <= 0.0031308 ? 12.92 * x : 1.055 * Math.pow(x, 1 / 2.4) - 0.055); - } - - function rgb2lrgb(x) { - return (x /= 255) <= 0.04045 ? x / 12.92 : Math.pow((x + 0.055) / 1.055, 2.4); - } - - function hclConvert(o) { - if (o instanceof Hcl) return new Hcl(o.h, o.c, o.l, o.opacity); - if (!(o instanceof Lab)) o = labConvert(o); - if (o.a === 0 && o.b === 0) return new Hcl(NaN, 0 < o.l && o.l < 100 ? 0 : NaN, o.l, o.opacity); - var h = Math.atan2(o.b, o.a) * degrees; - return new Hcl(h < 0 ? h + 360 : h, Math.sqrt(o.a * o.a + o.b * o.b), o.l, o.opacity); - } - - function hcl(h, c, l, opacity) { - return arguments.length === 1 ? hclConvert(h) : new Hcl(h, c, l, opacity == null ? 1 : opacity); - } - - function Hcl(h, c, l, opacity) { - this.h = +h; - this.c = +c; - this.l = +l; - this.opacity = +opacity; - } - - function hcl2lab(o) { - if (isNaN(o.h)) return new Lab(o.l, 0, 0, o.opacity); - var h = o.h * radians; - return new Lab(o.l, Math.cos(h) * o.c, Math.sin(h) * o.c, o.opacity); - } - - define(Hcl, hcl, extend(Color, { - brighter(k) { - return new Hcl(this.h, this.c, this.l + K * (k == null ? 1 : k), this.opacity); - }, - darker(k) { - return new Hcl(this.h, this.c, this.l - K * (k == null ? 1 : k), this.opacity); - }, - rgb() { - return hcl2lab(this).rgb(); - } - })); - - var A = -0.14861, - B = +1.78277, - C = -0.29227, - D = -0.90649, - E = +1.97294, - ED = E * D, - EB = E * B, - BC_DA = B * C - D * A; - - function cubehelixConvert(o) { - if (o instanceof Cubehelix) return new Cubehelix(o.h, o.s, o.l, o.opacity); - if (!(o instanceof Rgb)) o = rgbConvert(o); - var r = o.r / 255, - g = o.g / 255, - b = o.b / 255, - l = (BC_DA * b + ED * r - EB * g) / (BC_DA + ED - EB), - bl = b - l, - k = (E * (g - l) - C * bl) / D, - s = Math.sqrt(k * k + bl * bl) / (E * l * (1 - l)), // NaN if l=0 or l=1 - h = s ? Math.atan2(k, bl) * degrees - 120 : NaN; - return new Cubehelix(h < 0 ? h + 360 : h, s, l, o.opacity); - } - - function cubehelix$1(h, s, l, opacity) { - return arguments.length === 1 ? cubehelixConvert(h) : new Cubehelix(h, s, l, opacity == null ? 1 : opacity); - } - - function Cubehelix(h, s, l, opacity) { - this.h = +h; - this.s = +s; - this.l = +l; - this.opacity = +opacity; - } - - define(Cubehelix, cubehelix$1, extend(Color, { - brighter(k) { - k = k == null ? brighter : Math.pow(brighter, k); - return new Cubehelix(this.h, this.s, this.l * k, this.opacity); - }, - darker(k) { - k = k == null ? darker : Math.pow(darker, k); - return new Cubehelix(this.h, this.s, this.l * k, this.opacity); - }, - rgb() { - var h = isNaN(this.h) ? 0 : (this.h + 120) * radians, - l = +this.l, - a = isNaN(this.s) ? 0 : this.s * l * (1 - l), - cosh = Math.cos(h), - sinh = Math.sin(h); - return new Rgb( - 255 * (l + a * (A * cosh + B * sinh)), - 255 * (l + a * (C * cosh + D * sinh)), - 255 * (l + a * (E * cosh)), - this.opacity - ); - } - })); - var constant = x => () => x; function linear$1(a, d) { @@ -4082,11 +3896,6 @@ var drawChart = (function (exports) { }; } - function hue(a, b) { - var d = b - a; - return d ? linear$1(a, d > 180 || d < -180 ? d - 360 * Math.round(d / 360) : d) : constant(isNaN(a) ? b : a); - } - function gamma(y) { return (y = +y) === 1 ? nogamma : function(a, b) { return b - a ? exponential(a, b, y) : constant(isNaN(a) ? b : a); @@ -4268,105 +4077,6 @@ var drawChart = (function (exports) { }; } - var epsilon2 = 1e-12; - - function cosh(x) { - return ((x = Math.exp(x)) + 1 / x) / 2; - } - - function sinh(x) { - return ((x = Math.exp(x)) - 1 / x) / 2; - } - - function tanh(x) { - return ((x = Math.exp(2 * x)) - 1) / (x + 1); - } - - ((function zoomRho(rho, rho2, rho4) { - - // p0 = [ux0, uy0, w0] - // p1 = [ux1, uy1, w1] - function zoom(p0, p1) { - var ux0 = p0[0], uy0 = p0[1], w0 = p0[2], - ux1 = p1[0], uy1 = p1[1], w1 = p1[2], - dx = ux1 - ux0, - dy = uy1 - uy0, - d2 = dx * dx + dy * dy, - i, - S; - - // Special case for u0 ≅ u1. - if (d2 < epsilon2) { - S = Math.log(w1 / w0) / rho; - i = function(t) { - return [ - ux0 + t * dx, - uy0 + t * dy, - w0 * Math.exp(rho * t * S) - ]; - }; - } - - // General case. - else { - var d1 = Math.sqrt(d2), - b0 = (w1 * w1 - w0 * w0 + rho4 * d2) / (2 * w0 * rho2 * d1), - b1 = (w1 * w1 - w0 * w0 - rho4 * d2) / (2 * w1 * rho2 * d1), - r0 = Math.log(Math.sqrt(b0 * b0 + 1) - b0), - r1 = Math.log(Math.sqrt(b1 * b1 + 1) - b1); - S = (r1 - r0) / rho; - i = function(t) { - var s = t * S, - coshr0 = cosh(r0), - u = w0 / (rho2 * d1) * (coshr0 * tanh(rho * s + r0) - sinh(r0)); - return [ - ux0 + u * dx, - uy0 + u * dy, - w0 * coshr0 / cosh(rho * s + r0) - ]; - }; - } - - i.duration = S * 1000 * rho / Math.SQRT2; - - return i; - } - - zoom.rho = function(_) { - var _1 = Math.max(1e-3, +_), _2 = _1 * _1, _4 = _2 * _2; - return zoomRho(_1, _2, _4); - }; - - return zoom; - }))(Math.SQRT2, 2, 4); - - function cubehelix(hue) { - return (function cubehelixGamma(y) { - y = +y; - - function cubehelix(start, end) { - var h = hue((start = cubehelix$1(start)).h, (end = cubehelix$1(end)).h), - s = nogamma(start.s, end.s), - l = nogamma(start.l, end.l), - opacity = nogamma(start.opacity, end.opacity); - return function(t) { - start.h = h(t); - start.s = s(t); - start.l = l(Math.pow(t, y)); - start.opacity = opacity(t); - return start + ""; - }; - } - - cubehelix.gamma = cubehelixGamma; - - return cubehelix; - })(1); - } - - cubehelix(hue); - cubehelix(nogamma); - function constants(x) { return function() { return x; @@ -4422,7 +4132,7 @@ var drawChart = (function (exports) { } return function(x) { - var i = bisectRight(domain, x, 1, j) - 1; + var i = bisect(domain, x, 1, j) - 1; return r[i](d[i](x)); }; } @@ -4658,7 +4368,7 @@ var drawChart = (function (exports) { var map = Array.prototype.map, prefixes = ["y","z","a","f","p","n","µ","m","","k","M","G","T","P","E","Z","Y"]; - function formatLocale$1(locale) { + function formatLocale(locale) { var group = locale.grouping === undefined || locale.thousands === undefined ? identity : formatGroup(map.call(locale.grouping, Number), locale.thousands + ""), currencyPrefix = locale.currency === undefined ? "" : locale.currency[0] + "", currencySuffix = locale.currency === undefined ? "" : locale.currency[1] + "", @@ -4795,21 +4505,21 @@ var drawChart = (function (exports) { }; } - var locale$1; + var locale; var format; var formatPrefix; - defaultLocale$1({ + defaultLocale({ thousands: ",", grouping: [3], currency: ["$", ""] }); - function defaultLocale$1(definition) { - locale$1 = formatLocale$1(definition); - format = locale$1.format; - formatPrefix = locale$1.formatPrefix; - return locale$1; + function defaultLocale(definition) { + locale = formatLocale(definition); + format = locale.format; + formatPrefix = locale.formatPrefix; + return locale; } function precisionFixed(step) { @@ -4918,1055 +4628,6 @@ var drawChart = (function (exports) { return linearish(scale); } - const t0 = new Date, t1 = new Date; - - function timeInterval(floori, offseti, count, field) { - - function interval(date) { - return floori(date = arguments.length === 0 ? new Date : new Date(+date)), date; - } - - interval.floor = (date) => { - return floori(date = new Date(+date)), date; - }; - - interval.ceil = (date) => { - return floori(date = new Date(date - 1)), offseti(date, 1), floori(date), date; - }; - - interval.round = (date) => { - const d0 = interval(date), d1 = interval.ceil(date); - return date - d0 < d1 - date ? d0 : d1; - }; - - interval.offset = (date, step) => { - return offseti(date = new Date(+date), step == null ? 1 : Math.floor(step)), date; - }; - - interval.range = (start, stop, step) => { - const range = []; - start = interval.ceil(start); - step = step == null ? 1 : Math.floor(step); - if (!(start < stop) || !(step > 0)) return range; // also handles Invalid Date - let previous; - do range.push(previous = new Date(+start)), offseti(start, step), floori(start); - while (previous < start && start < stop); - return range; - }; - - interval.filter = (test) => { - return timeInterval((date) => { - if (date >= date) while (floori(date), !test(date)) date.setTime(date - 1); - }, (date, step) => { - if (date >= date) { - if (step < 0) while (++step <= 0) { - while (offseti(date, -1), !test(date)) {} // eslint-disable-line no-empty - } else while (--step >= 0) { - while (offseti(date, +1), !test(date)) {} // eslint-disable-line no-empty - } - } - }); - }; - - if (count) { - interval.count = (start, end) => { - t0.setTime(+start), t1.setTime(+end); - floori(t0), floori(t1); - return Math.floor(count(t0, t1)); - }; - - interval.every = (step) => { - step = Math.floor(step); - return !isFinite(step) || !(step > 0) ? null - : !(step > 1) ? interval - : interval.filter(field - ? (d) => field(d) % step === 0 - : (d) => interval.count(0, d) % step === 0); - }; - } - - return interval; - } - - const millisecond = timeInterval(() => { - // noop - }, (date, step) => { - date.setTime(+date + step); - }, (start, end) => { - return end - start; - }); - - // An optimized implementation for this simple case. - millisecond.every = (k) => { - k = Math.floor(k); - if (!isFinite(k) || !(k > 0)) return null; - if (!(k > 1)) return millisecond; - return timeInterval((date) => { - date.setTime(Math.floor(date / k) * k); - }, (date, step) => { - date.setTime(+date + step * k); - }, (start, end) => { - return (end - start) / k; - }); - }; - - millisecond.range; - - const durationSecond = 1000; - const durationMinute = durationSecond * 60; - const durationHour = durationMinute * 60; - const durationDay = durationHour * 24; - const durationWeek = durationDay * 7; - - const second = timeInterval((date) => { - date.setTime(date - date.getMilliseconds()); - }, (date, step) => { - date.setTime(+date + step * durationSecond); - }, (start, end) => { - return (end - start) / durationSecond; - }, (date) => { - return date.getUTCSeconds(); - }); - - second.range; - - const timeMinute = timeInterval((date) => { - date.setTime(date - date.getMilliseconds() - date.getSeconds() * durationSecond); - }, (date, step) => { - date.setTime(+date + step * durationMinute); - }, (start, end) => { - return (end - start) / durationMinute; - }, (date) => { - return date.getMinutes(); - }); - - timeMinute.range; - - const utcMinute = timeInterval((date) => { - date.setUTCSeconds(0, 0); - }, (date, step) => { - date.setTime(+date + step * durationMinute); - }, (start, end) => { - return (end - start) / durationMinute; - }, (date) => { - return date.getUTCMinutes(); - }); - - utcMinute.range; - - const timeHour = timeInterval((date) => { - date.setTime(date - date.getMilliseconds() - date.getSeconds() * durationSecond - date.getMinutes() * durationMinute); - }, (date, step) => { - date.setTime(+date + step * durationHour); - }, (start, end) => { - return (end - start) / durationHour; - }, (date) => { - return date.getHours(); - }); - - timeHour.range; - - const utcHour = timeInterval((date) => { - date.setUTCMinutes(0, 0, 0); - }, (date, step) => { - date.setTime(+date + step * durationHour); - }, (start, end) => { - return (end - start) / durationHour; - }, (date) => { - return date.getUTCHours(); - }); - - utcHour.range; - - const timeDay = timeInterval( - date => date.setHours(0, 0, 0, 0), - (date, step) => date.setDate(date.getDate() + step), - (start, end) => (end - start - (end.getTimezoneOffset() - start.getTimezoneOffset()) * durationMinute) / durationDay, - date => date.getDate() - 1 - ); - - timeDay.range; - - const utcDay = timeInterval((date) => { - date.setUTCHours(0, 0, 0, 0); - }, (date, step) => { - date.setUTCDate(date.getUTCDate() + step); - }, (start, end) => { - return (end - start) / durationDay; - }, (date) => { - return date.getUTCDate() - 1; - }); - - utcDay.range; - - const unixDay = timeInterval((date) => { - date.setUTCHours(0, 0, 0, 0); - }, (date, step) => { - date.setUTCDate(date.getUTCDate() + step); - }, (start, end) => { - return (end - start) / durationDay; - }, (date) => { - return Math.floor(date / durationDay); - }); - - unixDay.range; - - function timeWeekday(i) { - return timeInterval((date) => { - date.setDate(date.getDate() - (date.getDay() + 7 - i) % 7); - date.setHours(0, 0, 0, 0); - }, (date, step) => { - date.setDate(date.getDate() + step * 7); - }, (start, end) => { - return (end - start - (end.getTimezoneOffset() - start.getTimezoneOffset()) * durationMinute) / durationWeek; - }); - } - - const timeSunday = timeWeekday(0); - const timeMonday = timeWeekday(1); - const timeTuesday = timeWeekday(2); - const timeWednesday = timeWeekday(3); - const timeThursday = timeWeekday(4); - const timeFriday = timeWeekday(5); - const timeSaturday = timeWeekday(6); - - timeSunday.range; - timeMonday.range; - timeTuesday.range; - timeWednesday.range; - timeThursday.range; - timeFriday.range; - timeSaturday.range; - - function utcWeekday(i) { - return timeInterval((date) => { - date.setUTCDate(date.getUTCDate() - (date.getUTCDay() + 7 - i) % 7); - date.setUTCHours(0, 0, 0, 0); - }, (date, step) => { - date.setUTCDate(date.getUTCDate() + step * 7); - }, (start, end) => { - return (end - start) / durationWeek; - }); - } - - const utcSunday = utcWeekday(0); - const utcMonday = utcWeekday(1); - const utcTuesday = utcWeekday(2); - const utcWednesday = utcWeekday(3); - const utcThursday = utcWeekday(4); - const utcFriday = utcWeekday(5); - const utcSaturday = utcWeekday(6); - - utcSunday.range; - utcMonday.range; - utcTuesday.range; - utcWednesday.range; - utcThursday.range; - utcFriday.range; - utcSaturday.range; - - const timeMonth = timeInterval((date) => { - date.setDate(1); - date.setHours(0, 0, 0, 0); - }, (date, step) => { - date.setMonth(date.getMonth() + step); - }, (start, end) => { - return end.getMonth() - start.getMonth() + (end.getFullYear() - start.getFullYear()) * 12; - }, (date) => { - return date.getMonth(); - }); - - timeMonth.range; - - const utcMonth = timeInterval((date) => { - date.setUTCDate(1); - date.setUTCHours(0, 0, 0, 0); - }, (date, step) => { - date.setUTCMonth(date.getUTCMonth() + step); - }, (start, end) => { - return end.getUTCMonth() - start.getUTCMonth() + (end.getUTCFullYear() - start.getUTCFullYear()) * 12; - }, (date) => { - return date.getUTCMonth(); - }); - - utcMonth.range; - - const timeYear = timeInterval((date) => { - date.setMonth(0, 1); - date.setHours(0, 0, 0, 0); - }, (date, step) => { - date.setFullYear(date.getFullYear() + step); - }, (start, end) => { - return end.getFullYear() - start.getFullYear(); - }, (date) => { - return date.getFullYear(); - }); - - // An optimized implementation for this simple case. - timeYear.every = (k) => { - return !isFinite(k = Math.floor(k)) || !(k > 0) ? null : timeInterval((date) => { - date.setFullYear(Math.floor(date.getFullYear() / k) * k); - date.setMonth(0, 1); - date.setHours(0, 0, 0, 0); - }, (date, step) => { - date.setFullYear(date.getFullYear() + step * k); - }); - }; - - timeYear.range; - - const utcYear = timeInterval((date) => { - date.setUTCMonth(0, 1); - date.setUTCHours(0, 0, 0, 0); - }, (date, step) => { - date.setUTCFullYear(date.getUTCFullYear() + step); - }, (start, end) => { - return end.getUTCFullYear() - start.getUTCFullYear(); - }, (date) => { - return date.getUTCFullYear(); - }); - - // An optimized implementation for this simple case. - utcYear.every = (k) => { - return !isFinite(k = Math.floor(k)) || !(k > 0) ? null : timeInterval((date) => { - date.setUTCFullYear(Math.floor(date.getUTCFullYear() / k) * k); - date.setUTCMonth(0, 1); - date.setUTCHours(0, 0, 0, 0); - }, (date, step) => { - date.setUTCFullYear(date.getUTCFullYear() + step * k); - }); - }; - - utcYear.range; - - function localDate(d) { - if (0 <= d.y && d.y < 100) { - var date = new Date(-1, d.m, d.d, d.H, d.M, d.S, d.L); - date.setFullYear(d.y); - return date; - } - return new Date(d.y, d.m, d.d, d.H, d.M, d.S, d.L); - } - - function utcDate(d) { - if (0 <= d.y && d.y < 100) { - var date = new Date(Date.UTC(-1, d.m, d.d, d.H, d.M, d.S, d.L)); - date.setUTCFullYear(d.y); - return date; - } - return new Date(Date.UTC(d.y, d.m, d.d, d.H, d.M, d.S, d.L)); - } - - function newDate(y, m, d) { - return {y: y, m: m, d: d, H: 0, M: 0, S: 0, L: 0}; - } - - function formatLocale(locale) { - var locale_dateTime = locale.dateTime, - locale_date = locale.date, - locale_time = locale.time, - locale_periods = locale.periods, - locale_weekdays = locale.days, - locale_shortWeekdays = locale.shortDays, - locale_months = locale.months, - locale_shortMonths = locale.shortMonths; - - var periodRe = formatRe(locale_periods), - periodLookup = formatLookup(locale_periods), - weekdayRe = formatRe(locale_weekdays), - weekdayLookup = formatLookup(locale_weekdays), - shortWeekdayRe = formatRe(locale_shortWeekdays), - shortWeekdayLookup = formatLookup(locale_shortWeekdays), - monthRe = formatRe(locale_months), - monthLookup = formatLookup(locale_months), - shortMonthRe = formatRe(locale_shortMonths), - shortMonthLookup = formatLookup(locale_shortMonths); - - var formats = { - "a": formatShortWeekday, - "A": formatWeekday, - "b": formatShortMonth, - "B": formatMonth, - "c": null, - "d": formatDayOfMonth, - "e": formatDayOfMonth, - "f": formatMicroseconds, - "g": formatYearISO, - "G": formatFullYearISO, - "H": formatHour24, - "I": formatHour12, - "j": formatDayOfYear, - "L": formatMilliseconds, - "m": formatMonthNumber, - "M": formatMinutes, - "p": formatPeriod, - "q": formatQuarter, - "Q": formatUnixTimestamp, - "s": formatUnixTimestampSeconds, - "S": formatSeconds, - "u": formatWeekdayNumberMonday, - "U": formatWeekNumberSunday, - "V": formatWeekNumberISO, - "w": formatWeekdayNumberSunday, - "W": formatWeekNumberMonday, - "x": null, - "X": null, - "y": formatYear, - "Y": formatFullYear, - "Z": formatZone, - "%": formatLiteralPercent - }; - - var utcFormats = { - "a": formatUTCShortWeekday, - "A": formatUTCWeekday, - "b": formatUTCShortMonth, - "B": formatUTCMonth, - "c": null, - "d": formatUTCDayOfMonth, - "e": formatUTCDayOfMonth, - "f": formatUTCMicroseconds, - "g": formatUTCYearISO, - "G": formatUTCFullYearISO, - "H": formatUTCHour24, - "I": formatUTCHour12, - "j": formatUTCDayOfYear, - "L": formatUTCMilliseconds, - "m": formatUTCMonthNumber, - "M": formatUTCMinutes, - "p": formatUTCPeriod, - "q": formatUTCQuarter, - "Q": formatUnixTimestamp, - "s": formatUnixTimestampSeconds, - "S": formatUTCSeconds, - "u": formatUTCWeekdayNumberMonday, - "U": formatUTCWeekNumberSunday, - "V": formatUTCWeekNumberISO, - "w": formatUTCWeekdayNumberSunday, - "W": formatUTCWeekNumberMonday, - "x": null, - "X": null, - "y": formatUTCYear, - "Y": formatUTCFullYear, - "Z": formatUTCZone, - "%": formatLiteralPercent - }; - - var parses = { - "a": parseShortWeekday, - "A": parseWeekday, - "b": parseShortMonth, - "B": parseMonth, - "c": parseLocaleDateTime, - "d": parseDayOfMonth, - "e": parseDayOfMonth, - "f": parseMicroseconds, - "g": parseYear, - "G": parseFullYear, - "H": parseHour24, - "I": parseHour24, - "j": parseDayOfYear, - "L": parseMilliseconds, - "m": parseMonthNumber, - "M": parseMinutes, - "p": parsePeriod, - "q": parseQuarter, - "Q": parseUnixTimestamp, - "s": parseUnixTimestampSeconds, - "S": parseSeconds, - "u": parseWeekdayNumberMonday, - "U": parseWeekNumberSunday, - "V": parseWeekNumberISO, - "w": parseWeekdayNumberSunday, - "W": parseWeekNumberMonday, - "x": parseLocaleDate, - "X": parseLocaleTime, - "y": parseYear, - "Y": parseFullYear, - "Z": parseZone, - "%": parseLiteralPercent - }; - - // These recursive directive definitions must be deferred. - formats.x = newFormat(locale_date, formats); - formats.X = newFormat(locale_time, formats); - formats.c = newFormat(locale_dateTime, formats); - utcFormats.x = newFormat(locale_date, utcFormats); - utcFormats.X = newFormat(locale_time, utcFormats); - utcFormats.c = newFormat(locale_dateTime, utcFormats); - - function newFormat(specifier, formats) { - return function(date) { - var string = [], - i = -1, - j = 0, - n = specifier.length, - c, - pad, - format; - - if (!(date instanceof Date)) date = new Date(+date); - - while (++i < n) { - if (specifier.charCodeAt(i) === 37) { - string.push(specifier.slice(j, i)); - if ((pad = pads[c = specifier.charAt(++i)]) != null) c = specifier.charAt(++i); - else pad = c === "e" ? " " : "0"; - if (format = formats[c]) c = format(date, pad); - string.push(c); - j = i + 1; - } - } - - string.push(specifier.slice(j, i)); - return string.join(""); - }; - } - - function newParse(specifier, Z) { - return function(string) { - var d = newDate(1900, undefined, 1), - i = parseSpecifier(d, specifier, string += "", 0), - week, day; - if (i != string.length) return null; - - // If a UNIX timestamp is specified, return it. - if ("Q" in d) return new Date(d.Q); - if ("s" in d) return new Date(d.s * 1000 + ("L" in d ? d.L : 0)); - - // If this is utcParse, never use the local timezone. - if (Z && !("Z" in d)) d.Z = 0; - - // The am-pm flag is 0 for AM, and 1 for PM. - if ("p" in d) d.H = d.H % 12 + d.p * 12; - - // If the month was not specified, inherit from the quarter. - if (d.m === undefined) d.m = "q" in d ? d.q : 0; - - // Convert day-of-week and week-of-year to day-of-year. - if ("V" in d) { - if (d.V < 1 || d.V > 53) return null; - if (!("w" in d)) d.w = 1; - if ("Z" in d) { - week = utcDate(newDate(d.y, 0, 1)), day = week.getUTCDay(); - week = day > 4 || day === 0 ? utcMonday.ceil(week) : utcMonday(week); - week = utcDay.offset(week, (d.V - 1) * 7); - d.y = week.getUTCFullYear(); - d.m = week.getUTCMonth(); - d.d = week.getUTCDate() + (d.w + 6) % 7; - } else { - week = localDate(newDate(d.y, 0, 1)), day = week.getDay(); - week = day > 4 || day === 0 ? timeMonday.ceil(week) : timeMonday(week); - week = timeDay.offset(week, (d.V - 1) * 7); - d.y = week.getFullYear(); - d.m = week.getMonth(); - d.d = week.getDate() + (d.w + 6) % 7; - } - } else if ("W" in d || "U" in d) { - if (!("w" in d)) d.w = "u" in d ? d.u % 7 : "W" in d ? 1 : 0; - day = "Z" in d ? utcDate(newDate(d.y, 0, 1)).getUTCDay() : localDate(newDate(d.y, 0, 1)).getDay(); - d.m = 0; - d.d = "W" in d ? (d.w + 6) % 7 + d.W * 7 - (day + 5) % 7 : d.w + d.U * 7 - (day + 6) % 7; - } - - // If a time zone is specified, all fields are interpreted as UTC and then - // offset according to the specified time zone. - if ("Z" in d) { - d.H += d.Z / 100 | 0; - d.M += d.Z % 100; - return utcDate(d); - } - - // Otherwise, all fields are in local time. - return localDate(d); - }; - } - - function parseSpecifier(d, specifier, string, j) { - var i = 0, - n = specifier.length, - m = string.length, - c, - parse; - - while (i < n) { - if (j >= m) return -1; - c = specifier.charCodeAt(i++); - if (c === 37) { - c = specifier.charAt(i++); - parse = parses[c in pads ? specifier.charAt(i++) : c]; - if (!parse || ((j = parse(d, string, j)) < 0)) return -1; - } else if (c != string.charCodeAt(j++)) { - return -1; - } - } - - return j; - } - - function parsePeriod(d, string, i) { - var n = periodRe.exec(string.slice(i)); - return n ? (d.p = periodLookup.get(n[0].toLowerCase()), i + n[0].length) : -1; - } - - function parseShortWeekday(d, string, i) { - var n = shortWeekdayRe.exec(string.slice(i)); - return n ? (d.w = shortWeekdayLookup.get(n[0].toLowerCase()), i + n[0].length) : -1; - } - - function parseWeekday(d, string, i) { - var n = weekdayRe.exec(string.slice(i)); - return n ? (d.w = weekdayLookup.get(n[0].toLowerCase()), i + n[0].length) : -1; - } - - function parseShortMonth(d, string, i) { - var n = shortMonthRe.exec(string.slice(i)); - return n ? (d.m = shortMonthLookup.get(n[0].toLowerCase()), i + n[0].length) : -1; - } - - function parseMonth(d, string, i) { - var n = monthRe.exec(string.slice(i)); - return n ? (d.m = monthLookup.get(n[0].toLowerCase()), i + n[0].length) : -1; - } - - function parseLocaleDateTime(d, string, i) { - return parseSpecifier(d, locale_dateTime, string, i); - } - - function parseLocaleDate(d, string, i) { - return parseSpecifier(d, locale_date, string, i); - } - - function parseLocaleTime(d, string, i) { - return parseSpecifier(d, locale_time, string, i); - } - - function formatShortWeekday(d) { - return locale_shortWeekdays[d.getDay()]; - } - - function formatWeekday(d) { - return locale_weekdays[d.getDay()]; - } - - function formatShortMonth(d) { - return locale_shortMonths[d.getMonth()]; - } - - function formatMonth(d) { - return locale_months[d.getMonth()]; - } - - function formatPeriod(d) { - return locale_periods[+(d.getHours() >= 12)]; - } - - function formatQuarter(d) { - return 1 + ~~(d.getMonth() / 3); - } - - function formatUTCShortWeekday(d) { - return locale_shortWeekdays[d.getUTCDay()]; - } - - function formatUTCWeekday(d) { - return locale_weekdays[d.getUTCDay()]; - } - - function formatUTCShortMonth(d) { - return locale_shortMonths[d.getUTCMonth()]; - } - - function formatUTCMonth(d) { - return locale_months[d.getUTCMonth()]; - } - - function formatUTCPeriod(d) { - return locale_periods[+(d.getUTCHours() >= 12)]; - } - - function formatUTCQuarter(d) { - return 1 + ~~(d.getUTCMonth() / 3); - } - - return { - format: function(specifier) { - var f = newFormat(specifier += "", formats); - f.toString = function() { return specifier; }; - return f; - }, - parse: function(specifier) { - var p = newParse(specifier += "", false); - p.toString = function() { return specifier; }; - return p; - }, - utcFormat: function(specifier) { - var f = newFormat(specifier += "", utcFormats); - f.toString = function() { return specifier; }; - return f; - }, - utcParse: function(specifier) { - var p = newParse(specifier += "", true); - p.toString = function() { return specifier; }; - return p; - } - }; - } - - var pads = {"-": "", "_": " ", "0": "0"}, - numberRe = /^\s*\d+/, // note: ignores next directive - percentRe = /^%/, - requoteRe = /[\\^$*+?|[\]().{}]/g; - - function pad(value, fill, width) { - var sign = value < 0 ? "-" : "", - string = (sign ? -value : value) + "", - length = string.length; - return sign + (length < width ? new Array(width - length + 1).join(fill) + string : string); - } - - function requote(s) { - return s.replace(requoteRe, "\\$&"); - } - - function formatRe(names) { - return new RegExp("^(?:" + names.map(requote).join("|") + ")", "i"); - } - - function formatLookup(names) { - return new Map(names.map((name, i) => [name.toLowerCase(), i])); - } - - function parseWeekdayNumberSunday(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 1)); - return n ? (d.w = +n[0], i + n[0].length) : -1; - } - - function parseWeekdayNumberMonday(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 1)); - return n ? (d.u = +n[0], i + n[0].length) : -1; - } - - function parseWeekNumberSunday(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.U = +n[0], i + n[0].length) : -1; - } - - function parseWeekNumberISO(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.V = +n[0], i + n[0].length) : -1; - } - - function parseWeekNumberMonday(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.W = +n[0], i + n[0].length) : -1; - } - - function parseFullYear(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 4)); - return n ? (d.y = +n[0], i + n[0].length) : -1; - } - - function parseYear(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.y = +n[0] + (+n[0] > 68 ? 1900 : 2000), i + n[0].length) : -1; - } - - function parseZone(d, string, i) { - var n = /^(Z)|([+-]\d\d)(?::?(\d\d))?/.exec(string.slice(i, i + 6)); - return n ? (d.Z = n[1] ? 0 : -(n[2] + (n[3] || "00")), i + n[0].length) : -1; - } - - function parseQuarter(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 1)); - return n ? (d.q = n[0] * 3 - 3, i + n[0].length) : -1; - } - - function parseMonthNumber(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.m = n[0] - 1, i + n[0].length) : -1; - } - - function parseDayOfMonth(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.d = +n[0], i + n[0].length) : -1; - } - - function parseDayOfYear(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 3)); - return n ? (d.m = 0, d.d = +n[0], i + n[0].length) : -1; - } - - function parseHour24(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.H = +n[0], i + n[0].length) : -1; - } - - function parseMinutes(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.M = +n[0], i + n[0].length) : -1; - } - - function parseSeconds(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 2)); - return n ? (d.S = +n[0], i + n[0].length) : -1; - } - - function parseMilliseconds(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 3)); - return n ? (d.L = +n[0], i + n[0].length) : -1; - } - - function parseMicroseconds(d, string, i) { - var n = numberRe.exec(string.slice(i, i + 6)); - return n ? (d.L = Math.floor(n[0] / 1000), i + n[0].length) : -1; - } - - function parseLiteralPercent(d, string, i) { - var n = percentRe.exec(string.slice(i, i + 1)); - return n ? i + n[0].length : -1; - } - - function parseUnixTimestamp(d, string, i) { - var n = numberRe.exec(string.slice(i)); - return n ? (d.Q = +n[0], i + n[0].length) : -1; - } - - function parseUnixTimestampSeconds(d, string, i) { - var n = numberRe.exec(string.slice(i)); - return n ? (d.s = +n[0], i + n[0].length) : -1; - } - - function formatDayOfMonth(d, p) { - return pad(d.getDate(), p, 2); - } - - function formatHour24(d, p) { - return pad(d.getHours(), p, 2); - } - - function formatHour12(d, p) { - return pad(d.getHours() % 12 || 12, p, 2); - } - - function formatDayOfYear(d, p) { - return pad(1 + timeDay.count(timeYear(d), d), p, 3); - } - - function formatMilliseconds(d, p) { - return pad(d.getMilliseconds(), p, 3); - } - - function formatMicroseconds(d, p) { - return formatMilliseconds(d, p) + "000"; - } - - function formatMonthNumber(d, p) { - return pad(d.getMonth() + 1, p, 2); - } - - function formatMinutes(d, p) { - return pad(d.getMinutes(), p, 2); - } - - function formatSeconds(d, p) { - return pad(d.getSeconds(), p, 2); - } - - function formatWeekdayNumberMonday(d) { - var day = d.getDay(); - return day === 0 ? 7 : day; - } - - function formatWeekNumberSunday(d, p) { - return pad(timeSunday.count(timeYear(d) - 1, d), p, 2); - } - - function dISO(d) { - var day = d.getDay(); - return (day >= 4 || day === 0) ? timeThursday(d) : timeThursday.ceil(d); - } - - function formatWeekNumberISO(d, p) { - d = dISO(d); - return pad(timeThursday.count(timeYear(d), d) + (timeYear(d).getDay() === 4), p, 2); - } - - function formatWeekdayNumberSunday(d) { - return d.getDay(); - } - - function formatWeekNumberMonday(d, p) { - return pad(timeMonday.count(timeYear(d) - 1, d), p, 2); - } - - function formatYear(d, p) { - return pad(d.getFullYear() % 100, p, 2); - } - - function formatYearISO(d, p) { - d = dISO(d); - return pad(d.getFullYear() % 100, p, 2); - } - - function formatFullYear(d, p) { - return pad(d.getFullYear() % 10000, p, 4); - } - - function formatFullYearISO(d, p) { - var day = d.getDay(); - d = (day >= 4 || day === 0) ? timeThursday(d) : timeThursday.ceil(d); - return pad(d.getFullYear() % 10000, p, 4); - } - - function formatZone(d) { - var z = d.getTimezoneOffset(); - return (z > 0 ? "-" : (z *= -1, "+")) - + pad(z / 60 | 0, "0", 2) - + pad(z % 60, "0", 2); - } - - function formatUTCDayOfMonth(d, p) { - return pad(d.getUTCDate(), p, 2); - } - - function formatUTCHour24(d, p) { - return pad(d.getUTCHours(), p, 2); - } - - function formatUTCHour12(d, p) { - return pad(d.getUTCHours() % 12 || 12, p, 2); - } - - function formatUTCDayOfYear(d, p) { - return pad(1 + utcDay.count(utcYear(d), d), p, 3); - } - - function formatUTCMilliseconds(d, p) { - return pad(d.getUTCMilliseconds(), p, 3); - } - - function formatUTCMicroseconds(d, p) { - return formatUTCMilliseconds(d, p) + "000"; - } - - function formatUTCMonthNumber(d, p) { - return pad(d.getUTCMonth() + 1, p, 2); - } - - function formatUTCMinutes(d, p) { - return pad(d.getUTCMinutes(), p, 2); - } - - function formatUTCSeconds(d, p) { - return pad(d.getUTCSeconds(), p, 2); - } - - function formatUTCWeekdayNumberMonday(d) { - var dow = d.getUTCDay(); - return dow === 0 ? 7 : dow; - } - - function formatUTCWeekNumberSunday(d, p) { - return pad(utcSunday.count(utcYear(d) - 1, d), p, 2); - } - - function UTCdISO(d) { - var day = d.getUTCDay(); - return (day >= 4 || day === 0) ? utcThursday(d) : utcThursday.ceil(d); - } - - function formatUTCWeekNumberISO(d, p) { - d = UTCdISO(d); - return pad(utcThursday.count(utcYear(d), d) + (utcYear(d).getUTCDay() === 4), p, 2); - } - - function formatUTCWeekdayNumberSunday(d) { - return d.getUTCDay(); - } - - function formatUTCWeekNumberMonday(d, p) { - return pad(utcMonday.count(utcYear(d) - 1, d), p, 2); - } - - function formatUTCYear(d, p) { - return pad(d.getUTCFullYear() % 100, p, 2); - } - - function formatUTCYearISO(d, p) { - d = UTCdISO(d); - return pad(d.getUTCFullYear() % 100, p, 2); - } - - function formatUTCFullYear(d, p) { - return pad(d.getUTCFullYear() % 10000, p, 4); - } - - function formatUTCFullYearISO(d, p) { - var day = d.getUTCDay(); - d = (day >= 4 || day === 0) ? utcThursday(d) : utcThursday.ceil(d); - return pad(d.getUTCFullYear() % 10000, p, 4); - } - - function formatUTCZone() { - return "+0000"; - } - - function formatLiteralPercent() { - return "%"; - } - - function formatUnixTimestamp(d) { - return +d; - } - - function formatUnixTimestampSeconds(d) { - return Math.floor(+d / 1000); - } - - var locale; - var utcFormat; - var utcParse; - - defaultLocale({ - dateTime: "%x, %X", - date: "%-m/%-d/%Y", - time: "%-I:%M:%S %p", - periods: ["AM", "PM"], - days: ["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"], - shortDays: ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"], - months: ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"], - shortMonths: ["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"] - }); - - function defaultLocale(definition) { - locale = formatLocale(definition); - locale.format; - locale.parse; - utcFormat = locale.utcFormat; - utcParse = locale.utcParse; - return locale; - } - - var isoSpecifier = "%Y-%m-%dT%H:%M:%S.%LZ"; - - function formatIsoNative(date) { - return date.toISOString(); - } - - Date.prototype.toISOString - ? formatIsoNative - : utcFormat(isoSpecifier); - - function parseIsoNative(string) { - var date = new Date(string); - return isNaN(date) ? null : date; - } - - +new Date("2000-01-01T00:00:00.000Z") - ? parseIsoNative - : utcParse(isoSpecifier); - function transformer() { var x0 = 0, x1 = 1, @@ -6087,7 +4748,7 @@ var drawChart = (function (exports) { }; }; - const StaticContext = B$2({}); + const StaticContext = G({}); const drawChart = (parentNode, data, width, height) => { const availableSizeProperties = getAvailableSizeOptions(data.options); console.time("layout create"); @@ -6131,7 +4792,7 @@ var drawChart = (function (exports) { console.time("color"); const getModuleColor = createRainbowColor(rawHierarchy); console.timeEnd("color"); - P(o$1(StaticContext.Provider, Object.assign({ value: { + D(o$1(StaticContext.Provider, { value: { data, availableSizeProperties, width, @@ -6141,7 +4802,7 @@ var drawChart = (function (exports) { getModuleColor, rawHierarchy, layout, - } }, { children: o$1(Main, {}) })), parentNode); + }, children: o$1(Main, {}) }), parentNode); }; exports.StaticContext = StaticContext; @@ -6157,7 +4818,7 @@ var drawChart = (function (exports) {