Merge branch 'main' into feat/use-custom-vaes

This commit is contained in:
Lincoln Stein 2023-03-23 10:32:56 -04:00
commit a958ae5e29
55 changed files with 854 additions and 724 deletions

View File

@ -17,7 +17,7 @@ notebooks.
You will need a GPU to perform training in a reasonable length of You will need a GPU to perform training in a reasonable length of
time, and at least 12 GB of VRAM. We recommend using the [`xformers` time, and at least 12 GB of VRAM. We recommend using the [`xformers`
library](../installation/070_INSTALL_XFORMERS) to accelerate the library](../installation/070_INSTALL_XFORMERS.md) to accelerate the
training process further. During training, about ~8 GB is temporarily training process further. During training, about ~8 GB is temporarily
needed in order to store intermediate models, checkpoints and logs. needed in order to store intermediate models, checkpoints and logs.

View File

@ -24,7 +24,7 @@ You need to have opencv installed so that pypatchmatch can be built:
brew install opencv brew install opencv
``` ```
The next time you start `invoke`, after sucesfully installing opencv, pypatchmatch will be built. The next time you start `invoke`, after successfully installing opencv, pypatchmatch will be built.
## Linux ## Linux
@ -56,7 +56,7 @@ Prior to installing PyPatchMatch, you need to take the following steps:
5. Confirm that pypatchmatch is installed. At the command-line prompt enter 5. Confirm that pypatchmatch is installed. At the command-line prompt enter
`python`, and then at the `>>>` line type `python`, and then at the `>>>` line type
`from patchmatch import patch_match`: It should look like the follwing: `from patchmatch import patch_match`: It should look like the following:
```py ```py
Python 3.9.5 (default, Nov 23 2021, 15:27:38) Python 3.9.5 (default, Nov 23 2021, 15:27:38)
@ -108,4 +108,4 @@ Prior to installing PyPatchMatch, you need to take the following steps:
[**Next, Follow Steps 4-6 from the Debian Section above**](#linux) [**Next, Follow Steps 4-6 from the Debian Section above**](#linux)
If you see no errors, then you're ready to go! If you see no errors you're ready to go!

View File

@ -10,6 +10,7 @@ from pydantic.fields import Field
from ...invocations import * from ...invocations import *
from ...invocations.baseinvocation import BaseInvocation from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import ( from ...services.graph import (
Edge,
EdgeConnection, EdgeConnection,
Graph, Graph,
GraphExecutionState, GraphExecutionState,
@ -92,7 +93,7 @@ async def get_session(
async def add_node( async def add_node(
session_id: str = Path(description="The id of the session"), session_id: str = Path(description="The id of the session"),
node: Annotated[ node: Annotated[
Union[BaseInvocation.get_invocations()], Field(discriminator="type") Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
] = Body(description="The node to add"), ] = Body(description="The node to add"),
) -> str: ) -> str:
"""Adds a node to the graph""" """Adds a node to the graph"""
@ -125,7 +126,7 @@ async def update_node(
session_id: str = Path(description="The id of the session"), session_id: str = Path(description="The id of the session"),
node_path: str = Path(description="The path to the node in the graph"), node_path: str = Path(description="The path to the node in the graph"),
node: Annotated[ node: Annotated[
Union[BaseInvocation.get_invocations()], Field(discriminator="type") Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
] = Body(description="The new node"), ] = Body(description="The new node"),
) -> GraphExecutionState: ) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges""" """Updates a node in the graph and removes all linked edges"""
@ -186,7 +187,7 @@ async def delete_node(
) )
async def add_edge( async def add_edge(
session_id: str = Path(description="The id of the session"), session_id: str = Path(description="The id of the session"),
edge: tuple[EdgeConnection, EdgeConnection] = Body(description="The edge to add"), edge: Edge = Body(description="The edge to add"),
) -> GraphExecutionState: ) -> GraphExecutionState:
"""Adds an edge to the graph""" """Adds an edge to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
@ -228,9 +229,9 @@ async def delete_edge(
return Response(status_code=404) return Response(status_code=404)
try: try:
edge = ( edge = Edge(
EdgeConnection(node_id=from_node_id, field=from_field), source=EdgeConnection(node_id=from_node_id, field=from_field),
EdgeConnection(node_id=to_node_id, field=to_field), destination=EdgeConnection(node_id=to_node_id, field=to_field)
) )
session.delete_edge(edge) session.delete_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set( ApiDependencies.invoker.services.graph_execution_manager.set(

View File

@ -19,7 +19,7 @@ from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices from .services.restoration_services import RestorationServices
from .services.graph import EdgeConnection, GraphExecutionState from .services.graph import Edge, EdgeConnection, GraphExecutionState
from .services.image_storage import DiskImageStorage from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices from .services.invocation_services import InvocationServices
@ -77,7 +77,7 @@ def get_command_parser() -> argparse.ArgumentParser:
def generate_matching_edges( def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation a: BaseInvocation, b: BaseInvocation
) -> list[tuple[EdgeConnection, EdgeConnection]]: ) -> list[Edge]:
"""Generates all possible edges between two invocations""" """Generates all possible edges between two invocations"""
atype = type(a) atype = type(a)
btype = type(b) btype = type(b)
@ -94,9 +94,9 @@ def generate_matching_edges(
matching_fields = matching_fields.difference(invalid_fields) matching_fields = matching_fields.difference(invalid_fields)
edges = [ edges = [
( Edge(
EdgeConnection(node_id=a.id, field=field), source=EdgeConnection(node_id=a.id, field=field),
EdgeConnection(node_id=b.id, field=field), destination=EdgeConnection(node_id=b.id, field=field)
) )
for field in matching_fields for field in matching_fields
] ]
@ -111,16 +111,15 @@ class SessionError(Exception):
def invoke_all(context: CliContext): def invoke_all(context: CliContext):
"""Runs all invocations in the specified session""" """Runs all invocations in the specified session"""
context.invoker.invoke(context.session, invoke_all=True) context.invoker.invoke(context.session, invoke_all=True)
while not context.session.is_complete(): while not context.get_session().is_complete():
# Wait some time # Wait some time
session = context.get_session()
time.sleep(0.1) time.sleep(0.1)
# Print any errors # Print any errors
if context.session.has_error(): if context.session.has_error():
for n in context.session.errors: for n in context.session.errors:
print( print(
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}" f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
) )
raise SessionError() raise SessionError()
@ -203,7 +202,7 @@ def invoke_cli():
continue continue
# Pipe previous command output (if there was a previous command) # Pipe previous command output (if there was a previous command)
edges = [] edges: list[Edge] = list()
if len(history) > 0 or current_id != start_id: if len(history) > 0 or current_id != start_id:
from_id = ( from_id = (
history[0] if current_id == start_id else str(current_id - 1) history[0] if current_id == start_id else str(current_id - 1)
@ -225,19 +224,19 @@ def invoke_cli():
matching_edges = generate_matching_edges( matching_edges = generate_matching_edges(
link_node, command.command link_node, command.command
) )
matching_destinations = [e[1] for e in matching_edges] matching_destinations = [e.destination for e in matching_edges]
edges = [e for e in edges if e[1] not in matching_destinations] edges = [e for e in edges if e.destination not in matching_destinations]
edges.extend(matching_edges) edges.extend(matching_edges)
if "link" in args and args["link"]: if "link" in args and args["link"]:
for link in args["link"]: for link in args["link"]:
edges = [e for e in edges if e[1].node_id != command.command.id and e[1].field != link[2]] edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
edges.append( edges.append(
( Edge(
EdgeConnection(node_id=link[1], field=link[0]), source=EdgeConnection(node_id=link[1], field=link[0]),
EdgeConnection( destination=EdgeConnection(
node_id=command.command.id, field=link[2] node_id=command.command.id, field=link[2]
), )
) )
) )

View File

@ -4,6 +4,8 @@ from datetime import datetime, timezone
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
import numpy as np import numpy as np
from torch import Tensor
from PIL import Image from PIL import Image
from pydantic import Field from pydantic import Field
from skimage.exposure.histogram_matching import match_histograms from skimage.exposure.histogram_matching import match_histograms
@ -12,7 +14,9 @@ from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers()) tuple(InvokeAIGenerator.schedulers())
@ -41,18 +45,32 @@ class TextToImageInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress( def dispatch_progress(
self, context: InvocationContext, sample: Any = None, step: int = 0 self, context: InvocationContext, sample: Tensor, step: int
) -> None: ) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress( context.services.events.emit_generator_progress(
context.graph_execution_state_id, context.graph_execution_state_id,
self.id, self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step, step,
float(step) / float(self.steps), self.steps,
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(sample, step=0): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, sample, step) self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache

View File

@ -1,7 +1,10 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict from typing import Any, Dict, TypedDict
ProgressImage = TypedDict(
"ProgressImage", {"dataURL": str, "width": int, "height": int}
)
class EventServiceBase: class EventServiceBase:
session_event: str = "session_event" session_event: str = "session_event"
@ -23,8 +26,9 @@ class EventServiceBase:
self, self,
graph_execution_state_id: str, graph_execution_state_id: str,
invocation_id: str, invocation_id: str,
progress_image: ProgressImage | None,
step: int, step: int,
percent: float, total_steps: int,
) -> None: ) -> None:
"""Emitted when there is generation progress""" """Emitted when there is generation progress"""
self.__emit_session_event( self.__emit_session_event(
@ -32,8 +36,9 @@ class EventServiceBase:
payload=dict( payload=dict(
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id, invocation_id=invocation_id,
progress_image=progress_image,
step=step, step=step,
percent=percent, total_steps=total_steps,
), ),
) )

View File

@ -44,6 +44,11 @@ class EdgeConnection(BaseModel):
return hash(f"{self.node_id}.{self.field}") return hash(f"{self.node_id}.{self.field}")
class Edge(BaseModel):
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")
def get_output_field(node: BaseInvocation, field: str) -> Any: def get_output_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node) node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_type()) node_outputs = get_type_hints(node_type.get_output_type())
@ -194,7 +199,7 @@ class Graph(BaseModel):
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field( nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
description="The nodes in this graph", default_factory=dict description="The nodes in this graph", default_factory=dict
) )
edges: list[tuple[EdgeConnection, EdgeConnection]] = Field( edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph", description="The connections between nodes and their fields in this graph",
default_factory=list, default_factory=list,
) )
@ -251,7 +256,7 @@ class Graph(BaseModel):
except NodeNotFoundError: except NodeNotFoundError:
pass # Ignore, not doesn't exist (should this throw?) pass # Ignore, not doesn't exist (should this throw?)
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: def add_edge(self, edge: Edge) -> None:
"""Adds an edge to a graph """Adds an edge to a graph
:raises InvalidEdgeError: the provided edge is invalid. :raises InvalidEdgeError: the provided edge is invalid.
@ -262,7 +267,7 @@ class Graph(BaseModel):
else: else:
raise InvalidEdgeError() raise InvalidEdgeError()
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: def delete_edge(self, edge: Edge) -> None:
"""Deletes an edge from a graph""" """Deletes an edge from a graph"""
try: try:
@ -280,7 +285,7 @@ class Graph(BaseModel):
# Validate all edges reference nodes in the graph # Validate all edges reference nodes in the graph
node_ids = set( node_ids = set(
[e[0].node_id for e in self.edges] + [e[1].node_id for e in self.edges] [e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
) )
if not all((self.has_node(node_id) for node_id in node_ids)): if not all((self.has_node(node_id) for node_id in node_ids)):
return False return False
@ -294,10 +299,10 @@ class Graph(BaseModel):
if not all( if not all(
( (
are_connections_compatible( are_connections_compatible(
self.get_node(e[0].node_id), self.get_node(e.source.node_id),
e[0].field, e.source.field,
self.get_node(e[1].node_id), self.get_node(e.destination.node_id),
e[1].field, e.destination.field,
) )
for e in self.edges for e in self.edges
) )
@ -328,58 +333,58 @@ class Graph(BaseModel):
return True return True
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool: def _is_edge_valid(self, edge: Edge) -> bool:
"""Validates that a new edge doesn't create a cycle in the graph""" """Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly) # Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
try: try:
from_node = self.get_node(edge[0].node_id) from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge[1].node_id) to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError: except NodeNotFoundError:
return False return False
# Validate that an edge to this node+field doesn't already exist # Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge[1].node_id, edge[1].field) input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
return False return False
# Validate that no cycles would be created # Validate that no cycles would be created
g = self.nx_graph_flat() g = self.nx_graph_flat()
g.add_edge(edge[0].node_id, edge[1].node_id) g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g): if not nx.is_directed_acyclic_graph(g):
return False return False
# Validate that the field types are compatible # Validate that the field types are compatible
if not are_connections_compatible( if not are_connections_compatible(
from_node, edge[0].field, to_node, edge[1].field from_node, edge.source.field, to_node, edge.destination.field
): ):
return False return False
# Validate if iterator output type matches iterator input type (if this edge results in both being set) # Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge[1].field == "collection": if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid( if not self._is_iterator_connection_valid(
edge[1].node_id, new_input=edge[0] edge.destination.node_id, new_input=edge.source
): ):
return False return False
# Validate if iterator input type matches output type (if this edge results in both being set) # Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge[0].field == "item": if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid( if not self._is_iterator_connection_valid(
edge[0].node_id, new_output=edge[1] edge.source.node_id, new_output=edge.destination
): ):
return False return False
# Validate if collector input type matches output type (if this edge results in both being set) # Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge[1].field == "item": if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid( if not self._is_collector_connection_valid(
edge[1].node_id, new_input=edge[0] edge.destination.node_id, new_input=edge.source
): ):
return False return False
# Validate if collector output type matches input type (if this edge results in both being set) # Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge[0].field == "collection": if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid( if not self._is_collector_connection_valid(
edge[0].node_id, new_output=edge[1] edge.source.node_id, new_output=edge.destination
): ):
return False return False
@ -438,15 +443,15 @@ class Graph(BaseModel):
# Remove the graph prefix from the node path # Remove the graph prefix from the node path
new_graph_node_path = ( new_graph_node_path = (
new_node.id new_node.id
if "." not in edge[1].node_id if "." not in edge.destination.node_id
else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}' else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
) )
graph.add_edge( graph.add_edge(
( Edge(
edge[0], source=edge.source,
EdgeConnection( destination=EdgeConnection(
node_id=new_graph_node_path, field=edge[1].field node_id=new_graph_node_path, field=edge.destination.field
), )
) )
) )
@ -454,51 +459,51 @@ class Graph(BaseModel):
# Remove the graph prefix from the node path # Remove the graph prefix from the node path
new_graph_node_path = ( new_graph_node_path = (
new_node.id new_node.id
if "." not in edge[0].node_id if "." not in edge.source.node_id
else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}' else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
) )
graph.add_edge( graph.add_edge(
( Edge(
EdgeConnection( source=EdgeConnection(
node_id=new_graph_node_path, field=edge[0].field node_id=new_graph_node_path, field=edge.source.field
), ),
edge[1], destination=edge.destination
) )
) )
def _get_input_edges( def _get_input_edges(
self, node_path: str, field: Optional[str] = None self, node_path: str, field: Optional[str] = None
) -> list[tuple[EdgeConnection, EdgeConnection]]: ) -> list[Edge]:
"""Gets all input edges for a node""" """Gets all input edges for a node"""
edges = self._get_input_edges_and_graphs(node_path) edges = self._get_input_edges_and_graphs(node_path)
# Filter to edges that match the field # Filter to edges that match the field
filtered_edges = (e for e in edges if field is None or e[2][1].field == field) filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
# Create full node paths for each edge # Create full node paths for each edge
return [ return [
( Edge(
EdgeConnection( source=EdgeConnection(
node_id=self._get_node_path(e[0].node_id, prefix=prefix), node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e[0].field, field=e.source.field,
),
EdgeConnection(
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
field=e[1].field,
), ),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
)
) )
for _, prefix, e in filtered_edges for _, prefix, e in filtered_edges
] ]
def _get_input_edges_and_graphs( def _get_input_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]: ) -> list[tuple["Graph", str, Edge]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path""" """Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = list() edges = list()
# Return any input edges that appear in this graph # Return any input edges that appear in this graph
edges.extend( edges.extend(
[(self, prefix, e) for e in self.edges if e[1].node_id == node_path] [(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
) )
node_id = ( node_id = (
@ -522,37 +527,37 @@ class Graph(BaseModel):
def _get_output_edges( def _get_output_edges(
self, node_path: str, field: str self, node_path: str, field: str
) -> list[tuple[EdgeConnection, EdgeConnection]]: ) -> list[Edge]:
"""Gets all output edges for a node""" """Gets all output edges for a node"""
edges = self._get_output_edges_and_graphs(node_path) edges = self._get_output_edges_and_graphs(node_path)
# Filter to edges that match the field # Filter to edges that match the field
filtered_edges = (e for e in edges if e[2][0].field == field) filtered_edges = (e for e in edges if e[2].source.field == field)
# Create full node paths for each edge # Create full node paths for each edge
return [ return [
( Edge(
EdgeConnection( source=EdgeConnection(
node_id=self._get_node_path(e[0].node_id, prefix=prefix), node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e[0].field, field=e.source.field,
),
EdgeConnection(
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
field=e[1].field,
), ),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
)
) )
for _, prefix, e in filtered_edges for _, prefix, e in filtered_edges
] ]
def _get_output_edges_and_graphs( def _get_output_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]: ) -> list[tuple["Graph", str, Edge]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path""" """Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = list() edges = list()
# Return any input edges that appear in this graph # Return any input edges that appear in this graph
edges.extend( edges.extend(
[(self, prefix, e) for e in self.edges if e[0].node_id == node_path] [(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
) )
node_id = ( node_id = (
@ -580,8 +585,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None, new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None,
) -> bool: ) -> bool:
inputs = list([e[0] for e in self._get_input_edges(node_path, "collection")]) inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
outputs = list([e[1] for e in self._get_output_edges(node_path, "item")]) outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
if new_input is not None: if new_input is not None:
inputs.append(new_input) inputs.append(new_input)
@ -622,8 +627,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None, new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None,
) -> bool: ) -> bool:
inputs = list([e[0] for e in self._get_input_edges(node_path, "item")]) inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
outputs = list([e[1] for e in self._get_output_edges(node_path, "collection")]) outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
if new_input is not None: if new_input is not None:
inputs.append(new_input) inputs.append(new_input)
@ -684,7 +689,7 @@ class Graph(BaseModel):
# TODO: Cache this? # TODO: Cache this?
g = nx.DiGraph() g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.keys()]) g.add_nodes_from([n for n in self.nodes.keys()])
g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges])) g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g return g
def nx_graph_flat( def nx_graph_flat(
@ -711,7 +716,7 @@ class Graph(BaseModel):
# TODO: figure out if iteration nodes need to be expanded # TODO: figure out if iteration nodes need to be expanded
unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges]) unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
g.add_edges_from( g.add_edges_from(
[ [
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) (self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
@ -768,6 +773,24 @@ class GraphExecutionState(BaseModel):
default_factory=dict, default_factory=dict,
) )
# Declare all fields as required; necessary for OpenAPI schema generation build.
# Technically only fields without a `default_factory` need to be listed here.
# See: https://github.com/pydantic/pydantic/discussions/4577
class Config:
schema_extra = {
'required': [
'id',
'graph',
'execution_graph',
'executed',
'executed_history',
'results',
'errors',
'prepared_source_mapping',
'source_prepared_mapping',
]
}
def next(self) -> BaseInvocation | None: def next(self) -> BaseInvocation | None:
"""Gets the next node ready to execute.""" """Gets the next node ready to execute."""
@ -841,13 +864,13 @@ class GraphExecutionState(BaseModel):
input_collection_prepared_node_id = next( input_collection_prepared_node_id = next(
n[1] n[1]
for n in iteration_node_map for n in iteration_node_map
if n[0] == input_collection_edge[0].node_id if n[0] == input_collection_edge.source.node_id
) )
input_collection_prepared_node_output = self.results[ input_collection_prepared_node_output = self.results[
input_collection_prepared_node_id input_collection_prepared_node_id
] ]
input_collection = getattr( input_collection = getattr(
input_collection_prepared_node_output, input_collection_edge[0].field input_collection_prepared_node_output, input_collection_edge.source.field
) )
self_iteration_count = len(input_collection) self_iteration_count = len(input_collection)
@ -864,11 +887,11 @@ class GraphExecutionState(BaseModel):
new_edges = list() new_edges = list()
for edge in input_edges: for edge in input_edges:
for input_node_id in ( for input_node_id in (
n[1] for n in iteration_node_map if n[0] == edge[0].node_id n[1] for n in iteration_node_map if n[0] == edge.source.node_id
): ):
new_edge = ( new_edge = Edge(
EdgeConnection(node_id=input_node_id, field=edge[0].field), source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
EdgeConnection(node_id="", field=edge[1].field), destination=EdgeConnection(node_id="", field=edge.destination.field),
) )
new_edges.append(new_edge) new_edges.append(new_edge)
@ -893,9 +916,9 @@ class GraphExecutionState(BaseModel):
# Add new edges to execution graph # Add new edges to execution graph
for edge in new_edges: for edge in new_edges:
new_edge = ( new_edge = Edge(
edge[0], source=edge.source,
EdgeConnection(node_id=new_node.id, field=edge[1].field), destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
) )
self.execution_graph.add_edge(new_edge) self.execution_graph.add_edge(new_edge)
@ -1043,26 +1066,26 @@ class GraphExecutionState(BaseModel):
return self.execution_graph.nodes[next_node] return self.execution_graph.nodes[next_node]
def _prepare_inputs(self, node: BaseInvocation): def _prepare_inputs(self, node: BaseInvocation):
input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id] input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
if isinstance(node, CollectInvocation): if isinstance(node, CollectInvocation):
output_collection = [ output_collection = [
getattr(self.results[edge[0].node_id], edge[0].field) getattr(self.results[edge.source.node_id], edge.source.field)
for edge in input_edges for edge in input_edges
if edge[1].field == "item" if edge.destination.field == "item"
] ]
setattr(node, "collection", output_collection) setattr(node, "collection", output_collection)
else: else:
for edge in input_edges: for edge in input_edges:
output_value = getattr(self.results[edge[0].node_id], edge[0].field) output_value = getattr(self.results[edge.source.node_id], edge.source.field)
setattr(node, edge[1].field, output_value) setattr(node, edge.destination.field, output_value)
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state # TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool: def _is_edge_valid(self, edge: Edge) -> bool:
if not self._is_edge_valid(edge): if not self._is_edge_valid(edge):
return False return False
# Invalid if destination has already been prepared or executed # Invalid if destination has already been prepared or executed
if edge[1].node_id in self.source_prepared_mapping: if edge.destination.node_id in self.source_prepared_mapping:
return False return False
# Otherwise, the edge is valid # Otherwise, the edge is valid
@ -1089,17 +1112,17 @@ class GraphExecutionState(BaseModel):
) )
self.graph.delete_node(node_path) self.graph.delete_node(node_path)
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: def add_edge(self, edge: Edge) -> None:
if not self._is_node_updatable(edge[1].node_id): if not self._is_node_updatable(edge.destination.node_id):
raise NodeAlreadyExecutedError( raise NodeAlreadyExecutedError(
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to" f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot be linked to"
) )
self.graph.add_edge(edge) self.graph.add_edge(edge)
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: def delete_edge(self, edge: Edge) -> None:
if not self._is_node_updatable(edge[1].node_id): if not self._is_node_updatable(edge.destination.node_id):
raise NodeAlreadyExecutedError( raise NodeAlreadyExecutedError(
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted" f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted"
) )
self.graph.delete_edge(edge) self.graph.delete_edge(edge)

View File

@ -490,7 +490,7 @@ class Args(object):
"-z", "-z",
type=int, type=int,
default=6, default=6,
choices=range(0, 9), choices=range(0, 10),
dest="png_compression", dest="png_compression",
help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.", help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.",
) )
@ -943,7 +943,6 @@ class Args(object):
"--png_compression", "--png_compression",
"-z", "-z",
type=int, type=int,
default=6,
choices=range(0, 10), choices=range(0, 10),
dest="png_compression", dest="png_compression",
help="level of PNG compression, from 0 (none) to 9 (maximum). [6]", help="level of PNG compression, from 0 (none) to 9 (maximum). [6]",

View File

@ -497,7 +497,8 @@ class Generator:
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask) matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
return matched_result return matched_result
def sample_to_lowres_estimated_image(self, samples): @staticmethod
def sample_to_lowres_estimated_image(samples):
# origingally adapted from code by @erucipe and @keturn here: # origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7

View File

@ -159,6 +159,7 @@ class Inpaint(Img2Img):
seam_size: int, seam_size: int,
seam_blur: int, seam_blur: int,
prompt, prompt,
seed,
sampler, sampler,
steps, steps,
cfg_scale, cfg_scale,
@ -192,7 +193,7 @@ class Inpaint(Img2Img):
seam_noise = self.get_noise(im.width, im.height) seam_noise = self.get_noise(im.width, im.height)
result = make_image(seam_noise) result = make_image(seam_noise, seed)
return result return result
@ -342,6 +343,7 @@ class Inpaint(Img2Img):
seam_size, seam_size,
seam_blur, seam_blur,
prompt, prompt,
seed,
sampler, sampler,
seam_steps, seam_steps,
cfg_scale, cfg_scale,

View File

@ -1086,9 +1086,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
dlogging.set_verbosity_error() dlogging.set_verbosity_error()
checkpoint = ( checkpoint = (
load_file(checkpoint_path) torch.load(checkpoint_path)
if Path(checkpoint_path).suffix == ".safetensors" if Path(checkpoint_path).suffix == ".ckpt"
else torch.load(checkpoint_path) else load_file(checkpoint_path)
) )
cache_dir = global_cache_dir("hub") cache_dir = global_cache_dir("hub")
pipeline_class = ( pipeline_class = (

View File

@ -730,9 +730,9 @@ v Apply picklescanner to the indicated checkpoint and issue a warning
# another round of heuristics to guess the correct config file. # another round of heuristics to guess the correct config file.
checkpoint = ( checkpoint = (
safetensors.torch.load_file(model_path) torch.load(model_path)
if model_path.suffix == ".safetensors" if model_path.suffix == ".ckpt"
else torch.load(model_path) else safetensors.torch.load_file(model_path)
) )
# additional probing needed if no config file provided # additional probing needed if no config file provided

View File

@ -3,6 +3,9 @@ import math
import multiprocessing as mp import multiprocessing as mp
import os import os
import re import re
import io
import base64
from collections import abc from collections import abc
from inspect import isfunction from inspect import isfunction
from pathlib import Path from pathlib import Path
@ -364,3 +367,16 @@ def url_attachment_name(url: str) -> dict:
def download_with_progress_bar(url: str, dest: Path) -> bool: def download_with_progress_bar(url: str, dest: Path) -> bool:
result = download_with_resume(url, dest, access_token=None) result = download_with_resume(url, dest, access_token=None)
return result is not None return result is not None
def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
"""
Converts an image into a base64 image dataURL.
"""
buffered = io.BytesIO()
image.save(buffered, format=image_format)
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
return image_base64

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-2ad84bef.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-e63a2dc4.js";var Or=` import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-f7f41e1f.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-eaf47ae3.js";var Or=`
:root { :root {
--chakra-vh: 100vh; --chakra-vh: 100vh;
} }

View File

@ -12,7 +12,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-2ad84bef.js"></script> <script type="module" crossorigin src="./assets/index-f7f41e1f.js"></script>
<link rel="stylesheet" href="./assets/index-5483945c.css"> <link rel="stylesheet" href="./assets/index-5483945c.css">
</head> </head>

View File

@ -64,6 +64,8 @@
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.", "trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
"upload": "Upload", "upload": "Upload",
"close": "Close", "close": "Close",
"cancel": "Cancel",
"accept": "Accept",
"load": "Load", "load": "Load",
"back": "Back", "back": "Back",
"statusConnected": "Connected", "statusConnected": "Connected",
@ -333,6 +335,7 @@
"addNewModel": "Add New Model", "addNewModel": "Add New Model",
"addCheckpointModel": "Add Checkpoint / Safetensor Model", "addCheckpointModel": "Add Checkpoint / Safetensor Model",
"addDiffuserModel": "Add Diffusers", "addDiffuserModel": "Add Diffusers",
"scanForModels": "Scan For Models",
"addManually": "Add Manually", "addManually": "Add Manually",
"manual": "Manual", "manual": "Manual",
"name": "Name", "name": "Name",

View File

@ -1,3 +1,7 @@
import React, { PropsWithChildren } from 'react';
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
export {}; export {};
declare module 'redux-socket.io-middleware'; declare module 'redux-socket.io-middleware';
@ -40,5 +44,35 @@ declare global {
/* eslint-enable @typescript-eslint/no-explicit-any */ /* eslint-enable @typescript-eslint/no-explicit-any */
} }
declare function Invoke(): React.JSX; declare module '@invoke-ai/invoke-ai-ui' {
declare class ThemeChanger extends React.Component<ThemeChangerProps> {
public constructor(props: ThemeChangerProps);
}
declare class InvokeAiLogoComponent extends React.Component<InvokeAILogoComponentProps> {
public constructor(props: InvokeAILogoComponentProps);
}
declare class IAIPopover extends React.Component<IAIPopoverProps> {
public constructor(props: IAIPopoverProps);
}
declare class IAIIconButton extends React.Component<IAIIconButtonProps> {
public constructor(props: IAIIconButtonProps);
}
declare class SettingsModal extends React.Component<SettingsModalProps> {
public constructor(props: SettingsModalProps);
}
}
declare function Invoke(props: PropsWithChildren): JSX.Element;
export {
ThemeChanger,
InvokeAiLogoComponent,
IAIPopover,
IAIIconButton,
SettingsModal,
};
export = Invoke; export = Invoke;

View File

@ -6,7 +6,6 @@
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky", "prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"", "dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build", "build": "yarn run lint && vite build",
"build:package": "vite build --mode=package",
"preview": "vite preview", "preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx", "lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .", "lint:eslint": "eslint --max-warnings=0 .",

View File

@ -64,6 +64,8 @@
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.", "trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
"upload": "Upload", "upload": "Upload",
"close": "Close", "close": "Close",
"cancel": "Cancel",
"accept": "Accept",
"load": "Load", "load": "Load",
"back": "Back", "back": "Back",
"statusConnected": "Connected", "statusConnected": "Connected",
@ -333,6 +335,7 @@
"addNewModel": "Add New Model", "addNewModel": "Add New Model",
"addCheckpointModel": "Add Checkpoint / Safetensor Model", "addCheckpointModel": "Add Checkpoint / Safetensor Model",
"addDiffuserModel": "Add Diffusers", "addDiffuserModel": "Add Diffusers",
"scanForModels": "Scan For Models",
"addManually": "Add Manually", "addManually": "Add Manually",
"manual": "Manual", "manual": "Manual",
"name": "Name", "name": "Name",

View File

@ -14,11 +14,11 @@ import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel'; import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
import Lightbox from 'features/lightbox/components/Lightbox'; import Lightbox from 'features/lightbox/components/Lightbox';
import { useAppSelector } from './storeHooks'; import { useAppSelector } from './storeHooks';
import { useEffect } from 'react'; import { PropsWithChildren, useEffect } from 'react';
keepGUIAlive(); keepGUIAlive();
const App = () => { const App = (props: PropsWithChildren) => {
useToastWatcher(); useToastWatcher();
const currentTheme = useAppSelector((state) => state.ui.currentTheme); const currentTheme = useAppSelector((state) => state.ui.currentTheme);
@ -40,7 +40,7 @@ const App = () => {
w={APP_WIDTH} w={APP_WIDTH}
h={APP_HEIGHT} h={APP_HEIGHT}
> >
<SiteHeader /> {props.children || <SiteHeader />}
<Flex gap={4} w="full" h="full"> <Flex gap={4} w="full" h="full">
<InvokeTabs /> <InvokeTabs />
<ImageGalleryPanel /> <ImageGalleryPanel />

View File

@ -31,18 +31,14 @@ export const DIFFUSERS_SAMPLERS: Array<string> = [
]; ];
// Valid image widths // Valid image widths
export const WIDTHS: Array<number> = [ export const WIDTHS: Array<number> = Array.from(Array(65)).map(
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, (_x, i) => i * 64
1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792, );
1856, 1920, 1984, 2048,
];
// Valid image heights // Valid image heights
export const HEIGHTS: Array<number> = [ export const HEIGHTS: Array<number> = Array.from(Array(65)).map(
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, (_x, i) => i * 64
1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792, );
1856, 1920, 1984, 2048,
];
// Valid upscaling levels // Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [ export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [

View File

@ -9,6 +9,7 @@ import {
useDisclosure, useDisclosure,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { cloneElement, memo, ReactElement, ReactNode, useRef } from 'react'; import { cloneElement, memo, ReactElement, ReactNode, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import IAIButton from './IAIButton'; import IAIButton from './IAIButton';
type Props = { type Props = {
@ -22,10 +23,12 @@ type Props = {
}; };
const IAIAlertDialog = forwardRef((props: Props, ref) => { const IAIAlertDialog = forwardRef((props: Props, ref) => {
const { t } = useTranslation();
const { const {
acceptButtonText = 'Accept', acceptButtonText = t('common.accept'),
acceptCallback, acceptCallback,
cancelButtonText = 'Cancel', cancelButtonText = t('common.cancel'),
cancelCallback, cancelCallback,
children, children,
title, title,
@ -56,6 +59,7 @@ const IAIAlertDialog = forwardRef((props: Props, ref) => {
isOpen={isOpen} isOpen={isOpen}
leastDestructiveRef={cancelRef} leastDestructiveRef={cancelRef}
onClose={onClose} onClose={onClose}
isCentered
> >
<AlertDialogOverlay> <AlertDialogOverlay>
<AlertDialogContent> <AlertDialogContent>

View File

@ -0,0 +1,8 @@
import { chakra } from '@chakra-ui/react';
/**
* Chakra-enabled <form />
*/
const IAIForm = chakra.form;
export default IAIForm;

View File

@ -0,0 +1,23 @@
import { Flex } from '@chakra-ui/react';
import { ReactElement } from 'react';
export function IAIFormItemWrapper({
children,
}: {
children: ReactElement | ReactElement[];
}) {
return (
<Flex
sx={{
flexDirection: 'column',
padding: 4,
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: 'base.900',
}}
>
{children}
</Flex>
);
}

View File

@ -8,7 +8,7 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { memo, ReactNode } from 'react'; import { memo, ReactNode } from 'react';
type IAIPopoverProps = PopoverProps & { export type IAIPopoverProps = PopoverProps & {
triggerComponent: ReactNode; triggerComponent: ReactNode;
triggerContainerProps?: BoxProps; triggerContainerProps?: BoxProps;
children: ReactNode; children: ReactNode;

View File

@ -1,4 +1,4 @@
import React, { lazy } from 'react'; import React, { lazy, PropsWithChildren } from 'react';
import { Provider } from 'react-redux'; import { Provider } from 'react-redux';
import { PersistGate } from 'redux-persist/integration/react'; import { PersistGate } from 'redux-persist/integration/react';
import { store } from './app/store'; import { store } from './app/store';
@ -21,14 +21,14 @@ import './i18n';
const App = lazy(() => import('./app/App')); const App = lazy(() => import('./app/App'));
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider')); const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
export default function Component() { export default function Component(props: PropsWithChildren) {
return ( return (
<React.StrictMode> <React.StrictMode>
<Provider store={store}> <Provider store={store}>
<PersistGate loading={<Loading />} persistor={persistor}> <PersistGate loading={<Loading />} persistor={persistor}>
<React.Suspense fallback={<Loading showText />}> <React.Suspense fallback={<Loading showText />}>
<ThemeLocaleProvider> <ThemeLocaleProvider>
<App /> <App>{props.children}</App>
</ThemeLocaleProvider> </ThemeLocaleProvider>
</React.Suspense> </React.Suspense>
</PersistGate> </PersistGate>

View File

@ -0,0 +1,16 @@
import Component from './component';
import InvokeAiLogoComponent from './features/system/components/InvokeAILogoComponent';
import ThemeChanger from './features/system/components/ThemeChanger';
import IAIPopover from './common/components/IAIPopover';
import IAIIconButton from './common/components/IAIIconButton';
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
export default Component;
export {
InvokeAiLogoComponent,
ThemeChanger,
IAIPopover,
IAIIconButton,
SettingsModal,
};

View File

@ -104,7 +104,6 @@ const IAICanvasMaskOptions = () => {
return ( return (
<IAIPopover <IAIPopover
trigger="hover"
triggerComponent={ triggerComponent={
<ButtonGroup> <ButtonGroup>
<IAIIconButton <IAIIconButton

View File

@ -88,7 +88,7 @@ const IAICanvasSettingsButtonPopover = () => {
return ( return (
<IAIPopover <IAIPopover
trigger="hover" isLazy={false}
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
tooltip={t('unifiedCanvas.canvasSettings')} tooltip={t('unifiedCanvas.canvasSettings')}

View File

@ -219,7 +219,6 @@ const IAICanvasToolChooserOptions = () => {
onClick={handleSelectColorPickerTool} onClick={handleSelectColorPickerTool}
/> />
<IAIPopover <IAIPopover
trigger="hover"
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
aria-label={t('unifiedCanvas.brushOptions')} aria-label={t('unifiedCanvas.brushOptions')}

View File

@ -405,7 +405,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
> >
<ButtonGroup isAttached={true}> <ButtonGroup isAttached={true}>
<IAIPopover <IAIPopover
trigger="hover"
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
aria-label={`${t('parameters.sendTo')}...`} aria-label={`${t('parameters.sendTo')}...`}
@ -505,7 +504,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<ButtonGroup isAttached={true}> <ButtonGroup isAttached={true}>
<IAIPopover <IAIPopover
trigger="hover"
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
icon={<FaGrinStars />} icon={<FaGrinStars />}
@ -535,7 +533,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
</IAIPopover> </IAIPopover>
<IAIPopover <IAIPopover
trigger="hover"
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
icon={<FaExpandArrowsAlt />} icon={<FaExpandArrowsAlt />}

View File

@ -0,0 +1,24 @@
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
type CurrentImageFallbackProps = SpinnerProps;
const CurrentImageFallback = (props: CurrentImageFallbackProps) => {
const { size = 'xl', ...rest } = props;
return (
<Flex
sx={{
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'absolute',
color: 'base.400',
}}
>
<Spinner size={size} {...rest} />
</Flex>
);
};
export default CurrentImageFallback;

View File

@ -7,6 +7,7 @@ import { isEqual } from 'lodash';
import { APP_METADATA_HEIGHT } from 'theme/util/constants'; import { APP_METADATA_HEIGHT } from 'theme/util/constants';
import { gallerySelector } from '../store/gallerySelectors'; import { gallerySelector } from '../store/gallerySelectors';
import CurrentImageFallback from './CurrentImageFallback';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer'; import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons'; import NextPrevImageButtons from './NextPrevImageButtons';
@ -48,6 +49,7 @@ export default function CurrentImagePreview() {
src={imageToDisplay.url} src={imageToDisplay.url}
width={imageToDisplay.width} width={imageToDisplay.width}
height={imageToDisplay.height} height={imageToDisplay.height}
fallback={!isIntermediate ? <CurrentImageFallback /> : undefined}
sx={{ sx={{
objectFit: 'contain', objectFit: 'contain',
maxWidth: '100%', maxWidth: '100%',

View File

@ -55,7 +55,6 @@ export default function LanguagePicker() {
return ( return (
<IAIPopover <IAIPopover
trigger="hover"
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
aria-label={t('common.languagePickerLabel')} aria-label={t('common.languagePickerLabel')}

View File

@ -1,4 +1,5 @@
import { import {
Flex,
FormControl, FormControl,
FormErrorMessage, FormErrorMessage,
FormHelperText, FormHelperText,
@ -25,10 +26,10 @@ import { useTranslation } from 'react-i18next';
import type { InvokeModelConfigProps } from 'app/invokeai'; import type { InvokeModelConfigProps } from 'app/invokeai';
import type { RootState } from 'app/store'; import type { RootState } from 'app/store';
import IAIIconButton from 'common/components/IAIIconButton';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice'; import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import type { FieldInputProps, FormikProps } from 'formik'; import type { FieldInputProps, FormikProps } from 'formik';
import { BiArrowBack } from 'react-icons/bi'; import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
const MIN_MODEL_SIZE = 64; const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048; const MAX_MODEL_SIZE = 2048;
@ -72,38 +73,32 @@ export default function AddCheckpointModel() {
return ( return (
<VStack gap={2} alignItems="flex-start"> <VStack gap={2} alignItems="flex-start">
<IAIIconButton <Flex columnGap={4}>
aria-label={t('common.back')} <IAICheckbox
tooltip={t('common.back')} isChecked={!addManually}
onClick={() => dispatch(setAddNewModelUIOption(null))} label={t('modelManager.scanForModels')}
width="max-content" onChange={() => setAddmanually(!addManually)}
position="absolute"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={3}
icon={<BiArrowBack />}
/> />
<SearchModels />
<IAICheckbox <IAICheckbox
label={t('modelManager.addManually')} label={t('modelManager.addManually')}
isChecked={addManually} isChecked={addManually}
onChange={() => setAddmanually(!addManually)} onChange={() => setAddmanually(!addManually)}
/> />
</Flex>
{addManually && ( {addManually ? (
<Formik <Formik
initialValues={addModelFormValues} initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler} onSubmit={addModelFormSubmitHandler}
> >
{({ handleSubmit, errors, touched }) => ( {({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}> <IAIForm onSubmit={handleSubmit} sx={{ w: 'full' }}>
<VStack rowGap={2}> <VStack rowGap={2}>
<Text fontSize={20} fontWeight="bold" alignSelf="start"> <Text fontSize={20} fontWeight="bold" alignSelf="start">
{t('modelManager.manual')} {t('modelManager.manual')}
</Text> </Text>
{/* Name */} {/* Name */}
<IAIFormItemWrapper>
<FormControl <FormControl
isInvalid={!!errors.name && touched.name} isInvalid={!!errors.name && touched.name}
isRequired isRequired
@ -118,7 +113,7 @@ export default function AddCheckpointModel() {
name="name" name="name"
type="text" type="text"
validate={baseValidation} validate={baseValidation}
width="2xl" width="full"
/> />
{!!errors.name && touched.name ? ( {!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage> <FormErrorMessage>{errors.name}</FormErrorMessage>
@ -129,8 +124,10 @@ export default function AddCheckpointModel() {
)} )}
</VStack> </VStack>
</FormControl> </FormControl>
</IAIFormItemWrapper>
{/* Description */} {/* Description */}
<IAIFormItemWrapper>
<FormControl <FormControl
isInvalid={!!errors.description && touched.description} isInvalid={!!errors.description && touched.description}
isRequired isRequired
@ -144,10 +141,12 @@ export default function AddCheckpointModel() {
id="description" id="description"
name="description" name="description"
type="text" type="text"
width="2xl" width="full"
/> />
{!!errors.description && touched.description ? ( {!!errors.description && touched.description ? (
<FormErrorMessage>{errors.description}</FormErrorMessage> <FormErrorMessage>
{errors.description}
</FormErrorMessage>
) : ( ) : (
<FormHelperText margin={0}> <FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')} {t('modelManager.descriptionValidationMsg')}
@ -155,8 +154,10 @@ export default function AddCheckpointModel() {
)} )}
</VStack> </VStack>
</FormControl> </FormControl>
</IAIFormItemWrapper>
{/* Config */} {/* Config */}
<IAIFormItemWrapper>
<FormControl <FormControl
isInvalid={!!errors.config && touched.config} isInvalid={!!errors.config && touched.config}
isRequired isRequired
@ -170,7 +171,7 @@ export default function AddCheckpointModel() {
id="config" id="config"
name="config" name="config"
type="text" type="text"
width="2xl" width="full"
/> />
{!!errors.config && touched.config ? ( {!!errors.config && touched.config ? (
<FormErrorMessage>{errors.config}</FormErrorMessage> <FormErrorMessage>{errors.config}</FormErrorMessage>
@ -181,8 +182,10 @@ export default function AddCheckpointModel() {
)} )}
</VStack> </VStack>
</FormControl> </FormControl>
</IAIFormItemWrapper>
{/* Weights */} {/* Weights */}
<IAIFormItemWrapper>
<FormControl <FormControl
isInvalid={!!errors.weights && touched.weights} isInvalid={!!errors.weights && touched.weights}
isRequired isRequired
@ -196,7 +199,7 @@ export default function AddCheckpointModel() {
id="weights" id="weights"
name="weights" name="weights"
type="text" type="text"
width="2xl" width="full"
/> />
{!!errors.weights && touched.weights ? ( {!!errors.weights && touched.weights ? (
<FormErrorMessage>{errors.weights}</FormErrorMessage> <FormErrorMessage>{errors.weights}</FormErrorMessage>
@ -207,8 +210,10 @@ export default function AddCheckpointModel() {
)} )}
</VStack> </VStack>
</FormControl> </FormControl>
</IAIFormItemWrapper>
{/* VAE */} {/* VAE */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.vae && touched.vae}> <FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm"> <FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')} {t('modelManager.vaeLocation')}
@ -219,7 +224,7 @@ export default function AddCheckpointModel() {
id="vae" id="vae"
name="vae" name="vae"
type="text" type="text"
width="2xl" width="full"
/> />
{!!errors.vae && touched.vae ? ( {!!errors.vae && touched.vae ? (
<FormErrorMessage>{errors.vae}</FormErrorMessage> <FormErrorMessage>{errors.vae}</FormErrorMessage>
@ -230,9 +235,11 @@ export default function AddCheckpointModel() {
)} )}
</VStack> </VStack>
</FormControl> </FormControl>
</IAIFormItemWrapper>
<HStack width="100%"> <HStack width="100%">
{/* Width */} {/* Width */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.width && touched.width}> <FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm"> <FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')} {t('modelManager.width')}
@ -252,7 +259,6 @@ export default function AddCheckpointModel() {
min={MIN_MODEL_SIZE} min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE} max={MAX_MODEL_SIZE}
step={64} step={64}
width="90%"
value={form.values.width} value={form.values.width}
onChange={(value) => onChange={(value) =>
form.setFieldValue(field.name, Number(value)) form.setFieldValue(field.name, Number(value))
@ -270,8 +276,10 @@ export default function AddCheckpointModel() {
)} )}
</VStack> </VStack>
</FormControl> </FormControl>
</IAIFormItemWrapper>
{/* Height */} {/* Height */}
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.height && touched.height}> <FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm"> <FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')} {t('modelManager.height')}
@ -290,7 +298,6 @@ export default function AddCheckpointModel() {
name="height" name="height"
min={MIN_MODEL_SIZE} min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE} max={MAX_MODEL_SIZE}
width="90%"
step={64} step={64}
value={form.values.height} value={form.values.height}
onChange={(value) => onChange={(value) =>
@ -309,6 +316,7 @@ export default function AddCheckpointModel() {
)} )}
</VStack> </VStack>
</FormControl> </FormControl>
</IAIFormItemWrapper>
</HStack> </HStack>
<IAIButton <IAIButton
@ -319,9 +327,11 @@ export default function AddCheckpointModel() {
{t('modelManager.addModel')} {t('modelManager.addModel')}
</IAIButton> </IAIButton>
</VStack> </VStack>
</form> </IAIForm>
)} )}
</Formik> </Formik>
) : (
<SearchModels />
)} )}
</VStack> </VStack>
); );

View File

@ -11,36 +11,14 @@ import { InvokeDiffusersModelConfigProps } from 'app/invokeai';
import { addNewModel } from 'app/socketio/actions'; import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice'; import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import { Field, Formik } from 'formik'; import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { BiArrowBack } from 'react-icons/bi';
import type { RootState } from 'app/store'; import type { RootState } from 'app/store';
import type { ReactElement } from 'react'; import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
function FormItemWrapper({
children,
}: {
children: ReactElement | ReactElement[];
}) {
return (
<Flex
sx={{
flexDirection: 'column',
padding: 4,
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: 'base.900',
}}
>
{children}
</Flex>
);
}
export default function AddDiffusersModel() { export default function AddDiffusersModel() {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -89,26 +67,14 @@ export default function AddDiffusersModel() {
return ( return (
<Flex> <Flex>
<IAIIconButton
aria-label={t('common.back')}
tooltip={t('common.back')}
onClick={() => dispatch(setAddNewModelUIOption(null))}
width="max-content"
position="absolute"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={3}
icon={<BiArrowBack />}
/>
<Formik <Formik
initialValues={addModelFormValues} initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler} onSubmit={addModelFormSubmitHandler}
> >
{({ handleSubmit, errors, touched }) => ( {({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}> <IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2}> <VStack rowGap={2}>
<FormItemWrapper> <IAIFormItemWrapper>
{/* Name */} {/* Name */}
<FormControl <FormControl
isInvalid={!!errors.name && touched.name} isInvalid={!!errors.name && touched.name}
@ -136,9 +102,9 @@ export default function AddDiffusersModel() {
)} )}
</VStack> </VStack>
</FormControl> </FormControl>
</FormItemWrapper> </IAIFormItemWrapper>
<FormItemWrapper> <IAIFormItemWrapper>
{/* Description */} {/* Description */}
<FormControl <FormControl
isInvalid={!!errors.description && touched.description} isInvalid={!!errors.description && touched.description}
@ -165,9 +131,9 @@ export default function AddDiffusersModel() {
)} )}
</VStack> </VStack>
</FormControl> </FormControl>
</FormItemWrapper> </IAIFormItemWrapper>
<FormItemWrapper> <IAIFormItemWrapper>
<Text fontWeight="bold" fontSize="sm"> <Text fontWeight="bold" fontSize="sm">
{t('modelManager.formMessageDiffusersModelLocation')} {t('modelManager.formMessageDiffusersModelLocation')}
</Text> </Text>
@ -226,9 +192,9 @@ export default function AddDiffusersModel() {
)} )}
</VStack> </VStack>
</FormControl> </FormControl>
</FormItemWrapper> </IAIFormItemWrapper>
<FormItemWrapper> <IAIFormItemWrapper>
{/* VAE Path */} {/* VAE Path */}
<Text fontWeight="bold"> <Text fontWeight="bold">
{t('modelManager.formMessageDiffusersVAELocation')} {t('modelManager.formMessageDiffusersVAELocation')}
@ -290,13 +256,13 @@ export default function AddDiffusersModel() {
)} )}
</VStack> </VStack>
</FormControl> </FormControl>
</FormItemWrapper> </IAIFormItemWrapper>
<IAIButton type="submit" isLoading={isProcessing}> <IAIButton type="submit" isLoading={isProcessing}>
{t('modelManager.addModel')} {t('modelManager.addModel')}
</IAIButton> </IAIButton>
</VStack> </VStack>
</form> </IAIForm>
)} )}
</Formik> </Formik>
</Flex> </Flex>

View File

@ -14,7 +14,7 @@ import {
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import { FaPlus } from 'react-icons/fa'; import { FaArrowLeft, FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -23,6 +23,7 @@ import type { RootState } from 'app/store';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice'; import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import AddCheckpointModel from './AddCheckpointModel'; import AddCheckpointModel from './AddCheckpointModel';
import AddDiffusersModel from './AddDiffusersModel'; import AddDiffusersModel from './AddDiffusersModel';
import IAIIconButton from 'common/components/IAIIconButton';
function AddModelBox({ function AddModelBox({
text, text,
@ -83,8 +84,22 @@ export default function AddModel() {
closeOnOverlayClick={false} closeOnOverlayClick={false}
> >
<ModalOverlay /> <ModalOverlay />
<ModalContent margin="auto" paddingInlineEnd={4}> <ModalContent margin="auto">
<ModalHeader>{t('modelManager.addNewModel')}</ModalHeader> <ModalHeader>{t('modelManager.addNewModel')} </ModalHeader>
{addNewModelUIOption !== null && (
<IAIIconButton
aria-label={t('common.back')}
tooltip={t('common.back')}
onClick={() => dispatch(setAddNewModelUIOption(null))}
position="absolute"
variant="ghost"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={2}
icon={<FaArrowLeft />}
/>
)}
<ModalCloseButton /> <ModalCloseButton />
<ModalBody> <ModalBody>
{addNewModelUIOption == null && ( {addNewModelUIOption == null && (

View File

@ -28,6 +28,7 @@ import { isEqual, pickBy } from 'lodash';
import ModelConvert from './ModelConvert'; import ModelConvert from './ModelConvert';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText'; import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage'; import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIForm from 'common/components/IAIForm';
const selector = createSelector( const selector = createSelector(
[systemSelector], [systemSelector],
@ -120,7 +121,7 @@ export default function CheckpointModelEdit() {
onSubmit={editModelFormSubmitHandler} onSubmit={editModelFormSubmitHandler}
> >
{({ handleSubmit, errors, touched }) => ( {({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}> <IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2} alignItems="start"> <VStack rowGap={2} alignItems="start">
{/* Description */} {/* Description */}
<FormControl <FormControl
@ -317,7 +318,7 @@ export default function CheckpointModelEdit() {
{t('modelManager.updateModel')} {t('modelManager.updateModel')}
</IAIButton> </IAIButton>
</VStack> </VStack>
</form> </IAIForm>
)} )}
</Formik> </Formik>
</Flex> </Flex>

View File

@ -18,6 +18,7 @@ import type { RootState } from 'app/store';
import { isEqual, pickBy } from 'lodash'; import { isEqual, pickBy } from 'lodash';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText'; import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage'; import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIForm from 'common/components/IAIForm';
const selector = createSelector( const selector = createSelector(
[systemSelector], [systemSelector],
@ -116,7 +117,7 @@ export default function DiffusersModelEdit() {
onSubmit={editModelFormSubmitHandler} onSubmit={editModelFormSubmitHandler}
> >
{({ handleSubmit, errors, touched }) => ( {({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}> <IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2} alignItems="start"> <VStack rowGap={2} alignItems="start">
{/* Description */} {/* Description */}
<FormControl <FormControl
@ -259,7 +260,7 @@ export default function DiffusersModelEdit() {
{t('modelManager.updateModel')} {t('modelManager.updateModel')}
</IAIButton> </IAIButton>
</VStack> </VStack>
</form> </IAIForm>
)} )}
</Formik> </Formik>
</Flex> </Flex>

View File

@ -12,14 +12,13 @@ import {
RadioGroup, RadioGroup,
Spacer, Spacer,
Text, Text,
VStack,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors'; import { systemSelector } from 'features/system/store/systemSelectors';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaPlus, FaSearch } from 'react-icons/fa'; import { FaSearch, FaTrash } from 'react-icons/fa';
import { addNewModel, searchForModels } from 'app/socketio/actions'; import { addNewModel, searchForModels } from 'app/socketio/actions';
import { import {
@ -34,7 +33,7 @@ import IAIInput from 'common/components/IAIInput';
import { Field, Formik } from 'formik'; import { Field, Formik } from 'formik';
import { forEach, remove } from 'lodash'; import { forEach, remove } from 'lodash';
import type { ChangeEvent, ReactNode } from 'react'; import type { ChangeEvent, ReactNode } from 'react';
import { BiReset } from 'react-icons/bi'; import IAIForm from 'common/components/IAIForm';
const existingModelsSelector = createSelector([systemSelector], (system) => { const existingModelsSelector = createSelector([systemSelector], (system) => {
const { model_list } = system; const { model_list } = system;
@ -71,7 +70,6 @@ function SearchModelEntry({
}; };
return ( return (
<VStack>
<Flex <Flex
flexDirection="column" flexDirection="column"
gap={2} gap={2}
@ -82,7 +80,7 @@ function SearchModelEntry({
paddingY={2} paddingY={2}
borderRadius={4} borderRadius={4}
> >
<Flex gap={4}> <Flex gap={4} alignItems="center" justifyContent="space-between">
<IAICheckbox <IAICheckbox
value={model.name} value={model.name}
label={<Text fontWeight={500}>{model.name}</Text>} label={<Text fontWeight={500}>{model.name}</Text>}
@ -98,7 +96,6 @@ function SearchModelEntry({
{model.location} {model.location}
</Text> </Text>
</Flex> </Flex>
</VStack>
); );
} }
@ -215,10 +212,10 @@ export default function SearchModels() {
} }
return ( return (
<> <Flex flexDirection="column" rowGap={4}>
{newFoundModels} {newFoundModels}
{shouldShowExistingModelsInSearch && existingFoundModels} {shouldShowExistingModelsInSearch && existingFoundModels}
</> </Flex>
); );
}; };
@ -245,26 +242,26 @@ export default function SearchModels() {
<Text <Text
sx={{ sx={{
fontWeight: 500, fontWeight: 500,
fontSize: 'sm',
}} }}
variant="subtext" variant="subtext"
> >
{t('modelManager.checkpointFolder')} {t('modelManager.checkpointFolder')}
</Text> </Text>
<Text sx={{ fontWeight: 500, fontSize: 'sm' }}>{searchFolder}</Text> <Text sx={{ fontWeight: 500 }}>{searchFolder}</Text>
</Flex> </Flex>
<Spacer /> <Spacer />
<IAIIconButton <IAIIconButton
aria-label={t('modelManager.scanAgain')} aria-label={t('modelManager.scanAgain')}
tooltip={t('modelManager.scanAgain')} tooltip={t('modelManager.scanAgain')}
icon={<BiReset />} icon={<FaSearch />}
fontSize={18} fontSize={18}
disabled={isProcessing} disabled={isProcessing}
onClick={() => dispatch(searchForModels(searchFolder))} onClick={() => dispatch(searchForModels(searchFolder))}
/> />
<IAIIconButton <IAIIconButton
aria-label={t('modelManager.clearCheckpointFolder')} aria-label={t('modelManager.clearCheckpointFolder')}
icon={<FaPlus style={{ transform: 'rotate(45deg)' }} />} tooltip={t('modelManager.clearCheckpointFolder')}
icon={<FaTrash />}
onClick={resetSearchModelHandler} onClick={resetSearchModelHandler}
/> />
</Flex> </Flex>
@ -276,9 +273,9 @@ export default function SearchModels() {
}} }}
> >
{({ handleSubmit }) => ( {({ handleSubmit }) => (
<form onSubmit={handleSubmit}> <IAIForm onSubmit={handleSubmit} width="100%">
<HStack columnGap={2} alignItems="flex-end" width="100%"> <HStack columnGap={2} alignItems="flex-end">
<FormControl isRequired width="lg"> <FormControl flexGrow={1}>
<Field <Field
as={IAIInput} as={IAIInput}
id="checkpointFolder" id="checkpointFolder"
@ -294,12 +291,12 @@ export default function SearchModels() {
tooltip={t('modelManager.findModels')} tooltip={t('modelManager.findModels')}
type="submit" type="submit"
disabled={isProcessing} disabled={isProcessing}
paddingX={10} px={8}
> >
{t('modelManager.findModels')} {t('modelManager.findModels')}
</IAIButton> </IAIButton>
</HStack> </HStack>
</form> </IAIForm>
)} )}
</Formik> </Formik>
)} )}
@ -410,7 +407,6 @@ export default function SearchModels() {
maxHeight={72} maxHeight={72}
overflowY="scroll" overflowY="scroll"
borderRadius="sm" borderRadius="sm"
paddingInlineEnd={4}
gap={2} gap={2}
> >
{foundModels.length > 0 ? ( {foundModels.length > 0 ? (

View File

@ -50,7 +50,6 @@ export default function ThemeChanger() {
return ( return (
<IAIPopover <IAIPopover
trigger="hover"
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
aria-label={t('common.themeLabel')} aria-label={t('common.themeLabel')}

View File

@ -166,20 +166,8 @@ export default function InvokeTabs() {
[] []
); );
/**
* isLazy means the tabs are mounted and unmounted when changing them. There is a tradeoff here,
* as mounting is expensive, but so is retaining all tabs in the DOM at all times.
*
* Removing isLazy messes with the outside click watcher, which is used by ResizableDrawer.
* Because you have multiple handlers listening for an outside click, any click anywhere triggers
* the watcher for the hidden drawers, closing the open drawer.
*
* TODO: Add logic to the `useOutsideClick` in ResizableDrawer to enable it only for the active
* tab's drawer.
*/
return ( return (
<Tabs <Tabs
isLazy
defaultIndex={activeTab} defaultIndex={activeTab}
index={activeTab} index={activeTab}
onChange={(index: number) => { onChange={(index: number) => {

View File

@ -93,12 +93,9 @@ const ResizableDrawer = ({
useOutsideClick({ useOutsideClick({
ref: outsideClickRef, ref: outsideClickRef,
handler: () => { handler: () => {
if (isPinned) {
return;
}
onClose(); onClose();
}, },
enabled: isOpen && !isPinned,
}); });
const handleEnables = useMemo( const handleEnables = useMemo(

View File

@ -77,7 +77,6 @@ export default function UnifiedCanvasColorPicker() {
return ( return (
<IAIPopover <IAIPopover
trigger="hover"
triggerComponent={ triggerComponent={
<Box <Box
sx={{ sx={{

View File

@ -56,7 +56,7 @@ const UnifiedCanvasSettings = () => {
return ( return (
<IAIPopover <IAIPopover
trigger="hover" isLazy={false}
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
tooltip={t('unifiedCanvas.canvasSettings')} tooltip={t('unifiedCanvas.canvasSettings')}

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,3 @@
import path from 'path';
import react from '@vitejs/plugin-react-swc'; import react from '@vitejs/plugin-react-swc';
import { visualizer } from 'rollup-plugin-visualizer'; import { visualizer } from 'rollup-plugin-visualizer';
import { defineConfig, PluginOption } from 'vite'; import { defineConfig, PluginOption } from 'vite';
@ -58,26 +57,6 @@ export default defineConfig(({ mode }) => {
// sourcemap: true, // this can be enabled if needed, it adds ovwer 15MB to the commit // sourcemap: true, // this can be enabled if needed, it adds ovwer 15MB to the commit
}, },
}; };
} else if (mode === 'package') {
return {
...common,
build: {
...common.build,
lib: {
entry: path.resolve(__dirname, 'src/component.tsx'),
name: 'InvokeAI UI',
fileName: (format) => `invoke-ai-ui.${format}.js`,
},
rollupOptions: {
external: ['react', 'react-dom'],
output: {
globals: {
react: 'React',
},
},
},
},
};
} else { } else {
return { return {
...common, ...common,

View File

@ -38,16 +38,16 @@ dependencies = [
"albumentations", "albumentations",
"click", "click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==0.1.10", "compel==1.0.1",
"datasets", "datasets",
"diffusers[torch]~=0.14", "diffusers[torch]~=0.14",
"dnspython==2.2.1", "dnspython==2.2.1",
"einops", "einops",
"eventlet", "eventlet",
"facexlib", "facexlib",
"fastapi==0.85.0", "fastapi==0.94.1",
"fastapi-events==0.6.0", "fastapi-events==0.8.0",
"fastapi-socketio==0.0.9", "fastapi-socketio==0.0.10",
"flask==2.1.3", "flask==2.1.3",
"flask_cors==3.0.10", "flask_cors==3.0.10",
"flask_socketio==5.3.0", "flask_socketio==5.3.0",
@ -75,7 +75,7 @@ dependencies = [
"torchvision>=0.14.1", "torchvision>=0.14.1",
"torchmetrics", "torchmetrics",
"transformers~=4.26", "transformers~=4.26",
"uvicorn[standard]==0.20.0", "uvicorn[standard]==0.21.1",
"windows-curses; sys_platform=='win32'", "windows-curses; sys_platform=='win32'",
] ]

View File

@ -105,17 +105,20 @@
// Start building nodes // Start building nodes
var id = 1; var id = 1;
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "sampler": sampler, "steps": steps, "seed": seed}; var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": seed};
id++;
var i2iNode = {"id": id.toString(), "type": "img2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": Math.floor(Math.random() * 10000)};
id++; id++;
var upscaleNode = {"id": id.toString(), "type": "show_image" }; var upscaleNode = {"id": id.toString(), "type": "show_image" };
id++ id++
nodes = {}; nodes = {};
nodes[initialNode.id] = initialNode; nodes[initialNode.id] = initialNode;
nodes[i2iNode.id] = i2iNode;
nodes[upscaleNode.id] = upscaleNode; nodes[upscaleNode.id] = upscaleNode;
links = [ links = [
[{ "node_id": initialNode.id, field: "image" }, { "source": { "node_id": initialNode.id, field: "image" }, "destination": { "node_id": i2iNode.id, field: "image" }},
{ "node_id": upscaleNode.id, field: "image" }] { "source": { "node_id": i2iNode.id, field: "image" }, "destination": { "node_id": upscaleNode.id, field: "image" }}
]; ];
// expandSize = 128; // expandSize = 128;
// for (var i = 0; i < 6; ++i) { // for (var i = 0; i < 6; ++i) {

View File

@ -1,15 +1,18 @@
from invokeai.app.invocations.image import * from invokeai.app.invocations.image import *
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from invokeai.app.invocations.upscale import UpscaleInvocation from invokeai.app.invocations.upscale import UpscaleInvocation
import pytest import pytest
# Helpers # Helpers
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]: def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field)) return Edge(
source=EdgeConnection(node_id = from_id, field = from_field),
destination=EdgeConnection(node_id = to_id, field = to_field)
)
# Tests # Tests
def test_connections_are_compatible(): def test_connections_are_compatible():
@ -108,7 +111,7 @@ def test_graph_allows_non_conflicting_id_change():
assert g.get_node("3").prompt == "Banana sushi" assert g.get_node("3").prompt == "Banana sushi"
assert len(g.edges) == 1 assert len(g.edges) == 1
assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges assert Edge(source=EdgeConnection(node_id = "3", field = "image"), destination=EdgeConnection(node_id = "2", field = "image")) in g.edges
def test_graph_fails_to_update_node_id_if_conflict(): def test_graph_fails_to_update_node_id_if_conflict():
g = Graph() g = Graph()
@ -490,10 +493,10 @@ def test_graph_can_deserialize():
assert g2.nodes['1'] is not None assert g2.nodes['1'] is not None
assert g2.nodes['2'] is not None assert g2.nodes['2'] is not None
assert len(g2.edges) == 1 assert len(g2.edges) == 1
assert g2.edges[0][0].node_id == '1' assert g2.edges[0].source.node_id == '1'
assert g2.edges[0][0].field == 'image' assert g2.edges[0].source.field == 'image'
assert g2.edges[0][1].node_id == '2' assert g2.edges[0].destination.node_id == '2'
assert g2.edges[0][1].field == 'image' assert g2.edges[0].destination.field == 'image'
def test_graph_can_generate_schema(): def test_graph_can_generate_schema():
# Not throwing on this line is sufficient # Not throwing on this line is sufficient

View File

@ -64,10 +64,12 @@ class PromptCollectionTestInvocation(BaseInvocation):
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import EdgeConnection from invokeai.app.services.graph import Edge, EdgeConnection
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]: def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field)) return Edge(
source=EdgeConnection(node_id = from_id, field = from_field),
destination=EdgeConnection(node_id = to_id, field = to_field))
class TestEvent: class TestEvent: