mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/use-custom-vaes
This commit is contained in:
commit
a958ae5e29
@ -17,7 +17,7 @@ notebooks.
|
||||
|
||||
You will need a GPU to perform training in a reasonable length of
|
||||
time, and at least 12 GB of VRAM. We recommend using the [`xformers`
|
||||
library](../installation/070_INSTALL_XFORMERS) to accelerate the
|
||||
library](../installation/070_INSTALL_XFORMERS.md) to accelerate the
|
||||
training process further. During training, about ~8 GB is temporarily
|
||||
needed in order to store intermediate models, checkpoints and logs.
|
||||
|
||||
|
@ -24,7 +24,7 @@ You need to have opencv installed so that pypatchmatch can be built:
|
||||
brew install opencv
|
||||
```
|
||||
|
||||
The next time you start `invoke`, after sucesfully installing opencv, pypatchmatch will be built.
|
||||
The next time you start `invoke`, after successfully installing opencv, pypatchmatch will be built.
|
||||
|
||||
## Linux
|
||||
|
||||
@ -56,7 +56,7 @@ Prior to installing PyPatchMatch, you need to take the following steps:
|
||||
|
||||
5. Confirm that pypatchmatch is installed. At the command-line prompt enter
|
||||
`python`, and then at the `>>>` line type
|
||||
`from patchmatch import patch_match`: It should look like the follwing:
|
||||
`from patchmatch import patch_match`: It should look like the following:
|
||||
|
||||
```py
|
||||
Python 3.9.5 (default, Nov 23 2021, 15:27:38)
|
||||
@ -108,4 +108,4 @@ Prior to installing PyPatchMatch, you need to take the following steps:
|
||||
|
||||
[**Next, Follow Steps 4-6 from the Debian Section above**](#linux)
|
||||
|
||||
If you see no errors, then you're ready to go!
|
||||
If you see no errors you're ready to go!
|
||||
|
@ -10,6 +10,7 @@ from pydantic.fields import Field
|
||||
from ...invocations import *
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import (
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
@ -92,7 +93,7 @@ async def get_session(
|
||||
async def add_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The node to add"),
|
||||
) -> str:
|
||||
"""Adds a node to the graph"""
|
||||
@ -125,7 +126,7 @@ async def update_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node_path: str = Path(description="The path to the node in the graph"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The new node"),
|
||||
) -> GraphExecutionState:
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
@ -186,7 +187,7 @@ async def delete_node(
|
||||
)
|
||||
async def add_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
edge: tuple[EdgeConnection, EdgeConnection] = Body(description="The edge to add"),
|
||||
edge: Edge = Body(description="The edge to add"),
|
||||
) -> GraphExecutionState:
|
||||
"""Adds an edge to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
@ -228,9 +229,9 @@ async def delete_edge(
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
edge = (
|
||||
EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
EdgeConnection(node_id=to_node_id, field=to_field),
|
||||
edge = Edge(
|
||||
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
destination=EdgeConnection(node_id=to_node_id, field=to_field)
|
||||
)
|
||||
session.delete_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
|
@ -19,7 +19,7 @@ from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import EdgeConnection, GraphExecutionState
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
@ -77,7 +77,7 @@ def get_command_parser() -> argparse.ArgumentParser:
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation
|
||||
) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
||||
) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
atype = type(a)
|
||||
btype = type(b)
|
||||
@ -94,9 +94,9 @@ def generate_matching_edges(
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
edges = [
|
||||
(
|
||||
EdgeConnection(node_id=a.id, field=field),
|
||||
EdgeConnection(node_id=b.id, field=field),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=a.id, field=field),
|
||||
destination=EdgeConnection(node_id=b.id, field=field)
|
||||
)
|
||||
for field in matching_fields
|
||||
]
|
||||
@ -111,16 +111,15 @@ class SessionError(Exception):
|
||||
def invoke_all(context: CliContext):
|
||||
"""Runs all invocations in the specified session"""
|
||||
context.invoker.invoke(context.session, invoke_all=True)
|
||||
while not context.session.is_complete():
|
||||
while not context.get_session().is_complete():
|
||||
# Wait some time
|
||||
session = context.get_session()
|
||||
time.sleep(0.1)
|
||||
|
||||
# Print any errors
|
||||
if context.session.has_error():
|
||||
for n in context.session.errors:
|
||||
print(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}"
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||
)
|
||||
|
||||
raise SessionError()
|
||||
@ -203,7 +202,7 @@ def invoke_cli():
|
||||
continue
|
||||
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges = []
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
from_id = (
|
||||
history[0] if current_id == start_id else str(current_id - 1)
|
||||
@ -225,19 +224,19 @@ def invoke_cli():
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command
|
||||
)
|
||||
matching_destinations = [e[1] for e in matching_edges]
|
||||
edges = [e for e in edges if e[1] not in matching_destinations]
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
edges.extend(matching_edges)
|
||||
|
||||
if "link" in args and args["link"]:
|
||||
for link in args["link"]:
|
||||
edges = [e for e in edges if e[1].node_id != command.command.id and e[1].field != link[2]]
|
||||
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
|
||||
edges.append(
|
||||
(
|
||||
EdgeConnection(node_id=link[1], field=link[0]),
|
||||
EdgeConnection(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=link[1], field=link[0]),
|
||||
destination=EdgeConnection(
|
||||
node_id=command.command.id, field=link[2]
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -4,6 +4,8 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch import Tensor
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
@ -12,7 +14,9 @@ from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(InvokeAIGenerator.schedulers())
|
||||
@ -41,18 +45,32 @@ class TextToImageInvocation(BaseInvocation):
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, sample: Any = None, step: int = 0
|
||||
) -> None:
|
||||
self, context: InvocationContext, sample: Tensor, step: int
|
||||
) -> None:
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id,
|
||||
self.id,
|
||||
{
|
||||
"width": width,
|
||||
"height": height,
|
||||
"dataURL": dataURL
|
||||
},
|
||||
step,
|
||||
float(step) / float(self.steps),
|
||||
self.steps,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
def step_callback(sample, step=0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state.latents, state.step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
|
@ -1,7 +1,10 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, TypedDict
|
||||
|
||||
ProgressImage = TypedDict(
|
||||
"ProgressImage", {"dataURL": str, "width": int, "height": int}
|
||||
)
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
@ -23,8 +26,9 @@ class EventServiceBase:
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
progress_image: ProgressImage | None,
|
||||
step: int,
|
||||
percent: float,
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_session_event(
|
||||
@ -32,8 +36,9 @@ class EventServiceBase:
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
progress_image=progress_image,
|
||||
step=step,
|
||||
percent=percent,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -44,6 +44,11 @@ class EdgeConnection(BaseModel):
|
||||
return hash(f"{self.node_id}.{self.field}")
|
||||
|
||||
|
||||
class Edge(BaseModel):
|
||||
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
|
||||
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")
|
||||
|
||||
|
||||
def get_output_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_type = type(node)
|
||||
node_outputs = get_type_hints(node_type.get_output_type())
|
||||
@ -194,7 +199,7 @@ class Graph(BaseModel):
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
)
|
||||
edges: list[tuple[EdgeConnection, EdgeConnection]] = Field(
|
||||
edges: list[Edge] = Field(
|
||||
description="The connections between nodes and their fields in this graph",
|
||||
default_factory=list,
|
||||
)
|
||||
@ -251,7 +256,7 @@ class Graph(BaseModel):
|
||||
except NodeNotFoundError:
|
||||
pass # Ignore, not doesn't exist (should this throw?)
|
||||
|
||||
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
def add_edge(self, edge: Edge) -> None:
|
||||
"""Adds an edge to a graph
|
||||
|
||||
:raises InvalidEdgeError: the provided edge is invalid.
|
||||
@ -262,7 +267,7 @@ class Graph(BaseModel):
|
||||
else:
|
||||
raise InvalidEdgeError()
|
||||
|
||||
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
def delete_edge(self, edge: Edge) -> None:
|
||||
"""Deletes an edge from a graph"""
|
||||
|
||||
try:
|
||||
@ -280,7 +285,7 @@ class Graph(BaseModel):
|
||||
|
||||
# Validate all edges reference nodes in the graph
|
||||
node_ids = set(
|
||||
[e[0].node_id for e in self.edges] + [e[1].node_id for e in self.edges]
|
||||
[e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
|
||||
)
|
||||
if not all((self.has_node(node_id) for node_id in node_ids)):
|
||||
return False
|
||||
@ -294,10 +299,10 @@ class Graph(BaseModel):
|
||||
if not all(
|
||||
(
|
||||
are_connections_compatible(
|
||||
self.get_node(e[0].node_id),
|
||||
e[0].field,
|
||||
self.get_node(e[1].node_id),
|
||||
e[1].field,
|
||||
self.get_node(e.source.node_id),
|
||||
e.source.field,
|
||||
self.get_node(e.destination.node_id),
|
||||
e.destination.field,
|
||||
)
|
||||
for e in self.edges
|
||||
)
|
||||
@ -328,58 +333,58 @@ class Graph(BaseModel):
|
||||
|
||||
return True
|
||||
|
||||
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||
try:
|
||||
from_node = self.get_node(edge[0].node_id)
|
||||
to_node = self.get_node(edge[1].node_id)
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge[1].node_id, edge[1].field)
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||
return False
|
||||
|
||||
# Validate that no cycles would be created
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge[0].node_id, edge[1].node_id)
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
return False
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge[0].field, to_node, edge[1].field
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge[1].field == "collection":
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge[1].node_id, new_input=edge[0]
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge[0].field == "item":
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge[0].node_id, new_output=edge[1]
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge[1].field == "item":
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge[1].node_id, new_input=edge[0]
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge[0].field == "collection":
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge[0].node_id, new_output=edge[1]
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
|
||||
@ -438,15 +443,15 @@ class Graph(BaseModel):
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = (
|
||||
new_node.id
|
||||
if "." not in edge[1].node_id
|
||||
else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}'
|
||||
if "." not in edge.destination.node_id
|
||||
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
|
||||
)
|
||||
graph.add_edge(
|
||||
(
|
||||
edge[0],
|
||||
EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge[1].field
|
||||
),
|
||||
Edge(
|
||||
source=edge.source,
|
||||
destination=EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge.destination.field
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@ -454,51 +459,51 @@ class Graph(BaseModel):
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = (
|
||||
new_node.id
|
||||
if "." not in edge[0].node_id
|
||||
else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}'
|
||||
if "." not in edge.source.node_id
|
||||
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
|
||||
)
|
||||
graph.add_edge(
|
||||
(
|
||||
EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge[0].field
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge.source.field
|
||||
),
|
||||
edge[1],
|
||||
destination=edge.destination
|
||||
)
|
||||
)
|
||||
|
||||
def _get_input_edges(
|
||||
self, node_path: str, field: Optional[str] = None
|
||||
) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
||||
) -> list[Edge]:
|
||||
"""Gets all input edges for a node"""
|
||||
edges = self._get_input_edges_and_graphs(node_path)
|
||||
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if field is None or e[2][1].field == field)
|
||||
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
|
||||
|
||||
# Create full node paths for each edge
|
||||
return [
|
||||
(
|
||||
EdgeConnection(
|
||||
node_id=self._get_node_path(e[0].node_id, prefix=prefix),
|
||||
field=e[0].field,
|
||||
),
|
||||
EdgeConnection(
|
||||
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
|
||||
field=e[1].field,
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
|
||||
field=e.source.field,
|
||||
),
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
)
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
|
||||
def _get_input_edges_and_graphs(
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]:
|
||||
) -> list[tuple["Graph", str, Edge]]:
|
||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend(
|
||||
[(self, prefix, e) for e in self.edges if e[1].node_id == node_path]
|
||||
[(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
|
||||
)
|
||||
|
||||
node_id = (
|
||||
@ -522,37 +527,37 @@ class Graph(BaseModel):
|
||||
|
||||
def _get_output_edges(
|
||||
self, node_path: str, field: str
|
||||
) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
||||
) -> list[Edge]:
|
||||
"""Gets all output edges for a node"""
|
||||
edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if e[2][0].field == field)
|
||||
filtered_edges = (e for e in edges if e[2].source.field == field)
|
||||
|
||||
# Create full node paths for each edge
|
||||
return [
|
||||
(
|
||||
EdgeConnection(
|
||||
node_id=self._get_node_path(e[0].node_id, prefix=prefix),
|
||||
field=e[0].field,
|
||||
),
|
||||
EdgeConnection(
|
||||
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
|
||||
field=e[1].field,
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
|
||||
field=e.source.field,
|
||||
),
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
)
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
|
||||
def _get_output_edges_and_graphs(
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]:
|
||||
) -> list[tuple["Graph", str, Edge]]:
|
||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend(
|
||||
[(self, prefix, e) for e in self.edges if e[0].node_id == node_path]
|
||||
[(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
|
||||
)
|
||||
|
||||
node_id = (
|
||||
@ -580,8 +585,8 @@ class Graph(BaseModel):
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = list([e[0] for e in self._get_input_edges(node_path, "collection")])
|
||||
outputs = list([e[1] for e in self._get_output_edges(node_path, "item")])
|
||||
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
|
||||
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@ -622,8 +627,8 @@ class Graph(BaseModel):
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = list([e[0] for e in self._get_input_edges(node_path, "item")])
|
||||
outputs = list([e[1] for e in self._get_output_edges(node_path, "collection")])
|
||||
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
|
||||
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@ -684,7 +689,7 @@ class Graph(BaseModel):
|
||||
# TODO: Cache this?
|
||||
g = nx.DiGraph()
|
||||
g.add_nodes_from([n for n in self.nodes.keys()])
|
||||
g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges]))
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_flat(
|
||||
@ -711,7 +716,7 @@ class Graph(BaseModel):
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges])
|
||||
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
||||
g.add_edges_from(
|
||||
[
|
||||
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
|
||||
@ -768,6 +773,24 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
# Declare all fields as required; necessary for OpenAPI schema generation build.
|
||||
# Technically only fields without a `default_factory` need to be listed here.
|
||||
# See: https://github.com/pydantic/pydantic/discussions/4577
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'id',
|
||||
'graph',
|
||||
'execution_graph',
|
||||
'executed',
|
||||
'executed_history',
|
||||
'results',
|
||||
'errors',
|
||||
'prepared_source_mapping',
|
||||
'source_prepared_mapping',
|
||||
]
|
||||
}
|
||||
|
||||
def next(self) -> BaseInvocation | None:
|
||||
"""Gets the next node ready to execute."""
|
||||
|
||||
@ -841,13 +864,13 @@ class GraphExecutionState(BaseModel):
|
||||
input_collection_prepared_node_id = next(
|
||||
n[1]
|
||||
for n in iteration_node_map
|
||||
if n[0] == input_collection_edge[0].node_id
|
||||
if n[0] == input_collection_edge.source.node_id
|
||||
)
|
||||
input_collection_prepared_node_output = self.results[
|
||||
input_collection_prepared_node_id
|
||||
]
|
||||
input_collection = getattr(
|
||||
input_collection_prepared_node_output, input_collection_edge[0].field
|
||||
input_collection_prepared_node_output, input_collection_edge.source.field
|
||||
)
|
||||
self_iteration_count = len(input_collection)
|
||||
|
||||
@ -864,11 +887,11 @@ class GraphExecutionState(BaseModel):
|
||||
new_edges = list()
|
||||
for edge in input_edges:
|
||||
for input_node_id in (
|
||||
n[1] for n in iteration_node_map if n[0] == edge[0].node_id
|
||||
n[1] for n in iteration_node_map if n[0] == edge.source.node_id
|
||||
):
|
||||
new_edge = (
|
||||
EdgeConnection(node_id=input_node_id, field=edge[0].field),
|
||||
EdgeConnection(node_id="", field=edge[1].field),
|
||||
new_edge = Edge(
|
||||
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
|
||||
destination=EdgeConnection(node_id="", field=edge.destination.field),
|
||||
)
|
||||
new_edges.append(new_edge)
|
||||
|
||||
@ -893,9 +916,9 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Add new edges to execution graph
|
||||
for edge in new_edges:
|
||||
new_edge = (
|
||||
edge[0],
|
||||
EdgeConnection(node_id=new_node.id, field=edge[1].field),
|
||||
new_edge = Edge(
|
||||
source=edge.source,
|
||||
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
|
||||
)
|
||||
self.execution_graph.add_edge(new_edge)
|
||||
|
||||
@ -1043,26 +1066,26 @@ class GraphExecutionState(BaseModel):
|
||||
return self.execution_graph.nodes[next_node]
|
||||
|
||||
def _prepare_inputs(self, node: BaseInvocation):
|
||||
input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id]
|
||||
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
|
||||
if isinstance(node, CollectInvocation):
|
||||
output_collection = [
|
||||
getattr(self.results[edge[0].node_id], edge[0].field)
|
||||
getattr(self.results[edge.source.node_id], edge.source.field)
|
||||
for edge in input_edges
|
||||
if edge[1].field == "item"
|
||||
if edge.destination.field == "item"
|
||||
]
|
||||
setattr(node, "collection", output_collection)
|
||||
else:
|
||||
for edge in input_edges:
|
||||
output_value = getattr(self.results[edge[0].node_id], edge[0].field)
|
||||
setattr(node, edge[1].field, output_value)
|
||||
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
|
||||
setattr(node, edge.destination.field, output_value)
|
||||
|
||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
if not self._is_edge_valid(edge):
|
||||
return False
|
||||
|
||||
# Invalid if destination has already been prepared or executed
|
||||
if edge[1].node_id in self.source_prepared_mapping:
|
||||
if edge.destination.node_id in self.source_prepared_mapping:
|
||||
return False
|
||||
|
||||
# Otherwise, the edge is valid
|
||||
@ -1089,17 +1112,17 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
self.graph.delete_node(node_path)
|
||||
|
||||
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
if not self._is_node_updatable(edge[1].node_id):
|
||||
def add_edge(self, edge: Edge) -> None:
|
||||
if not self._is_node_updatable(edge.destination.node_id):
|
||||
raise NodeAlreadyExecutedError(
|
||||
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to"
|
||||
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot be linked to"
|
||||
)
|
||||
self.graph.add_edge(edge)
|
||||
|
||||
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
if not self._is_node_updatable(edge[1].node_id):
|
||||
def delete_edge(self, edge: Edge) -> None:
|
||||
if not self._is_node_updatable(edge.destination.node_id):
|
||||
raise NodeAlreadyExecutedError(
|
||||
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted"
|
||||
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted"
|
||||
)
|
||||
self.graph.delete_edge(edge)
|
||||
|
||||
|
@ -490,7 +490,7 @@ class Args(object):
|
||||
"-z",
|
||||
type=int,
|
||||
default=6,
|
||||
choices=range(0, 9),
|
||||
choices=range(0, 10),
|
||||
dest="png_compression",
|
||||
help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.",
|
||||
)
|
||||
@ -943,7 +943,6 @@ class Args(object):
|
||||
"--png_compression",
|
||||
"-z",
|
||||
type=int,
|
||||
default=6,
|
||||
choices=range(0, 10),
|
||||
dest="png_compression",
|
||||
help="level of PNG compression, from 0 (none) to 9 (maximum). [6]",
|
||||
|
@ -497,7 +497,8 @@ class Generator:
|
||||
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
def sample_to_lowres_estimated_image(self, samples):
|
||||
@staticmethod
|
||||
def sample_to_lowres_estimated_image(samples):
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
|
@ -159,6 +159,7 @@ class Inpaint(Img2Img):
|
||||
seam_size: int,
|
||||
seam_blur: int,
|
||||
prompt,
|
||||
seed,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
@ -192,7 +193,7 @@ class Inpaint(Img2Img):
|
||||
|
||||
seam_noise = self.get_noise(im.width, im.height)
|
||||
|
||||
result = make_image(seam_noise)
|
||||
result = make_image(seam_noise, seed)
|
||||
|
||||
return result
|
||||
|
||||
@ -342,6 +343,7 @@ class Inpaint(Img2Img):
|
||||
seam_size,
|
||||
seam_blur,
|
||||
prompt,
|
||||
seed,
|
||||
sampler,
|
||||
seam_steps,
|
||||
cfg_scale,
|
||||
|
@ -1086,9 +1086,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
checkpoint = (
|
||||
load_file(checkpoint_path)
|
||||
if Path(checkpoint_path).suffix == ".safetensors"
|
||||
else torch.load(checkpoint_path)
|
||||
torch.load(checkpoint_path)
|
||||
if Path(checkpoint_path).suffix == ".ckpt"
|
||||
else load_file(checkpoint_path)
|
||||
|
||||
)
|
||||
cache_dir = global_cache_dir("hub")
|
||||
pipeline_class = (
|
||||
|
@ -730,9 +730,9 @@ v Apply picklescanner to the indicated checkpoint and issue a warning
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
checkpoint = (
|
||||
safetensors.torch.load_file(model_path)
|
||||
if model_path.suffix == ".safetensors"
|
||||
else torch.load(model_path)
|
||||
torch.load(model_path)
|
||||
if model_path.suffix == ".ckpt"
|
||||
else safetensors.torch.load_file(model_path)
|
||||
)
|
||||
|
||||
# additional probing needed if no config file provided
|
||||
|
@ -3,6 +3,9 @@ import math
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
import io
|
||||
import base64
|
||||
|
||||
from collections import abc
|
||||
from inspect import isfunction
|
||||
from pathlib import Path
|
||||
@ -364,3 +367,16 @@ def url_attachment_name(url: str) -> dict:
|
||||
def download_with_progress_bar(url: str, dest: Path) -> bool:
|
||||
result = download_with_resume(url, dest, access_token=None)
|
||||
return result is not None
|
||||
|
||||
|
||||
def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
|
||||
"""
|
||||
Converts an image into a base64 image dataURL.
|
||||
"""
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format=image_format)
|
||||
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
|
||||
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
|
||||
buffered.getvalue()
|
||||
).decode("UTF-8")
|
||||
return image_base64
|
||||
|
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
Normal file
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
Normal file
File diff suppressed because one or more lines are too long
188
invokeai/frontend/web/dist/assets/App-982926da.js
vendored
188
invokeai/frontend/web/dist/assets/App-982926da.js
vendored
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
||||
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-2ad84bef.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-e63a2dc4.js";var Or=`
|
||||
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-f7f41e1f.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-eaf47ae3.js";var Or=`
|
||||
:root {
|
||||
--chakra-vh: 100vh;
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-2ad84bef.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-f7f41e1f.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index-5483945c.css">
|
||||
</head>
|
||||
|
||||
|
3
invokeai/frontend/web/dist/locales/en.json
vendored
3
invokeai/frontend/web/dist/locales/en.json
vendored
@ -64,6 +64,8 @@
|
||||
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
|
||||
"upload": "Upload",
|
||||
"close": "Close",
|
||||
"cancel": "Cancel",
|
||||
"accept": "Accept",
|
||||
"load": "Load",
|
||||
"back": "Back",
|
||||
"statusConnected": "Connected",
|
||||
@ -333,6 +335,7 @@
|
||||
"addNewModel": "Add New Model",
|
||||
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
|
||||
"addDiffuserModel": "Add Diffusers",
|
||||
"scanForModels": "Scan For Models",
|
||||
"addManually": "Add Manually",
|
||||
"manual": "Manual",
|
||||
"name": "Name",
|
||||
|
36
invokeai/frontend/web/index.d.ts
vendored
36
invokeai/frontend/web/index.d.ts
vendored
@ -1,3 +1,7 @@
|
||||
import React, { PropsWithChildren } from 'react';
|
||||
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
|
||||
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
|
||||
|
||||
export {};
|
||||
|
||||
declare module 'redux-socket.io-middleware';
|
||||
@ -40,5 +44,35 @@ declare global {
|
||||
/* eslint-enable @typescript-eslint/no-explicit-any */
|
||||
}
|
||||
|
||||
declare function Invoke(): React.JSX;
|
||||
declare module '@invoke-ai/invoke-ai-ui' {
|
||||
declare class ThemeChanger extends React.Component<ThemeChangerProps> {
|
||||
public constructor(props: ThemeChangerProps);
|
||||
}
|
||||
|
||||
declare class InvokeAiLogoComponent extends React.Component<InvokeAILogoComponentProps> {
|
||||
public constructor(props: InvokeAILogoComponentProps);
|
||||
}
|
||||
|
||||
declare class IAIPopover extends React.Component<IAIPopoverProps> {
|
||||
public constructor(props: IAIPopoverProps);
|
||||
}
|
||||
|
||||
declare class IAIIconButton extends React.Component<IAIIconButtonProps> {
|
||||
public constructor(props: IAIIconButtonProps);
|
||||
}
|
||||
|
||||
declare class SettingsModal extends React.Component<SettingsModalProps> {
|
||||
public constructor(props: SettingsModalProps);
|
||||
}
|
||||
}
|
||||
|
||||
declare function Invoke(props: PropsWithChildren): JSX.Element;
|
||||
|
||||
export {
|
||||
ThemeChanger,
|
||||
InvokeAiLogoComponent,
|
||||
IAIPopover,
|
||||
IAIIconButton,
|
||||
SettingsModal,
|
||||
};
|
||||
export = Invoke;
|
||||
|
@ -6,7 +6,6 @@
|
||||
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
|
||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||
"build": "yarn run lint && vite build",
|
||||
"build:package": "vite build --mode=package",
|
||||
"preview": "vite preview",
|
||||
"lint:madge": "madge --circular src/main.tsx",
|
||||
"lint:eslint": "eslint --max-warnings=0 .",
|
||||
|
@ -64,6 +64,8 @@
|
||||
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
|
||||
"upload": "Upload",
|
||||
"close": "Close",
|
||||
"cancel": "Cancel",
|
||||
"accept": "Accept",
|
||||
"load": "Load",
|
||||
"back": "Back",
|
||||
"statusConnected": "Connected",
|
||||
@ -333,6 +335,7 @@
|
||||
"addNewModel": "Add New Model",
|
||||
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
|
||||
"addDiffuserModel": "Add Diffusers",
|
||||
"scanForModels": "Scan For Models",
|
||||
"addManually": "Add Manually",
|
||||
"manual": "Manual",
|
||||
"name": "Name",
|
||||
|
@ -14,11 +14,11 @@ import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
|
||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||
import { useAppSelector } from './storeHooks';
|
||||
import { useEffect } from 'react';
|
||||
import { PropsWithChildren, useEffect } from 'react';
|
||||
|
||||
keepGUIAlive();
|
||||
|
||||
const App = () => {
|
||||
const App = (props: PropsWithChildren) => {
|
||||
useToastWatcher();
|
||||
|
||||
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
|
||||
@ -40,7 +40,7 @@ const App = () => {
|
||||
w={APP_WIDTH}
|
||||
h={APP_HEIGHT}
|
||||
>
|
||||
<SiteHeader />
|
||||
{props.children || <SiteHeader />}
|
||||
<Flex gap={4} w="full" h="full">
|
||||
<InvokeTabs />
|
||||
<ImageGalleryPanel />
|
||||
|
@ -31,18 +31,14 @@ export const DIFFUSERS_SAMPLERS: Array<string> = [
|
||||
];
|
||||
|
||||
// Valid image widths
|
||||
export const WIDTHS: Array<number> = [
|
||||
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
||||
1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792,
|
||||
1856, 1920, 1984, 2048,
|
||||
];
|
||||
export const WIDTHS: Array<number> = Array.from(Array(65)).map(
|
||||
(_x, i) => i * 64
|
||||
);
|
||||
|
||||
// Valid image heights
|
||||
export const HEIGHTS: Array<number> = [
|
||||
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
||||
1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792,
|
||||
1856, 1920, 1984, 2048,
|
||||
];
|
||||
export const HEIGHTS: Array<number> = Array.from(Array(65)).map(
|
||||
(_x, i) => i * 64
|
||||
);
|
||||
|
||||
// Valid upscaling levels
|
||||
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
|
||||
|
@ -9,6 +9,7 @@ import {
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { cloneElement, memo, ReactElement, ReactNode, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import IAIButton from './IAIButton';
|
||||
|
||||
type Props = {
|
||||
@ -22,10 +23,12 @@ type Props = {
|
||||
};
|
||||
|
||||
const IAIAlertDialog = forwardRef((props: Props, ref) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const {
|
||||
acceptButtonText = 'Accept',
|
||||
acceptButtonText = t('common.accept'),
|
||||
acceptCallback,
|
||||
cancelButtonText = 'Cancel',
|
||||
cancelButtonText = t('common.cancel'),
|
||||
cancelCallback,
|
||||
children,
|
||||
title,
|
||||
@ -56,6 +59,7 @@ const IAIAlertDialog = forwardRef((props: Props, ref) => {
|
||||
isOpen={isOpen}
|
||||
leastDestructiveRef={cancelRef}
|
||||
onClose={onClose}
|
||||
isCentered
|
||||
>
|
||||
<AlertDialogOverlay>
|
||||
<AlertDialogContent>
|
||||
|
8
invokeai/frontend/web/src/common/components/IAIForm.tsx
Normal file
8
invokeai/frontend/web/src/common/components/IAIForm.tsx
Normal file
@ -0,0 +1,8 @@
|
||||
import { chakra } from '@chakra-ui/react';
|
||||
|
||||
/**
|
||||
* Chakra-enabled <form />
|
||||
*/
|
||||
const IAIForm = chakra.form;
|
||||
|
||||
export default IAIForm;
|
@ -0,0 +1,23 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { ReactElement } from 'react';
|
||||
|
||||
export function IAIFormItemWrapper({
|
||||
children,
|
||||
}: {
|
||||
children: ReactElement | ReactElement[];
|
||||
}) {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
rowGap: 4,
|
||||
borderRadius: 'base',
|
||||
width: 'full',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -8,7 +8,7 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import { memo, ReactNode } from 'react';
|
||||
|
||||
type IAIPopoverProps = PopoverProps & {
|
||||
export type IAIPopoverProps = PopoverProps & {
|
||||
triggerComponent: ReactNode;
|
||||
triggerContainerProps?: BoxProps;
|
||||
children: ReactNode;
|
||||
|
@ -1,4 +1,4 @@
|
||||
import React, { lazy } from 'react';
|
||||
import React, { lazy, PropsWithChildren } from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { PersistGate } from 'redux-persist/integration/react';
|
||||
import { store } from './app/store';
|
||||
@ -21,14 +21,14 @@ import './i18n';
|
||||
const App = lazy(() => import('./app/App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
|
||||
|
||||
export default function Component() {
|
||||
export default function Component(props: PropsWithChildren) {
|
||||
return (
|
||||
<React.StrictMode>
|
||||
<Provider store={store}>
|
||||
<PersistGate loading={<Loading />} persistor={persistor}>
|
||||
<React.Suspense fallback={<Loading showText />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App />
|
||||
<App>{props.children}</App>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
</PersistGate>
|
||||
|
16
invokeai/frontend/web/src/exports.tsx
Normal file
16
invokeai/frontend/web/src/exports.tsx
Normal file
@ -0,0 +1,16 @@
|
||||
import Component from './component';
|
||||
|
||||
import InvokeAiLogoComponent from './features/system/components/InvokeAILogoComponent';
|
||||
import ThemeChanger from './features/system/components/ThemeChanger';
|
||||
import IAIPopover from './common/components/IAIPopover';
|
||||
import IAIIconButton from './common/components/IAIIconButton';
|
||||
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
|
||||
|
||||
export default Component;
|
||||
export {
|
||||
InvokeAiLogoComponent,
|
||||
ThemeChanger,
|
||||
IAIPopover,
|
||||
IAIIconButton,
|
||||
SettingsModal,
|
||||
};
|
@ -104,7 +104,6 @@ const IAICanvasMaskOptions = () => {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<ButtonGroup>
|
||||
<IAIIconButton
|
||||
|
@ -88,7 +88,7 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
isLazy={false}
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
tooltip={t('unifiedCanvas.canvasSettings')}
|
||||
|
@ -219,7 +219,6 @@ const IAICanvasToolChooserOptions = () => {
|
||||
onClick={handleSelectColorPickerTool}
|
||||
/>
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
aria-label={t('unifiedCanvas.brushOptions')}
|
||||
|
@ -405,7 +405,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
>
|
||||
<ButtonGroup isAttached={true}>
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
aria-label={`${t('parameters.sendTo')}...`}
|
||||
@ -505,7 +504,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
|
||||
<ButtonGroup isAttached={true}>
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
icon={<FaGrinStars />}
|
||||
@ -535,7 +533,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
</IAIPopover>
|
||||
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
icon={<FaExpandArrowsAlt />}
|
||||
|
@ -0,0 +1,24 @@
|
||||
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
|
||||
|
||||
type CurrentImageFallbackProps = SpinnerProps;
|
||||
|
||||
const CurrentImageFallback = (props: CurrentImageFallbackProps) => {
|
||||
const { size = 'xl', ...rest } = props;
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
position: 'absolute',
|
||||
color: 'base.400',
|
||||
}}
|
||||
>
|
||||
<Spinner size={size} {...rest} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default CurrentImageFallback;
|
@ -7,6 +7,7 @@ import { isEqual } from 'lodash';
|
||||
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
|
||||
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import CurrentImageFallback from './CurrentImageFallback';
|
||||
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||
|
||||
@ -48,6 +49,7 @@ export default function CurrentImagePreview() {
|
||||
src={imageToDisplay.url}
|
||||
width={imageToDisplay.width}
|
||||
height={imageToDisplay.height}
|
||||
fallback={!isIntermediate ? <CurrentImageFallback /> : undefined}
|
||||
sx={{
|
||||
objectFit: 'contain',
|
||||
maxWidth: '100%',
|
||||
|
@ -55,7 +55,6 @@ export default function LanguagePicker() {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
aria-label={t('common.languagePickerLabel')}
|
||||
|
@ -1,4 +1,5 @@
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormHelperText,
|
||||
@ -25,10 +26,10 @@ import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { InvokeModelConfigProps } from 'app/invokeai';
|
||||
import type { RootState } from 'app/store';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import type { FieldInputProps, FormikProps } from 'formik';
|
||||
import { BiArrowBack } from 'react-icons/bi';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
|
||||
|
||||
const MIN_MODEL_SIZE = 64;
|
||||
const MAX_MODEL_SIZE = 2048;
|
||||
@ -72,243 +73,250 @@ export default function AddCheckpointModel() {
|
||||
|
||||
return (
|
||||
<VStack gap={2} alignItems="flex-start">
|
||||
<IAIIconButton
|
||||
aria-label={t('common.back')}
|
||||
tooltip={t('common.back')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
||||
width="max-content"
|
||||
position="absolute"
|
||||
zIndex={1}
|
||||
size="sm"
|
||||
insetInlineEnd={12}
|
||||
top={3}
|
||||
icon={<BiArrowBack />}
|
||||
/>
|
||||
<Flex columnGap={4}>
|
||||
<IAICheckbox
|
||||
isChecked={!addManually}
|
||||
label={t('modelManager.scanForModels')}
|
||||
onChange={() => setAddmanually(!addManually)}
|
||||
/>
|
||||
<IAICheckbox
|
||||
label={t('modelManager.addManually')}
|
||||
isChecked={addManually}
|
||||
onChange={() => setAddmanually(!addManually)}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
<SearchModels />
|
||||
<IAICheckbox
|
||||
label={t('modelManager.addManually')}
|
||||
isChecked={addManually}
|
||||
onChange={() => setAddmanually(!addManually)}
|
||||
/>
|
||||
|
||||
{addManually && (
|
||||
{addManually ? (
|
||||
<Formik
|
||||
initialValues={addModelFormValues}
|
||||
onSubmit={addModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<IAIForm onSubmit={handleSubmit} sx={{ w: 'full' }}>
|
||||
<VStack rowGap={2}>
|
||||
<Text fontSize={20} fontWeight="bold" alignSelf="start">
|
||||
{t('modelManager.manual')}
|
||||
</Text>
|
||||
{/* Name */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.name && touched.name}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="name" fontSize="sm">
|
||||
{t('modelManager.name')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="name"
|
||||
name="name"
|
||||
type="text"
|
||||
validate={baseValidation}
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.name && touched.name ? (
|
||||
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.nameValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl
|
||||
isInvalid={!!errors.name && touched.name}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="name" fontSize="sm">
|
||||
{t('modelManager.name')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="name"
|
||||
name="name"
|
||||
type="text"
|
||||
validate={baseValidation}
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.name && touched.name ? (
|
||||
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.nameValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="description" fontSize="sm">
|
||||
{t('modelManager.description')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="description"
|
||||
name="description"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.description && touched.description ? (
|
||||
<FormErrorMessage>{errors.description}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.descriptionValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="description" fontSize="sm">
|
||||
{t('modelManager.description')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="description"
|
||||
name="description"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.description && touched.description ? (
|
||||
<FormErrorMessage>
|
||||
{errors.description}
|
||||
</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.descriptionValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* Config */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.config && touched.config}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelManager.config')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="config"
|
||||
name="config"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.config && touched.config ? (
|
||||
<FormErrorMessage>{errors.config}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.configValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl
|
||||
isInvalid={!!errors.config && touched.config}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelManager.config')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="config"
|
||||
name="config"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.config && touched.config ? (
|
||||
<FormErrorMessage>{errors.config}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.configValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* Weights */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.weights && touched.weights}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelManager.modelLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="weights"
|
||||
name="weights"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.weights && touched.weights ? (
|
||||
<FormErrorMessage>{errors.weights}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.modelLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl
|
||||
isInvalid={!!errors.weights && touched.weights}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelManager.modelLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="weights"
|
||||
name="weights"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.weights && touched.weights ? (
|
||||
<FormErrorMessage>{errors.weights}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.modelLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* VAE */}
|
||||
<FormControl isInvalid={!!errors.vae && touched.vae}>
|
||||
<FormLabel htmlFor="vae" fontSize="sm">
|
||||
{t('modelManager.vaeLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae"
|
||||
name="vae"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.vae && touched.vae ? (
|
||||
<FormErrorMessage>{errors.vae}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.vaeLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl isInvalid={!!errors.vae && touched.vae}>
|
||||
<FormLabel htmlFor="vae" fontSize="sm">
|
||||
{t('modelManager.vaeLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae"
|
||||
name="vae"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.vae && touched.vae ? (
|
||||
<FormErrorMessage>{errors.vae}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.vaeLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<HStack width="100%">
|
||||
{/* Width */}
|
||||
<FormControl isInvalid={!!errors.width && touched.width}>
|
||||
<FormLabel htmlFor="width" fontSize="sm">
|
||||
{t('modelManager.width')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field id="width" name="width">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="width"
|
||||
name="width"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
step={64}
|
||||
width="90%"
|
||||
value={form.values.width}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl isInvalid={!!errors.width && touched.width}>
|
||||
<FormLabel htmlFor="width" fontSize="sm">
|
||||
{t('modelManager.width')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field id="width" name="width">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="width"
|
||||
name="width"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
step={64}
|
||||
value={form.values.width}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
|
||||
{!!errors.width && touched.width ? (
|
||||
<FormErrorMessage>{errors.width}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.widthValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
{!!errors.width && touched.width ? (
|
||||
<FormErrorMessage>{errors.width}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.widthValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* Height */}
|
||||
<FormControl isInvalid={!!errors.height && touched.height}>
|
||||
<FormLabel htmlFor="height" fontSize="sm">
|
||||
{t('modelManager.height')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field id="height" name="height">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="height"
|
||||
name="height"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
width="90%"
|
||||
step={64}
|
||||
value={form.values.height}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl isInvalid={!!errors.height && touched.height}>
|
||||
<FormLabel htmlFor="height" fontSize="sm">
|
||||
{t('modelManager.height')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field id="height" name="height">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="height"
|
||||
name="height"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
step={64}
|
||||
value={form.values.height}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
|
||||
{!!errors.height && touched.height ? (
|
||||
<FormErrorMessage>{errors.height}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.heightValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
{!!errors.height && touched.height ? (
|
||||
<FormErrorMessage>{errors.height}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.heightValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
</HStack>
|
||||
|
||||
<IAIButton
|
||||
@ -319,9 +327,11 @@ export default function AddCheckpointModel() {
|
||||
{t('modelManager.addModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
) : (
|
||||
<SearchModels />
|
||||
)}
|
||||
</VStack>
|
||||
);
|
||||
|
@ -11,36 +11,14 @@ import { InvokeDiffusersModelConfigProps } from 'app/invokeai';
|
||||
import { addNewModel } from 'app/socketio/actions';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import { Field, Formik } from 'formik';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { BiArrowBack } from 'react-icons/bi';
|
||||
|
||||
import type { RootState } from 'app/store';
|
||||
import type { ReactElement } from 'react';
|
||||
|
||||
function FormItemWrapper({
|
||||
children,
|
||||
}: {
|
||||
children: ReactElement | ReactElement[];
|
||||
}) {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
rowGap: 4,
|
||||
borderRadius: 'base',
|
||||
width: 'full',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
|
||||
|
||||
export default function AddDiffusersModel() {
|
||||
const dispatch = useAppDispatch();
|
||||
@ -89,26 +67,14 @@ export default function AddDiffusersModel() {
|
||||
|
||||
return (
|
||||
<Flex>
|
||||
<IAIIconButton
|
||||
aria-label={t('common.back')}
|
||||
tooltip={t('common.back')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
||||
width="max-content"
|
||||
position="absolute"
|
||||
zIndex={1}
|
||||
size="sm"
|
||||
insetInlineEnd={12}
|
||||
top={3}
|
||||
icon={<BiArrowBack />}
|
||||
/>
|
||||
<Formik
|
||||
initialValues={addModelFormValues}
|
||||
onSubmit={addModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<IAIForm onSubmit={handleSubmit}>
|
||||
<VStack rowGap={2}>
|
||||
<FormItemWrapper>
|
||||
<IAIFormItemWrapper>
|
||||
{/* Name */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.name && touched.name}
|
||||
@ -136,9 +102,9 @@ export default function AddDiffusersModel() {
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<FormItemWrapper>
|
||||
<IAIFormItemWrapper>
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
@ -165,9 +131,9 @@ export default function AddDiffusersModel() {
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<FormItemWrapper>
|
||||
<IAIFormItemWrapper>
|
||||
<Text fontWeight="bold" fontSize="sm">
|
||||
{t('modelManager.formMessageDiffusersModelLocation')}
|
||||
</Text>
|
||||
@ -226,9 +192,9 @@ export default function AddDiffusersModel() {
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<FormItemWrapper>
|
||||
<IAIFormItemWrapper>
|
||||
{/* VAE Path */}
|
||||
<Text fontWeight="bold">
|
||||
{t('modelManager.formMessageDiffusersVAELocation')}
|
||||
@ -290,13 +256,13 @@ export default function AddDiffusersModel() {
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<IAIButton type="submit" isLoading={isProcessing}>
|
||||
{t('modelManager.addModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
</Flex>
|
||||
|
@ -14,7 +14,7 @@ import {
|
||||
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
|
||||
import { FaPlus } from 'react-icons/fa';
|
||||
import { FaArrowLeft, FaPlus } from 'react-icons/fa';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -23,6 +23,7 @@ import type { RootState } from 'app/store';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import AddCheckpointModel from './AddCheckpointModel';
|
||||
import AddDiffusersModel from './AddDiffusersModel';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
|
||||
function AddModelBox({
|
||||
text,
|
||||
@ -83,8 +84,22 @@ export default function AddModel() {
|
||||
closeOnOverlayClick={false}
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent margin="auto" paddingInlineEnd={4}>
|
||||
<ModalHeader>{t('modelManager.addNewModel')}</ModalHeader>
|
||||
<ModalContent margin="auto">
|
||||
<ModalHeader>{t('modelManager.addNewModel')} </ModalHeader>
|
||||
{addNewModelUIOption !== null && (
|
||||
<IAIIconButton
|
||||
aria-label={t('common.back')}
|
||||
tooltip={t('common.back')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
||||
position="absolute"
|
||||
variant="ghost"
|
||||
zIndex={1}
|
||||
size="sm"
|
||||
insetInlineEnd={12}
|
||||
top={2}
|
||||
icon={<FaArrowLeft />}
|
||||
/>
|
||||
)}
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
{addNewModelUIOption == null && (
|
||||
|
@ -28,6 +28,7 @@ import { isEqual, pickBy } from 'lodash';
|
||||
import ModelConvert from './ModelConvert';
|
||||
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
||||
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
|
||||
const selector = createSelector(
|
||||
[systemSelector],
|
||||
@ -120,7 +121,7 @@ export default function CheckpointModelEdit() {
|
||||
onSubmit={editModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<IAIForm onSubmit={handleSubmit}>
|
||||
<VStack rowGap={2} alignItems="start">
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
@ -317,7 +318,7 @@ export default function CheckpointModelEdit() {
|
||||
{t('modelManager.updateModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
</Flex>
|
||||
|
@ -18,6 +18,7 @@ import type { RootState } from 'app/store';
|
||||
import { isEqual, pickBy } from 'lodash';
|
||||
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
||||
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
|
||||
const selector = createSelector(
|
||||
[systemSelector],
|
||||
@ -116,7 +117,7 @@ export default function DiffusersModelEdit() {
|
||||
onSubmit={editModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<IAIForm onSubmit={handleSubmit}>
|
||||
<VStack rowGap={2} alignItems="start">
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
@ -259,7 +260,7 @@ export default function DiffusersModelEdit() {
|
||||
{t('modelManager.updateModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
</Flex>
|
||||
|
@ -12,14 +12,13 @@ import {
|
||||
RadioGroup,
|
||||
Spacer,
|
||||
Text,
|
||||
VStack,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { FaPlus, FaSearch } from 'react-icons/fa';
|
||||
import { FaSearch, FaTrash } from 'react-icons/fa';
|
||||
|
||||
import { addNewModel, searchForModels } from 'app/socketio/actions';
|
||||
import {
|
||||
@ -34,7 +33,7 @@ import IAIInput from 'common/components/IAIInput';
|
||||
import { Field, Formik } from 'formik';
|
||||
import { forEach, remove } from 'lodash';
|
||||
import type { ChangeEvent, ReactNode } from 'react';
|
||||
import { BiReset } from 'react-icons/bi';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
|
||||
const existingModelsSelector = createSelector([systemSelector], (system) => {
|
||||
const { model_list } = system;
|
||||
@ -71,34 +70,32 @@ function SearchModelEntry({
|
||||
};
|
||||
|
||||
return (
|
||||
<VStack>
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
gap={2}
|
||||
backgroundColor={
|
||||
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
|
||||
}
|
||||
paddingX={4}
|
||||
paddingY={2}
|
||||
borderRadius={4}
|
||||
>
|
||||
<Flex gap={4}>
|
||||
<IAICheckbox
|
||||
value={model.name}
|
||||
label={<Text fontWeight={500}>{model.name}</Text>}
|
||||
isChecked={modelsToAdd.includes(model.name)}
|
||||
isDisabled={existingModels.includes(model.location)}
|
||||
onChange={foundModelsChangeHandler}
|
||||
></IAICheckbox>
|
||||
{existingModels.includes(model.location) && (
|
||||
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
|
||||
)}
|
||||
</Flex>
|
||||
<Text fontStyle="italic" variant="subtext">
|
||||
{model.location}
|
||||
</Text>
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
gap={2}
|
||||
backgroundColor={
|
||||
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
|
||||
}
|
||||
paddingX={4}
|
||||
paddingY={2}
|
||||
borderRadius={4}
|
||||
>
|
||||
<Flex gap={4} alignItems="center" justifyContent="space-between">
|
||||
<IAICheckbox
|
||||
value={model.name}
|
||||
label={<Text fontWeight={500}>{model.name}</Text>}
|
||||
isChecked={modelsToAdd.includes(model.name)}
|
||||
isDisabled={existingModels.includes(model.location)}
|
||||
onChange={foundModelsChangeHandler}
|
||||
></IAICheckbox>
|
||||
{existingModels.includes(model.location) && (
|
||||
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
|
||||
)}
|
||||
</Flex>
|
||||
</VStack>
|
||||
<Text fontStyle="italic" variant="subtext">
|
||||
{model.location}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
@ -215,10 +212,10 @@ export default function SearchModels() {
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex flexDirection="column" rowGap={4}>
|
||||
{newFoundModels}
|
||||
{shouldShowExistingModelsInSearch && existingFoundModels}
|
||||
</>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@ -245,26 +242,26 @@ export default function SearchModels() {
|
||||
<Text
|
||||
sx={{
|
||||
fontWeight: 500,
|
||||
fontSize: 'sm',
|
||||
}}
|
||||
variant="subtext"
|
||||
>
|
||||
{t('modelManager.checkpointFolder')}
|
||||
</Text>
|
||||
<Text sx={{ fontWeight: 500, fontSize: 'sm' }}>{searchFolder}</Text>
|
||||
<Text sx={{ fontWeight: 500 }}>{searchFolder}</Text>
|
||||
</Flex>
|
||||
<Spacer />
|
||||
<IAIIconButton
|
||||
aria-label={t('modelManager.scanAgain')}
|
||||
tooltip={t('modelManager.scanAgain')}
|
||||
icon={<BiReset />}
|
||||
icon={<FaSearch />}
|
||||
fontSize={18}
|
||||
disabled={isProcessing}
|
||||
onClick={() => dispatch(searchForModels(searchFolder))}
|
||||
/>
|
||||
<IAIIconButton
|
||||
aria-label={t('modelManager.clearCheckpointFolder')}
|
||||
icon={<FaPlus style={{ transform: 'rotate(45deg)' }} />}
|
||||
tooltip={t('modelManager.clearCheckpointFolder')}
|
||||
icon={<FaTrash />}
|
||||
onClick={resetSearchModelHandler}
|
||||
/>
|
||||
</Flex>
|
||||
@ -276,9 +273,9 @@ export default function SearchModels() {
|
||||
}}
|
||||
>
|
||||
{({ handleSubmit }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<HStack columnGap={2} alignItems="flex-end" width="100%">
|
||||
<FormControl isRequired width="lg">
|
||||
<IAIForm onSubmit={handleSubmit} width="100%">
|
||||
<HStack columnGap={2} alignItems="flex-end">
|
||||
<FormControl flexGrow={1}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="checkpointFolder"
|
||||
@ -294,12 +291,12 @@ export default function SearchModels() {
|
||||
tooltip={t('modelManager.findModels')}
|
||||
type="submit"
|
||||
disabled={isProcessing}
|
||||
paddingX={10}
|
||||
px={8}
|
||||
>
|
||||
{t('modelManager.findModels')}
|
||||
</IAIButton>
|
||||
</HStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
)}
|
||||
@ -410,7 +407,6 @@ export default function SearchModels() {
|
||||
maxHeight={72}
|
||||
overflowY="scroll"
|
||||
borderRadius="sm"
|
||||
paddingInlineEnd={4}
|
||||
gap={2}
|
||||
>
|
||||
{foundModels.length > 0 ? (
|
||||
|
@ -50,7 +50,6 @@ export default function ThemeChanger() {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
aria-label={t('common.themeLabel')}
|
||||
|
@ -166,20 +166,8 @@ export default function InvokeTabs() {
|
||||
[]
|
||||
);
|
||||
|
||||
/**
|
||||
* isLazy means the tabs are mounted and unmounted when changing them. There is a tradeoff here,
|
||||
* as mounting is expensive, but so is retaining all tabs in the DOM at all times.
|
||||
*
|
||||
* Removing isLazy messes with the outside click watcher, which is used by ResizableDrawer.
|
||||
* Because you have multiple handlers listening for an outside click, any click anywhere triggers
|
||||
* the watcher for the hidden drawers, closing the open drawer.
|
||||
*
|
||||
* TODO: Add logic to the `useOutsideClick` in ResizableDrawer to enable it only for the active
|
||||
* tab's drawer.
|
||||
*/
|
||||
return (
|
||||
<Tabs
|
||||
isLazy
|
||||
defaultIndex={activeTab}
|
||||
index={activeTab}
|
||||
onChange={(index: number) => {
|
||||
|
@ -93,12 +93,9 @@ const ResizableDrawer = ({
|
||||
useOutsideClick({
|
||||
ref: outsideClickRef,
|
||||
handler: () => {
|
||||
if (isPinned) {
|
||||
return;
|
||||
}
|
||||
|
||||
onClose();
|
||||
},
|
||||
enabled: isOpen && !isPinned,
|
||||
});
|
||||
|
||||
const handleEnables = useMemo(
|
||||
|
@ -77,7 +77,6 @@ export default function UnifiedCanvasColorPicker() {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<Box
|
||||
sx={{
|
||||
|
@ -56,7 +56,7 @@ const UnifiedCanvasSettings = () => {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
isLazy={false}
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
tooltip={t('unifiedCanvas.canvasSettings')}
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,4 +1,3 @@
|
||||
import path from 'path';
|
||||
import react from '@vitejs/plugin-react-swc';
|
||||
import { visualizer } from 'rollup-plugin-visualizer';
|
||||
import { defineConfig, PluginOption } from 'vite';
|
||||
@ -58,26 +57,6 @@ export default defineConfig(({ mode }) => {
|
||||
// sourcemap: true, // this can be enabled if needed, it adds ovwer 15MB to the commit
|
||||
},
|
||||
};
|
||||
} else if (mode === 'package') {
|
||||
return {
|
||||
...common,
|
||||
build: {
|
||||
...common.build,
|
||||
lib: {
|
||||
entry: path.resolve(__dirname, 'src/component.tsx'),
|
||||
name: 'InvokeAI UI',
|
||||
fileName: (format) => `invoke-ai-ui.${format}.js`,
|
||||
},
|
||||
rollupOptions: {
|
||||
external: ['react', 'react-dom'],
|
||||
output: {
|
||||
globals: {
|
||||
react: 'React',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
...common,
|
||||
|
@ -38,16 +38,16 @@ dependencies = [
|
||||
"albumentations",
|
||||
"click",
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==0.1.10",
|
||||
"compel==1.0.1",
|
||||
"datasets",
|
||||
"diffusers[torch]~=0.14",
|
||||
"dnspython==2.2.1",
|
||||
"einops",
|
||||
"eventlet",
|
||||
"facexlib",
|
||||
"fastapi==0.85.0",
|
||||
"fastapi-events==0.6.0",
|
||||
"fastapi-socketio==0.0.9",
|
||||
"fastapi==0.94.1",
|
||||
"fastapi-events==0.8.0",
|
||||
"fastapi-socketio==0.0.10",
|
||||
"flask==2.1.3",
|
||||
"flask_cors==3.0.10",
|
||||
"flask_socketio==5.3.0",
|
||||
@ -75,7 +75,7 @@ dependencies = [
|
||||
"torchvision>=0.14.1",
|
||||
"torchmetrics",
|
||||
"transformers~=4.26",
|
||||
"uvicorn[standard]==0.20.0",
|
||||
"uvicorn[standard]==0.21.1",
|
||||
"windows-curses; sys_platform=='win32'",
|
||||
]
|
||||
|
||||
|
@ -105,17 +105,20 @@
|
||||
|
||||
// Start building nodes
|
||||
var id = 1;
|
||||
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "sampler": sampler, "steps": steps, "seed": seed};
|
||||
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": seed};
|
||||
id++;
|
||||
var i2iNode = {"id": id.toString(), "type": "img2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": Math.floor(Math.random() * 10000)};
|
||||
id++;
|
||||
var upscaleNode = {"id": id.toString(), "type": "show_image" };
|
||||
id++
|
||||
|
||||
nodes = {};
|
||||
nodes[initialNode.id] = initialNode;
|
||||
nodes[i2iNode.id] = i2iNode;
|
||||
nodes[upscaleNode.id] = upscaleNode;
|
||||
links = [
|
||||
[{ "node_id": initialNode.id, field: "image" },
|
||||
{ "node_id": upscaleNode.id, field: "image" }]
|
||||
{ "source": { "node_id": initialNode.id, field: "image" }, "destination": { "node_id": i2iNode.id, field: "image" }},
|
||||
{ "source": { "node_id": i2iNode.id, field: "image" }, "destination": { "node_id": upscaleNode.id, field: "image" }}
|
||||
];
|
||||
// expandSize = 128;
|
||||
// for (var i = 0; i < 6; ++i) {
|
||||
|
@ -1,15 +1,18 @@
|
||||
from invokeai.app.invocations.image import *
|
||||
|
||||
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
|
||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
||||
from invokeai.app.invocations.upscale import UpscaleInvocation
|
||||
import pytest
|
||||
|
||||
|
||||
# Helpers
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
|
||||
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
|
||||
return Edge(
|
||||
source=EdgeConnection(node_id = from_id, field = from_field),
|
||||
destination=EdgeConnection(node_id = to_id, field = to_field)
|
||||
)
|
||||
|
||||
# Tests
|
||||
def test_connections_are_compatible():
|
||||
@ -108,7 +111,7 @@ def test_graph_allows_non_conflicting_id_change():
|
||||
assert g.get_node("3").prompt == "Banana sushi"
|
||||
|
||||
assert len(g.edges) == 1
|
||||
assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges
|
||||
assert Edge(source=EdgeConnection(node_id = "3", field = "image"), destination=EdgeConnection(node_id = "2", field = "image")) in g.edges
|
||||
|
||||
def test_graph_fails_to_update_node_id_if_conflict():
|
||||
g = Graph()
|
||||
@ -490,10 +493,10 @@ def test_graph_can_deserialize():
|
||||
assert g2.nodes['1'] is not None
|
||||
assert g2.nodes['2'] is not None
|
||||
assert len(g2.edges) == 1
|
||||
assert g2.edges[0][0].node_id == '1'
|
||||
assert g2.edges[0][0].field == 'image'
|
||||
assert g2.edges[0][1].node_id == '2'
|
||||
assert g2.edges[0][1].field == 'image'
|
||||
assert g2.edges[0].source.node_id == '1'
|
||||
assert g2.edges[0].source.field == 'image'
|
||||
assert g2.edges[0].destination.node_id == '2'
|
||||
assert g2.edges[0].destination.field == 'image'
|
||||
|
||||
def test_graph_can_generate_schema():
|
||||
# Not throwing on this line is sufficient
|
||||
|
@ -64,10 +64,12 @@ class PromptCollectionTestInvocation(BaseInvocation):
|
||||
|
||||
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import EdgeConnection
|
||||
from invokeai.app.services.graph import Edge, EdgeConnection
|
||||
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
|
||||
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
|
||||
return Edge(
|
||||
source=EdgeConnection(node_id = from_id, field = from_field),
|
||||
destination=EdgeConnection(node_id = to_id, field = to_field))
|
||||
|
||||
|
||||
class TestEvent:
|
||||
|
Loading…
Reference in New Issue
Block a user