Merge branch 'main' into bugfix/dreambooth_ema

This commit is contained in:
Lincoln Stein 2023-03-23 23:24:15 -04:00 committed by GitHub
commit 6e7dbf99f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
72 changed files with 1060 additions and 790 deletions

View File

@ -1,6 +0,0 @@
[run]
omit='.env/*'
source='.'
[report]
show_missing = true

View File

@ -65,6 +65,16 @@ body:
placeholder: 8GB
validations:
required: false
- type: input
id: version-number
attributes:
label: What version did you experience this issue on?
description: |
Please share the version of Invoke AI that you experienced the issue on. If this is not the latest version, please update first to confirm the issue still exists. If you are testing main, please include the commit hash instead.
placeholder: X.X.X
validations:
required: true
- type: textarea
id: what-happened

2
.gitignore vendored
View File

@ -63,6 +63,7 @@ pip-delete-this-directory.txt
htmlcov/
.tox/
.nox/
.coveragerc
.coverage
.coverage.*
.cache
@ -73,6 +74,7 @@ cov.xml
*.py,cover
.hypothesis/
.pytest_cache/
.pytest.ini
cover/
junit/

View File

@ -1,5 +0,0 @@
[pytest]
DJANGO_SETTINGS_MODULE = webtas.settings
; python_files = tests.py test_*.py *_tests.py
addopts = --cov=. --cov-config=.coveragerc --cov-report xml:cov.xml

4
coverage/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore

Binary file not shown.

After

Width:  |  Height:  |  Size: 470 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 457 KiB

View File

@ -0,0 +1,83 @@
# Local Development
If you are looking to contribute you will need to have a local development
environment. See the
[Developer Install](../installation/020_INSTALL_MANUAL.md#developer-install) for
full details.
Broadly this involves cloning the repository, installing the pre-reqs, and
InvokeAI (in editable form). Assuming this is working, choose your area of
focus.
## Documentation
We use [mkdocs](https://www.mkdocs.org) for our documentation with the
[material theme](https://squidfunk.github.io/mkdocs-material/). Documentation is
written in markdown files under the `./docs` folder and then built into a static
website for hosting with GitHub Pages at
[invoke-ai.github.io/InvokeAI](https://invoke-ai.github.io/InvokeAI).
To contribute to the documentation you'll need to install the dependencies. Note
the use of `"`.
```zsh
pip install ".[docs]"
```
Now, to run the documentation locally with hot-reloading for changes made.
```zsh
mkdocs serve
```
You'll then be prompted to connect to `http://127.0.0.1:8080` in order to
access.
## Backend
The backend is contained within the `./invokeai/backend` folder structure. To
get started however please install the development dependencies.
From the root of the repository run the following command. Note the use of `"`.
```zsh
pip install ".[test]"
```
This in an optional group of packages which is defined within the
`pyproject.toml` and will be required for testing the changes you make the the
code.
### Running Tests
We use [pytest](https://docs.pytest.org/en/7.2.x/) for our test suite. Tests can
be found under the `./tests` folder and can be run with a single `pytest`
command. Optionally, to review test coverage you can append `--cov`.
```zsh
pytest --cov
```
Test outcomes and coverage will be reported in the terminal. In addition a more
detailed report is created in both XML and HTML format in the `./coverage`
folder. The HTML one in particular can help identify missing statements
requiring tests to ensure coverage. This can be run by opening
`./coverage/html/index.html`.
For example.
```zsh
pytest --cov; open ./coverage/html/index.html
```
??? info "HTML coverage report output"
![html-overview](../assets/contributing/html-overview.png)
![html-detail](../assets/contributing/html-detail.png)
## Front End
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
--8<-- "invokeai/frontend/web/README.md"

View File

@ -17,7 +17,7 @@ notebooks.
You will need a GPU to perform training in a reasonable length of
time, and at least 12 GB of VRAM. We recommend using the [`xformers`
library](../installation/070_INSTALL_XFORMERS) to accelerate the
library](../installation/070_INSTALL_XFORMERS.md) to accelerate the
training process further. During training, about ~8 GB is temporarily
needed in order to store intermediate models, checkpoints and logs.

View File

@ -24,7 +24,7 @@ You need to have opencv installed so that pypatchmatch can be built:
brew install opencv
```
The next time you start `invoke`, after sucesfully installing opencv, pypatchmatch will be built.
The next time you start `invoke`, after successfully installing opencv, pypatchmatch will be built.
## Linux
@ -56,7 +56,7 @@ Prior to installing PyPatchMatch, you need to take the following steps:
5. Confirm that pypatchmatch is installed. At the command-line prompt enter
`python`, and then at the `>>>` line type
`from patchmatch import patch_match`: It should look like the follwing:
`from patchmatch import patch_match`: It should look like the following:
```py
Python 3.9.5 (default, Nov 23 2021, 15:27:38)
@ -108,4 +108,4 @@ Prior to installing PyPatchMatch, you need to take the following steps:
[**Next, Follow Steps 4-6 from the Debian Section above**](#linux)
If you see no errors, then you're ready to go!
If you see no errors you're ready to go!

View File

@ -10,6 +10,7 @@ from pydantic.fields import Field
from ...invocations import *
from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import (
Edge,
EdgeConnection,
Graph,
GraphExecutionState,
@ -92,7 +93,7 @@ async def get_session(
async def add_node(
session_id: str = Path(description="The id of the session"),
node: Annotated[
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
] = Body(description="The node to add"),
) -> str:
"""Adds a node to the graph"""
@ -125,7 +126,7 @@ async def update_node(
session_id: str = Path(description="The id of the session"),
node_path: str = Path(description="The path to the node in the graph"),
node: Annotated[
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
] = Body(description="The new node"),
) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges"""
@ -186,7 +187,7 @@ async def delete_node(
)
async def add_edge(
session_id: str = Path(description="The id of the session"),
edge: tuple[EdgeConnection, EdgeConnection] = Body(description="The edge to add"),
edge: Edge = Body(description="The edge to add"),
) -> GraphExecutionState:
"""Adds an edge to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
@ -228,9 +229,9 @@ async def delete_edge(
return Response(status_code=404)
try:
edge = (
EdgeConnection(node_id=from_node_id, field=from_field),
EdgeConnection(node_id=to_node_id, field=to_field),
edge = Edge(
source=EdgeConnection(node_id=from_node_id, field=from_field),
destination=EdgeConnection(node_id=to_node_id, field=to_field)
)
session.delete_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set(

View File

@ -112,10 +112,8 @@ def custom_openapi():
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
if "additionalProperties" not in invoker_schema:
invoker_schema["additionalProperties"] = {}
invoker_schema["additionalProperties"]["outputs"] = outputs_ref
invoker_schema["output"] = outputs_ref
app.openapi_schema = openapi_schema
return app.openapi_schema

View File

@ -19,7 +19,7 @@ from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import EdgeConnection, GraphExecutionState
from .services.graph import Edge, EdgeConnection, GraphExecutionState
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
@ -77,7 +77,7 @@ def get_command_parser() -> argparse.ArgumentParser:
def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation
) -> list[tuple[EdgeConnection, EdgeConnection]]:
) -> list[Edge]:
"""Generates all possible edges between two invocations"""
atype = type(a)
btype = type(b)
@ -94,9 +94,9 @@ def generate_matching_edges(
matching_fields = matching_fields.difference(invalid_fields)
edges = [
(
EdgeConnection(node_id=a.id, field=field),
EdgeConnection(node_id=b.id, field=field),
Edge(
source=EdgeConnection(node_id=a.id, field=field),
destination=EdgeConnection(node_id=b.id, field=field)
)
for field in matching_fields
]
@ -111,16 +111,15 @@ class SessionError(Exception):
def invoke_all(context: CliContext):
"""Runs all invocations in the specified session"""
context.invoker.invoke(context.session, invoke_all=True)
while not context.session.is_complete():
while not context.get_session().is_complete():
# Wait some time
session = context.get_session()
time.sleep(0.1)
# Print any errors
if context.session.has_error():
for n in context.session.errors:
print(
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}"
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
)
raise SessionError()
@ -203,7 +202,7 @@ def invoke_cli():
continue
# Pipe previous command output (if there was a previous command)
edges = []
edges: list[Edge] = list()
if len(history) > 0 or current_id != start_id:
from_id = (
history[0] if current_id == start_id else str(current_id - 1)
@ -225,19 +224,19 @@ def invoke_cli():
matching_edges = generate_matching_edges(
link_node, command.command
)
matching_destinations = [e[1] for e in matching_edges]
edges = [e for e in edges if e[1] not in matching_destinations]
matching_destinations = [e.destination for e in matching_edges]
edges = [e for e in edges if e.destination not in matching_destinations]
edges.extend(matching_edges)
if "link" in args and args["link"]:
for link in args["link"]:
edges = [e for e in edges if e[1].node_id != command.command.id and e[1].field != link[2]]
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
edges.append(
(
EdgeConnection(node_id=link[1], field=link[0]),
EdgeConnection(
Edge(
source=EdgeConnection(node_id=link[1], field=link[0]),
destination=EdgeConnection(
node_id=command.command.id, field=link[2]
),
)
)
)

View File

@ -4,6 +4,8 @@ from datetime import datetime, timezone
from typing import Any, Literal, Optional, Union
import numpy as np
from torch import Tensor
from PIL import Image
from pydantic import Field
from skimage.exposure.histogram_matching import match_histograms
@ -12,7 +14,9 @@ from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers())
@ -41,18 +45,32 @@ class TextToImageInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, sample: Any = None, step: int = 0
) -> None:
self, context: InvocationContext, sample: Tensor, step: int
) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
float(step) / float(self.steps),
self.steps,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
@ -118,7 +136,7 @@ class ImageToImageInvocation(TextToImageInvocation):
generator_output = next(
Img2Img(model).generate(
prompt=self.prompt,
init_img=image,
init_image=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
@ -179,8 +197,8 @@ class InpaintInvocation(ImageToImageInvocation):
generator_output = next(
Inpaint(model).generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,
init_image=image,
mask_image=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}

View File

@ -1,7 +1,10 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict
from typing import Any, Dict, TypedDict
ProgressImage = TypedDict(
"ProgressImage", {"dataURL": str, "width": int, "height": int}
)
class EventServiceBase:
session_event: str = "session_event"
@ -23,8 +26,9 @@ class EventServiceBase:
self,
graph_execution_state_id: str,
invocation_id: str,
progress_image: ProgressImage | None,
step: int,
percent: float,
total_steps: int,
) -> None:
"""Emitted when there is generation progress"""
self.__emit_session_event(
@ -32,8 +36,9 @@ class EventServiceBase:
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
progress_image=progress_image,
step=step,
percent=percent,
total_steps=total_steps,
),
)

View File

@ -44,6 +44,11 @@ class EdgeConnection(BaseModel):
return hash(f"{self.node_id}.{self.field}")
class Edge(BaseModel):
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")
def get_output_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_type())
@ -194,7 +199,7 @@ class Graph(BaseModel):
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
description="The nodes in this graph", default_factory=dict
)
edges: list[tuple[EdgeConnection, EdgeConnection]] = Field(
edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph",
default_factory=list,
)
@ -251,7 +256,7 @@ class Graph(BaseModel):
except NodeNotFoundError:
pass # Ignore, not doesn't exist (should this throw?)
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
def add_edge(self, edge: Edge) -> None:
"""Adds an edge to a graph
:raises InvalidEdgeError: the provided edge is invalid.
@ -262,7 +267,7 @@ class Graph(BaseModel):
else:
raise InvalidEdgeError()
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
def delete_edge(self, edge: Edge) -> None:
"""Deletes an edge from a graph"""
try:
@ -280,7 +285,7 @@ class Graph(BaseModel):
# Validate all edges reference nodes in the graph
node_ids = set(
[e[0].node_id for e in self.edges] + [e[1].node_id for e in self.edges]
[e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
)
if not all((self.has_node(node_id) for node_id in node_ids)):
return False
@ -294,10 +299,10 @@ class Graph(BaseModel):
if not all(
(
are_connections_compatible(
self.get_node(e[0].node_id),
e[0].field,
self.get_node(e[1].node_id),
e[1].field,
self.get_node(e.source.node_id),
e.source.field,
self.get_node(e.destination.node_id),
e.destination.field,
)
for e in self.edges
)
@ -328,58 +333,58 @@ class Graph(BaseModel):
return True
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
def _is_edge_valid(self, edge: Edge) -> bool:
"""Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
try:
from_node = self.get_node(edge[0].node_id)
to_node = self.get_node(edge[1].node_id)
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
return False
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge[1].node_id, edge[1].field)
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
return False
# Validate that no cycles would be created
g = self.nx_graph_flat()
g.add_edge(edge[0].node_id, edge[1].node_id)
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
return False
# Validate that the field types are compatible
if not are_connections_compatible(
from_node, edge[0].field, to_node, edge[1].field
from_node, edge.source.field, to_node, edge.destination.field
):
return False
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge[1].field == "collection":
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid(
edge[1].node_id, new_input=edge[0]
edge.destination.node_id, new_input=edge.source
):
return False
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge[0].field == "item":
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid(
edge[0].node_id, new_output=edge[1]
edge.source.node_id, new_output=edge.destination
):
return False
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge[1].field == "item":
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid(
edge[1].node_id, new_input=edge[0]
edge.destination.node_id, new_input=edge.source
):
return False
# Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge[0].field == "collection":
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid(
edge[0].node_id, new_output=edge[1]
edge.source.node_id, new_output=edge.destination
):
return False
@ -438,15 +443,15 @@ class Graph(BaseModel):
# Remove the graph prefix from the node path
new_graph_node_path = (
new_node.id
if "." not in edge[1].node_id
else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}'
if "." not in edge.destination.node_id
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
)
graph.add_edge(
(
edge[0],
EdgeConnection(
node_id=new_graph_node_path, field=edge[1].field
),
Edge(
source=edge.source,
destination=EdgeConnection(
node_id=new_graph_node_path, field=edge.destination.field
)
)
)
@ -454,51 +459,51 @@ class Graph(BaseModel):
# Remove the graph prefix from the node path
new_graph_node_path = (
new_node.id
if "." not in edge[0].node_id
else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}'
if "." not in edge.source.node_id
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
)
graph.add_edge(
(
EdgeConnection(
node_id=new_graph_node_path, field=edge[0].field
Edge(
source=EdgeConnection(
node_id=new_graph_node_path, field=edge.source.field
),
edge[1],
destination=edge.destination
)
)
def _get_input_edges(
self, node_path: str, field: Optional[str] = None
) -> list[tuple[EdgeConnection, EdgeConnection]]:
) -> list[Edge]:
"""Gets all input edges for a node"""
edges = self._get_input_edges_and_graphs(node_path)
# Filter to edges that match the field
filtered_edges = (e for e in edges if field is None or e[2][1].field == field)
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
# Create full node paths for each edge
return [
(
EdgeConnection(
node_id=self._get_node_path(e[0].node_id, prefix=prefix),
field=e[0].field,
),
EdgeConnection(
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
field=e[1].field,
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
)
)
for _, prefix, e in filtered_edges
]
def _get_input_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]:
) -> list[tuple["Graph", str, Edge]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = list()
# Return any input edges that appear in this graph
edges.extend(
[(self, prefix, e) for e in self.edges if e[1].node_id == node_path]
[(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
)
node_id = (
@ -522,37 +527,37 @@ class Graph(BaseModel):
def _get_output_edges(
self, node_path: str, field: str
) -> list[tuple[EdgeConnection, EdgeConnection]]:
) -> list[Edge]:
"""Gets all output edges for a node"""
edges = self._get_output_edges_and_graphs(node_path)
# Filter to edges that match the field
filtered_edges = (e for e in edges if e[2][0].field == field)
filtered_edges = (e for e in edges if e[2].source.field == field)
# Create full node paths for each edge
return [
(
EdgeConnection(
node_id=self._get_node_path(e[0].node_id, prefix=prefix),
field=e[0].field,
),
EdgeConnection(
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
field=e[1].field,
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
)
)
for _, prefix, e in filtered_edges
]
def _get_output_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]:
) -> list[tuple["Graph", str, Edge]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = list()
# Return any input edges that appear in this graph
edges.extend(
[(self, prefix, e) for e in self.edges if e[0].node_id == node_path]
[(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
)
node_id = (
@ -580,8 +585,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = list([e[0] for e in self._get_input_edges(node_path, "collection")])
outputs = list([e[1] for e in self._get_output_edges(node_path, "item")])
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
if new_input is not None:
inputs.append(new_input)
@ -622,8 +627,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = list([e[0] for e in self._get_input_edges(node_path, "item")])
outputs = list([e[1] for e in self._get_output_edges(node_path, "collection")])
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
if new_input is not None:
inputs.append(new_input)
@ -684,7 +689,7 @@ class Graph(BaseModel):
# TODO: Cache this?
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.keys()])
g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges]))
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_flat(
@ -711,7 +716,7 @@ class Graph(BaseModel):
# TODO: figure out if iteration nodes need to be expanded
unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges])
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
g.add_edges_from(
[
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
@ -768,6 +773,24 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
# Declare all fields as required; necessary for OpenAPI schema generation build.
# Technically only fields without a `default_factory` need to be listed here.
# See: https://github.com/pydantic/pydantic/discussions/4577
class Config:
schema_extra = {
'required': [
'id',
'graph',
'execution_graph',
'executed',
'executed_history',
'results',
'errors',
'prepared_source_mapping',
'source_prepared_mapping',
]
}
def next(self) -> BaseInvocation | None:
"""Gets the next node ready to execute."""
@ -841,13 +864,13 @@ class GraphExecutionState(BaseModel):
input_collection_prepared_node_id = next(
n[1]
for n in iteration_node_map
if n[0] == input_collection_edge[0].node_id
if n[0] == input_collection_edge.source.node_id
)
input_collection_prepared_node_output = self.results[
input_collection_prepared_node_id
]
input_collection = getattr(
input_collection_prepared_node_output, input_collection_edge[0].field
input_collection_prepared_node_output, input_collection_edge.source.field
)
self_iteration_count = len(input_collection)
@ -864,11 +887,11 @@ class GraphExecutionState(BaseModel):
new_edges = list()
for edge in input_edges:
for input_node_id in (
n[1] for n in iteration_node_map if n[0] == edge[0].node_id
n[1] for n in iteration_node_map if n[0] == edge.source.node_id
):
new_edge = (
EdgeConnection(node_id=input_node_id, field=edge[0].field),
EdgeConnection(node_id="", field=edge[1].field),
new_edge = Edge(
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
destination=EdgeConnection(node_id="", field=edge.destination.field),
)
new_edges.append(new_edge)
@ -893,9 +916,9 @@ class GraphExecutionState(BaseModel):
# Add new edges to execution graph
for edge in new_edges:
new_edge = (
edge[0],
EdgeConnection(node_id=new_node.id, field=edge[1].field),
new_edge = Edge(
source=edge.source,
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
)
self.execution_graph.add_edge(new_edge)
@ -1043,26 +1066,26 @@ class GraphExecutionState(BaseModel):
return self.execution_graph.nodes[next_node]
def _prepare_inputs(self, node: BaseInvocation):
input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id]
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
if isinstance(node, CollectInvocation):
output_collection = [
getattr(self.results[edge[0].node_id], edge[0].field)
getattr(self.results[edge.source.node_id], edge.source.field)
for edge in input_edges
if edge[1].field == "item"
if edge.destination.field == "item"
]
setattr(node, "collection", output_collection)
else:
for edge in input_edges:
output_value = getattr(self.results[edge[0].node_id], edge[0].field)
setattr(node, edge[1].field, output_value)
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
setattr(node, edge.destination.field, output_value)
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
def _is_edge_valid(self, edge: Edge) -> bool:
if not self._is_edge_valid(edge):
return False
# Invalid if destination has already been prepared or executed
if edge[1].node_id in self.source_prepared_mapping:
if edge.destination.node_id in self.source_prepared_mapping:
return False
# Otherwise, the edge is valid
@ -1089,17 +1112,17 @@ class GraphExecutionState(BaseModel):
)
self.graph.delete_node(node_path)
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
if not self._is_node_updatable(edge[1].node_id):
def add_edge(self, edge: Edge) -> None:
if not self._is_node_updatable(edge.destination.node_id):
raise NodeAlreadyExecutedError(
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to"
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot be linked to"
)
self.graph.add_edge(edge)
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
if not self._is_node_updatable(edge[1].node_id):
def delete_edge(self, edge: Edge) -> None:
if not self._is_node_updatable(edge.destination.node_id):
raise NodeAlreadyExecutedError(
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted"
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted"
)
self.graph.delete_edge(edge)

View File

@ -490,7 +490,7 @@ class Args(object):
"-z",
type=int,
default=6,
choices=range(0, 9),
choices=range(0, 10),
dest="png_compression",
help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.",
)
@ -943,7 +943,6 @@ class Args(object):
"--png_compression",
"-z",
type=int,
default=6,
choices=range(0, 10),
dest="png_compression",
help="level of PNG compression, from 0 (none) to 9 (maximum). [6]",

View File

@ -58,7 +58,7 @@ class InvokeAIGeneratorOutput:
'''
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
operation, including the image, its seed, the model name used to generate the image
and the model hash, as well as all the generate() parameters that went into
and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes)
'''
image: Image
@ -116,7 +116,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
outputs = txt2img.generate(prompt='banana sushi', iterations=None)
for o in outputs:
print(o.image, o.seed)
'''
generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args)
@ -154,6 +154,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
for i in iteration_count:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
step_callback=step_callback,
sampler=scheduler,
**generator_args,
)
@ -167,7 +168,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
if callback:
callback(output)
yield output
@classmethod
def schedulers(self)->List[str]:
'''
@ -177,7 +178,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision)
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
scheduler = scheduler_class.from_config(model.scheduler.config)
@ -267,12 +268,12 @@ class Embiggen(Txt2Img):
embiggen_tiles=embiggen_tiles,
strength=strength,
**kwargs)
@classmethod
def _generator_class(cls):
from .embiggen import Embiggen
return Embiggen
class Generator:
downsampling_factor: int
@ -347,7 +348,6 @@ class Generator:
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
attention_maps_callback=attention_maps_callback,
seed=seed,
**kwargs,
)
results = []
@ -375,7 +375,8 @@ class Generator:
print("** An error occurred while getting initial noise **")
print(traceback.format_exc())
image = make_image(x_T)
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
image = make_image(x_T, seed)
if self.safety_checker is not None:
image = self.safety_checker.check(image)
@ -497,7 +498,8 @@ class Generator:
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
return matched_result
def sample_to_lowres_estimated_image(self, samples):
@staticmethod
def sample_to_lowres_estimated_image(samples):
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7

View File

@ -37,7 +37,6 @@ class Img2Img(Generator):
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
attention_maps_callback=None,
seed=None,
**kwargs,
):
"""
@ -64,7 +63,7 @@ class Img2Img(Generator):
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T):
def make_image(x_T: torch.Tensor, seed: int):
# FIXME: use x_T for initial seeded noise
# We're not at the moment because the pipeline automatically resizes init_image if
# necessary, which the x_T input might not match.
@ -77,7 +76,7 @@ class Img2Img(Generator):
conditioning_data,
noise_func=self.get_noise_like,
callback=step_callback,
seed=seed
seed=seed,
)
if (
pipeline_output.attention_map_saver is not None
@ -88,9 +87,7 @@ class Img2Img(Generator):
return make_image
def get_noise_like(self, like: torch.Tensor, seed: Optional[int]):
if seed is not None:
set_seed(seed)
def get_noise_like(self, like: torch.Tensor):
device = like.device
if device.type == "mps":
x = torch.randn_like(like, device="cpu").to(device)

View File

@ -159,6 +159,7 @@ class Inpaint(Img2Img):
seam_size: int,
seam_blur: int,
prompt,
seed,
sampler,
steps,
cfg_scale,
@ -192,7 +193,7 @@ class Inpaint(Img2Img):
seam_noise = self.get_noise(im.width, im.height)
result = make_image(seam_noise)
result = make_image(seam_noise, seed)
return result
@ -223,7 +224,6 @@ class Inpaint(Img2Img):
inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
attention_maps_callback=None,
seed=None,
**kwargs,
):
"""
@ -311,7 +311,7 @@ class Inpaint(Img2Img):
uc, c, cfg_scale
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T):
def make_image(x_T: torch.Tensor, seed: int):
pipeline_output = pipeline.inpaint_from_embeddings(
init_image=init_image,
mask=1 - mask, # expects white means "paint here."
@ -320,7 +320,7 @@ class Inpaint(Img2Img):
conditioning_data=conditioning_data,
noise_func=self.get_noise_like,
callback=step_callback,
seed=seed
seed=seed,
)
if (
@ -343,6 +343,7 @@ class Inpaint(Img2Img):
seam_size,
seam_blur,
prompt,
seed,
sampler,
seam_steps,
cfg_scale,

View File

@ -61,7 +61,7 @@ class Txt2Img(Generator):
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T) -> PIL.Image.Image:
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
noise=x_T,

View File

@ -64,7 +64,7 @@ class Txt2Img2Img(Generator):
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T):
def make_image(x_T: torch.Tensor, _: int):
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
latents=torch.zeros_like(x_T),
num_inference_steps=steps,

View File

@ -1085,9 +1085,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
dlogging.set_verbosity_error()
checkpoint = (
load_file(checkpoint_path)
if Path(checkpoint_path).suffix == ".safetensors"
else torch.load(checkpoint_path)
torch.load(checkpoint_path)
if Path(checkpoint_path).suffix == ".ckpt"
else load_file(checkpoint_path)
)
cache_dir = global_cache_dir("hub")
pipeline_class = (

View File

@ -97,7 +97,7 @@ class ModelManager(object):
If on disk, will load from there.
"""
if not model_name:
return self.current_model if self.current_model else self.get_model(self.default_model())
return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model())
if not self.valid_model(model_name):
print(
@ -362,6 +362,7 @@ class ModelManager(object):
raise NotImplementedError(
f"Unknown model format {model_name}: {model_format}"
)
self._add_embeddings_to_model(model)
# usage statistics
toc = time.time()
@ -436,7 +437,6 @@ class ModelManager(object):
height = width
print(f" | Default image dimensions = {width} x {height}")
self._add_embeddings_to_model(pipeline)
return pipeline, width, height, model_hash
@ -732,9 +732,9 @@ class ModelManager(object):
# another round of heuristics to guess the correct config file.
checkpoint = (
safetensors.torch.load_file(model_path)
if model_path.suffix == ".safetensors"
else torch.load(model_path)
torch.load(model_path)
if model_path.suffix == ".ckpt"
else safetensors.torch.load_file(model_path)
)
# additional probing needed if no config file provided

View File

@ -6,7 +6,6 @@ The interface is through the Concepts() object.
"""
import os
import re
import traceback
from typing import Callable
from urllib import error as ul_error
from urllib import request
@ -15,7 +14,6 @@ from huggingface_hub import (
HfApi,
HfFolder,
ModelFilter,
ModelSearchArguments,
hf_hub_url,
)
@ -84,7 +82,7 @@ class HuggingFaceConceptsLibrary(object):
"""
if not concept_name in self.list_concepts():
print(
f"This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
)
return None
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
@ -236,7 +234,7 @@ class HuggingFaceConceptsLibrary(object):
except ul_error.HTTPError as e:
if e.code == 404:
print(
f"This concept is not known to the Hugging Face library. Generation will continue without the concept."
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
)
else:
print(
@ -246,7 +244,7 @@ class HuggingFaceConceptsLibrary(object):
return False
except ul_error.URLError as e:
print(
f"ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
)
os.rmdir(dest)
return False

View File

@ -9,6 +9,7 @@ from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
import einops
import PIL.Image
from accelerate.utils import set_seed
import psutil
import torch
import torchvision.transforms as T
@ -694,7 +695,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype,
)
noise = noise_func(initial_latents, seed)
if seed is not None:
set_seed(seed)
noise = noise_func(initial_latents)
return self.img2img_from_latents_and_embeddings(
initial_latents,
@ -796,7 +799,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_image_latents = self.non_noised_latents_from_image(
init_image, device=device, dtype=latents_dtype
)
noise = noise_func(init_image_latents, seed)
if seed is not None:
set_seed(seed)
noise = noise_func(init_image_latents)
if mask.dim() == 3:
mask = mask.unsqueeze(0)

View File

@ -3,6 +3,9 @@ import math
import multiprocessing as mp
import os
import re
import io
import base64
from collections import abc
from inspect import isfunction
from pathlib import Path
@ -364,3 +367,16 @@ def url_attachment_name(url: str) -> dict:
def download_with_progress_bar(url: str, dest: Path) -> bool:
result = download_with_resume(url, dest, access_token=None)
return result is not None
def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
"""
Converts an image into a base64 image dataURL.
"""
buffered = io.BytesIO()
image.save(buffered, format=image_format)
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
return image_base64

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-e1f916bd.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-548a355c.js";var Or=`
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-f7f41e1f.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-eaf47ae3.js";var Or=`
:root {
--chakra-vh: 100vh;
}

View File

@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-e1f916bd.js"></script>
<script type="module" crossorigin src="./assets/index-f7f41e1f.js"></script>
<link rel="stylesheet" href="./assets/index-5483945c.css">
</head>

View File

@ -64,6 +64,8 @@
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
"upload": "Upload",
"close": "Close",
"cancel": "Cancel",
"accept": "Accept",
"load": "Load",
"back": "Back",
"statusConnected": "Connected",
@ -333,6 +335,7 @@
"addNewModel": "Add New Model",
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
"addDiffuserModel": "Add Diffusers",
"scanForModels": "Scan For Models",
"addManually": "Add Manually",
"manual": "Manual",
"name": "Name",

View File

@ -1,3 +1,7 @@
import React, { PropsWithChildren } from 'react';
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
export {};
declare module 'redux-socket.io-middleware';
@ -39,3 +43,36 @@ declare global {
}
/* eslint-enable @typescript-eslint/no-explicit-any */
}
declare module '@invoke-ai/invoke-ai-ui' {
declare class ThemeChanger extends React.Component<ThemeChangerProps> {
public constructor(props: ThemeChangerProps);
}
declare class InvokeAiLogoComponent extends React.Component<InvokeAILogoComponentProps> {
public constructor(props: InvokeAILogoComponentProps);
}
declare class IAIPopover extends React.Component<IAIPopoverProps> {
public constructor(props: IAIPopoverProps);
}
declare class IAIIconButton extends React.Component<IAIIconButtonProps> {
public constructor(props: IAIIconButtonProps);
}
declare class SettingsModal extends React.Component<SettingsModalProps> {
public constructor(props: SettingsModalProps);
}
}
declare function Invoke(props: PropsWithChildren): JSX.Element;
export {
ThemeChanger,
InvokeAiLogoComponent,
IAIPopover,
IAIIconButton,
SettingsModal,
};
export = Invoke;

View File

@ -36,6 +36,7 @@
},
"dependencies": {
"@chakra-ui/anatomy": "^2.1.1",
"@chakra-ui/cli": "^2.3.0",
"@chakra-ui/icons": "^2.0.17",
"@chakra-ui/react": "^2.5.1",
"@chakra-ui/styled-system": "^2.6.1",
@ -52,6 +53,7 @@
"i18next-http-backend": "^2.1.1",
"konva": "^8.4.2",
"lodash": "^4.17.21",
"patch-package": "^6.5.1",
"re-resizable": "^6.9.9",
"react": "^18.2.0",
"react-colorful": "^5.6.1",
@ -72,7 +74,6 @@
"uuid": "^9.0.0"
},
"devDependencies": {
"@chakra-ui/cli": "^2.3.0",
"@fontsource/inter": "^4.5.15",
"@types/dateformat": "^5.0.0",
"@types/react": "^18.0.28",
@ -92,7 +93,6 @@
"husky": "^8.0.3",
"lint-staged": "^13.1.2",
"madge": "^6.0.0",
"patch-package": "^6.5.1",
"postinstall-postinstall": "^2.1.0",
"prettier": "^2.8.4",
"rollup-plugin-visualizer": "^5.9.0",

View File

@ -64,6 +64,8 @@
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
"upload": "Upload",
"close": "Close",
"cancel": "Cancel",
"accept": "Accept",
"load": "Load",
"back": "Back",
"statusConnected": "Connected",
@ -333,6 +335,7 @@
"addNewModel": "Add New Model",
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
"addDiffuserModel": "Add Diffusers",
"scanForModels": "Scan For Models",
"addManually": "Add Manually",
"manual": "Manual",
"name": "Name",

View File

@ -14,11 +14,11 @@ import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
import Lightbox from 'features/lightbox/components/Lightbox';
import { useAppSelector } from './storeHooks';
import { useEffect } from 'react';
import { PropsWithChildren, useEffect } from 'react';
keepGUIAlive();
const App = () => {
const App = (props: PropsWithChildren) => {
useToastWatcher();
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
@ -40,7 +40,7 @@ const App = () => {
w={APP_WIDTH}
h={APP_HEIGHT}
>
<SiteHeader />
{props.children || <SiteHeader />}
<Flex gap={4} w="full" h="full">
<InvokeTabs />
<ImageGalleryPanel />

View File

@ -31,18 +31,14 @@ export const DIFFUSERS_SAMPLERS: Array<string> = [
];
// Valid image widths
export const WIDTHS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792,
1856, 1920, 1984, 2048,
];
export const WIDTHS: Array<number> = Array.from(Array(65)).map(
(_x, i) => i * 64
);
// Valid image heights
export const HEIGHTS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792,
1856, 1920, 1984, 2048,
];
export const HEIGHTS: Array<number> = Array.from(Array(65)).map(
(_x, i) => i * 64
);
// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [

View File

@ -9,6 +9,7 @@ import {
useDisclosure,
} from '@chakra-ui/react';
import { cloneElement, memo, ReactElement, ReactNode, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import IAIButton from './IAIButton';
type Props = {
@ -22,10 +23,12 @@ type Props = {
};
const IAIAlertDialog = forwardRef((props: Props, ref) => {
const { t } = useTranslation();
const {
acceptButtonText = 'Accept',
acceptButtonText = t('common.accept'),
acceptCallback,
cancelButtonText = 'Cancel',
cancelButtonText = t('common.cancel'),
cancelCallback,
children,
title,
@ -56,6 +59,7 @@ const IAIAlertDialog = forwardRef((props: Props, ref) => {
isOpen={isOpen}
leastDestructiveRef={cancelRef}
onClose={onClose}
isCentered
>
<AlertDialogOverlay>
<AlertDialogContent>

View File

@ -0,0 +1,8 @@
import { chakra } from '@chakra-ui/react';
/**
* Chakra-enabled <form />
*/
const IAIForm = chakra.form;
export default IAIForm;

View File

@ -0,0 +1,23 @@
import { Flex } from '@chakra-ui/react';
import { ReactElement } from 'react';
export function IAIFormItemWrapper({
children,
}: {
children: ReactElement | ReactElement[];
}) {
return (
<Flex
sx={{
flexDirection: 'column',
padding: 4,
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: 'base.900',
}}
>
{children}
</Flex>
);
}

View File

@ -8,13 +8,14 @@ import {
import { memo } from 'react';
export type IAIIconButtonProps = IconButtonProps & {
role?: string;
tooltip?: string;
tooltipProps?: Omit<TooltipProps, 'children'>;
isChecked?: boolean;
};
const IAIIconButton = forwardRef((props: IAIIconButtonProps, forwardedRef) => {
const { tooltip = '', tooltipProps, isChecked, ...rest } = props;
const { role, tooltip = '', tooltipProps, isChecked, ...rest } = props;
return (
<Tooltip
@ -27,6 +28,7 @@ const IAIIconButton = forwardRef((props: IAIIconButtonProps, forwardedRef) => {
>
<IconButton
ref={forwardedRef}
role={role}
aria-checked={isChecked !== undefined ? isChecked : undefined}
{...rest}
/>
@ -34,4 +36,5 @@ const IAIIconButton = forwardRef((props: IAIIconButtonProps, forwardedRef) => {
);
});
IAIIconButton.displayName = 'IAIIconButton';
export default memo(IAIIconButton);

View File

@ -8,7 +8,7 @@ import {
} from '@chakra-ui/react';
import { memo, ReactNode } from 'react';
type IAIPopoverProps = PopoverProps & {
export type IAIPopoverProps = PopoverProps & {
triggerComponent: ReactNode;
triggerContainerProps?: BoxProps;
children: ReactNode;

View File

@ -1,4 +1,4 @@
import React, { lazy } from 'react';
import React, { lazy, PropsWithChildren } from 'react';
import { Provider } from 'react-redux';
import { PersistGate } from 'redux-persist/integration/react';
import { store } from './app/store';
@ -21,14 +21,14 @@ import './i18n';
const App = lazy(() => import('./app/App'));
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
export default function Component() {
export default function Component(props: PropsWithChildren) {
return (
<React.StrictMode>
<Provider store={store}>
<PersistGate loading={<Loading />} persistor={persistor}>
<React.Suspense fallback={<Loading showText />}>
<ThemeLocaleProvider>
<App />
<App>{props.children}</App>
</ThemeLocaleProvider>
</React.Suspense>
</PersistGate>

View File

@ -0,0 +1,16 @@
import Component from './component';
import InvokeAiLogoComponent from './features/system/components/InvokeAILogoComponent';
import ThemeChanger from './features/system/components/ThemeChanger';
import IAIPopover from './common/components/IAIPopover';
import IAIIconButton from './common/components/IAIIconButton';
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
export default Component;
export {
InvokeAiLogoComponent,
ThemeChanger,
IAIPopover,
IAIIconButton,
SettingsModal,
};

View File

@ -104,7 +104,6 @@ const IAICanvasMaskOptions = () => {
return (
<IAIPopover
trigger="hover"
triggerComponent={
<ButtonGroup>
<IAIIconButton

View File

@ -88,7 +88,7 @@ const IAICanvasSettingsButtonPopover = () => {
return (
<IAIPopover
trigger="hover"
isLazy={false}
triggerComponent={
<IAIIconButton
tooltip={t('unifiedCanvas.canvasSettings')}

View File

@ -219,7 +219,6 @@ const IAICanvasToolChooserOptions = () => {
onClick={handleSelectColorPickerTool}
/>
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
aria-label={t('unifiedCanvas.brushOptions')}

View File

@ -405,7 +405,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
>
<ButtonGroup isAttached={true}>
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
aria-label={`${t('parameters.sendTo')}...`}
@ -505,7 +504,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<ButtonGroup isAttached={true}>
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
icon={<FaGrinStars />}
@ -535,7 +533,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
</IAIPopover>
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
icon={<FaExpandArrowsAlt />}

View File

@ -0,0 +1,24 @@
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
type CurrentImageFallbackProps = SpinnerProps;
const CurrentImageFallback = (props: CurrentImageFallbackProps) => {
const { size = 'xl', ...rest } = props;
return (
<Flex
sx={{
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'absolute',
color: 'base.400',
}}
>
<Spinner size={size} {...rest} />
</Flex>
);
};
export default CurrentImageFallback;

View File

@ -7,6 +7,7 @@ import { isEqual } from 'lodash';
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
import { gallerySelector } from '../store/gallerySelectors';
import CurrentImageFallback from './CurrentImageFallback';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
@ -48,6 +49,7 @@ export default function CurrentImagePreview() {
src={imageToDisplay.url}
width={imageToDisplay.width}
height={imageToDisplay.height}
fallback={!isIntermediate ? <CurrentImageFallback /> : undefined}
sx={{
objectFit: 'contain',
maxWidth: '100%',

View File

@ -101,12 +101,14 @@ const ImageGalleryContent = () => {
aria-label={t('gallery.showGenerations')}
tooltip={t('gallery.showGenerations')}
isChecked={currentCategory === 'result'}
role="radio"
icon={<FaImage />}
onClick={() => dispatch(setCurrentCategory('result'))}
/>
<IAIIconButton
aria-label={t('gallery.showUploads')}
tooltip={t('gallery.showUploads')}
role="radio"
isChecked={currentCategory === 'user'}
icon={<FaUser />}
onClick={() => dispatch(setCurrentCategory('user'))}
@ -251,4 +253,5 @@ const ImageGalleryContent = () => {
);
};
ImageGalleryContent.displayName = 'ImageGalleryContent';
export default ImageGalleryContent;

View File

@ -55,7 +55,6 @@ export default function LanguagePicker() {
return (
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
aria-label={t('common.languagePickerLabel')}

View File

@ -1,4 +1,5 @@
import {
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
@ -25,10 +26,10 @@ import { useTranslation } from 'react-i18next';
import type { InvokeModelConfigProps } from 'app/invokeai';
import type { RootState } from 'app/store';
import IAIIconButton from 'common/components/IAIIconButton';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import type { FieldInputProps, FormikProps } from 'formik';
import { BiArrowBack } from 'react-icons/bi';
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048;
@ -72,243 +73,250 @@ export default function AddCheckpointModel() {
return (
<VStack gap={2} alignItems="flex-start">
<IAIIconButton
aria-label={t('common.back')}
tooltip={t('common.back')}
onClick={() => dispatch(setAddNewModelUIOption(null))}
width="max-content"
position="absolute"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={3}
icon={<BiArrowBack />}
/>
<Flex columnGap={4}>
<IAICheckbox
isChecked={!addManually}
label={t('modelManager.scanForModels')}
onChange={() => setAddmanually(!addManually)}
/>
<IAICheckbox
label={t('modelManager.addManually')}
isChecked={addManually}
onChange={() => setAddmanually(!addManually)}
/>
</Flex>
<SearchModels />
<IAICheckbox
label={t('modelManager.addManually')}
isChecked={addManually}
onChange={() => setAddmanually(!addManually)}
/>
{addManually && (
{addManually ? (
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}>
<IAIForm onSubmit={handleSubmit} sx={{ w: 'full' }}>
<VStack rowGap={2}>
<Text fontSize={20} fontWeight="bold" alignSelf="start">
{t('modelManager.manual')}
</Text>
{/* Name */}
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
width="2xl"
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
width="full"
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="2xl"
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>{errors.description}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="full"
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>
{errors.description}
</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Config */}
<FormControl
isInvalid={!!errors.config && touched.config}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.config')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="config"
name="config"
type="text"
width="2xl"
/>
{!!errors.config && touched.config ? (
<FormErrorMessage>{errors.config}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.configValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.config && touched.config}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.config')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="config"
name="config"
type="text"
width="full"
/>
{!!errors.config && touched.config ? (
<FormErrorMessage>{errors.config}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.configValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Weights */}
<FormControl
isInvalid={!!errors.weights && touched.weights}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="weights"
name="weights"
type="text"
width="2xl"
/>
{!!errors.weights && touched.weights ? (
<FormErrorMessage>{errors.weights}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.weights && touched.weights}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="weights"
name="weights"
type="text"
width="full"
/>
{!!errors.weights && touched.weights ? (
<FormErrorMessage>{errors.weights}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* VAE */}
<FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae"
name="vae"
type="text"
width="2xl"
/>
{!!errors.vae && touched.vae ? (
<FormErrorMessage>{errors.vae}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae"
name="vae"
type="text"
width="full"
/>
{!!errors.vae && touched.vae ? (
<FormErrorMessage>{errors.vae}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<HStack width="100%">
{/* Width */}
<FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')}
</FormLabel>
<VStack alignItems="start">
<Field id="width" name="width">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="width"
name="width"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
width="90%"
value={form.values.width}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')}
</FormLabel>
<VStack alignItems="start">
<Field id="width" name="width">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="width"
name="width"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.width}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.width && touched.width ? (
<FormErrorMessage>{errors.width}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.widthValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{!!errors.width && touched.width ? (
<FormErrorMessage>{errors.width}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.widthValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Height */}
<FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')}
</FormLabel>
<VStack alignItems="start">
<Field id="height" name="height">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="height"
name="height"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
width="90%"
step={64}
value={form.values.height}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')}
</FormLabel>
<VStack alignItems="start">
<Field id="height" name="height">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="height"
name="height"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.height}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.height && touched.height ? (
<FormErrorMessage>{errors.height}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.heightValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{!!errors.height && touched.height ? (
<FormErrorMessage>{errors.height}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.heightValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
</HStack>
<IAIButton
@ -319,9 +327,11 @@ export default function AddCheckpointModel() {
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</form>
</IAIForm>
)}
</Formik>
) : (
<SearchModels />
)}
</VStack>
);

View File

@ -11,36 +11,14 @@ import { InvokeDiffusersModelConfigProps } from 'app/invokeai';
import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIInput from 'common/components/IAIInput';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import { BiArrowBack } from 'react-icons/bi';
import type { RootState } from 'app/store';
import type { ReactElement } from 'react';
function FormItemWrapper({
children,
}: {
children: ReactElement | ReactElement[];
}) {
return (
<Flex
sx={{
flexDirection: 'column',
padding: 4,
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: 'base.900',
}}
>
{children}
</Flex>
);
}
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
export default function AddDiffusersModel() {
const dispatch = useAppDispatch();
@ -89,26 +67,14 @@ export default function AddDiffusersModel() {
return (
<Flex>
<IAIIconButton
aria-label={t('common.back')}
tooltip={t('common.back')}
onClick={() => dispatch(setAddNewModelUIOption(null))}
width="max-content"
position="absolute"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={3}
icon={<BiArrowBack />}
/>
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}>
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2}>
<FormItemWrapper>
<IAIFormItemWrapper>
{/* Name */}
<FormControl
isInvalid={!!errors.name && touched.name}
@ -136,9 +102,9 @@ export default function AddDiffusersModel() {
)}
</VStack>
</FormControl>
</FormItemWrapper>
</IAIFormItemWrapper>
<FormItemWrapper>
<IAIFormItemWrapper>
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
@ -165,9 +131,9 @@ export default function AddDiffusersModel() {
)}
</VStack>
</FormControl>
</FormItemWrapper>
</IAIFormItemWrapper>
<FormItemWrapper>
<IAIFormItemWrapper>
<Text fontWeight="bold" fontSize="sm">
{t('modelManager.formMessageDiffusersModelLocation')}
</Text>
@ -226,9 +192,9 @@ export default function AddDiffusersModel() {
)}
</VStack>
</FormControl>
</FormItemWrapper>
</IAIFormItemWrapper>
<FormItemWrapper>
<IAIFormItemWrapper>
{/* VAE Path */}
<Text fontWeight="bold">
{t('modelManager.formMessageDiffusersVAELocation')}
@ -290,13 +256,13 @@ export default function AddDiffusersModel() {
)}
</VStack>
</FormControl>
</FormItemWrapper>
</IAIFormItemWrapper>
<IAIButton type="submit" isLoading={isProcessing}>
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</form>
</IAIForm>
)}
</Formik>
</Flex>

View File

@ -14,7 +14,7 @@ import {
import IAIButton from 'common/components/IAIButton';
import { FaPlus } from 'react-icons/fa';
import { FaArrowLeft, FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { useTranslation } from 'react-i18next';
@ -23,6 +23,7 @@ import type { RootState } from 'app/store';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import AddCheckpointModel from './AddCheckpointModel';
import AddDiffusersModel from './AddDiffusersModel';
import IAIIconButton from 'common/components/IAIIconButton';
function AddModelBox({
text,
@ -83,8 +84,22 @@ export default function AddModel() {
closeOnOverlayClick={false}
>
<ModalOverlay />
<ModalContent margin="auto" paddingInlineEnd={4}>
<ModalHeader>{t('modelManager.addNewModel')}</ModalHeader>
<ModalContent margin="auto">
<ModalHeader>{t('modelManager.addNewModel')} </ModalHeader>
{addNewModelUIOption !== null && (
<IAIIconButton
aria-label={t('common.back')}
tooltip={t('common.back')}
onClick={() => dispatch(setAddNewModelUIOption(null))}
position="absolute"
variant="ghost"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={2}
icon={<FaArrowLeft />}
/>
)}
<ModalCloseButton />
<ModalBody>
{addNewModelUIOption == null && (

View File

@ -28,6 +28,7 @@ import { isEqual, pickBy } from 'lodash';
import ModelConvert from './ModelConvert';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIForm from 'common/components/IAIForm';
const selector = createSelector(
[systemSelector],
@ -120,7 +121,7 @@ export default function CheckpointModelEdit() {
onSubmit={editModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}>
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2} alignItems="start">
{/* Description */}
<FormControl
@ -317,7 +318,7 @@ export default function CheckpointModelEdit() {
{t('modelManager.updateModel')}
</IAIButton>
</VStack>
</form>
</IAIForm>
)}
</Formik>
</Flex>

View File

@ -18,6 +18,7 @@ import type { RootState } from 'app/store';
import { isEqual, pickBy } from 'lodash';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIForm from 'common/components/IAIForm';
const selector = createSelector(
[systemSelector],
@ -116,7 +117,7 @@ export default function DiffusersModelEdit() {
onSubmit={editModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}>
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2} alignItems="start">
{/* Description */}
<FormControl
@ -259,7 +260,7 @@ export default function DiffusersModelEdit() {
{t('modelManager.updateModel')}
</IAIButton>
</VStack>
</form>
</IAIForm>
)}
</Formik>
</Flex>

View File

@ -12,14 +12,13 @@ import {
RadioGroup,
Spacer,
Text,
VStack,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { useTranslation } from 'react-i18next';
import { FaPlus, FaSearch } from 'react-icons/fa';
import { FaSearch, FaTrash } from 'react-icons/fa';
import { addNewModel, searchForModels } from 'app/socketio/actions';
import {
@ -34,7 +33,7 @@ import IAIInput from 'common/components/IAIInput';
import { Field, Formik } from 'formik';
import { forEach, remove } from 'lodash';
import type { ChangeEvent, ReactNode } from 'react';
import { BiReset } from 'react-icons/bi';
import IAIForm from 'common/components/IAIForm';
const existingModelsSelector = createSelector([systemSelector], (system) => {
const { model_list } = system;
@ -71,34 +70,32 @@ function SearchModelEntry({
};
return (
<VStack>
<Flex
flexDirection="column"
gap={2}
backgroundColor={
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
}
paddingX={4}
paddingY={2}
borderRadius={4}
>
<Flex gap={4}>
<IAICheckbox
value={model.name}
label={<Text fontWeight={500}>{model.name}</Text>}
isChecked={modelsToAdd.includes(model.name)}
isDisabled={existingModels.includes(model.location)}
onChange={foundModelsChangeHandler}
></IAICheckbox>
{existingModels.includes(model.location) && (
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
)}
</Flex>
<Text fontStyle="italic" variant="subtext">
{model.location}
</Text>
<Flex
flexDirection="column"
gap={2}
backgroundColor={
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
}
paddingX={4}
paddingY={2}
borderRadius={4}
>
<Flex gap={4} alignItems="center" justifyContent="space-between">
<IAICheckbox
value={model.name}
label={<Text fontWeight={500}>{model.name}</Text>}
isChecked={modelsToAdd.includes(model.name)}
isDisabled={existingModels.includes(model.location)}
onChange={foundModelsChangeHandler}
></IAICheckbox>
{existingModels.includes(model.location) && (
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
)}
</Flex>
</VStack>
<Text fontStyle="italic" variant="subtext">
{model.location}
</Text>
</Flex>
);
}
@ -215,10 +212,10 @@ export default function SearchModels() {
}
return (
<>
<Flex flexDirection="column" rowGap={4}>
{newFoundModels}
{shouldShowExistingModelsInSearch && existingFoundModels}
</>
</Flex>
);
};
@ -245,26 +242,26 @@ export default function SearchModels() {
<Text
sx={{
fontWeight: 500,
fontSize: 'sm',
}}
variant="subtext"
>
{t('modelManager.checkpointFolder')}
</Text>
<Text sx={{ fontWeight: 500, fontSize: 'sm' }}>{searchFolder}</Text>
<Text sx={{ fontWeight: 500 }}>{searchFolder}</Text>
</Flex>
<Spacer />
<IAIIconButton
aria-label={t('modelManager.scanAgain')}
tooltip={t('modelManager.scanAgain')}
icon={<BiReset />}
icon={<FaSearch />}
fontSize={18}
disabled={isProcessing}
onClick={() => dispatch(searchForModels(searchFolder))}
/>
<IAIIconButton
aria-label={t('modelManager.clearCheckpointFolder')}
icon={<FaPlus style={{ transform: 'rotate(45deg)' }} />}
tooltip={t('modelManager.clearCheckpointFolder')}
icon={<FaTrash />}
onClick={resetSearchModelHandler}
/>
</Flex>
@ -276,9 +273,9 @@ export default function SearchModels() {
}}
>
{({ handleSubmit }) => (
<form onSubmit={handleSubmit}>
<HStack columnGap={2} alignItems="flex-end" width="100%">
<FormControl isRequired width="lg">
<IAIForm onSubmit={handleSubmit} width="100%">
<HStack columnGap={2} alignItems="flex-end">
<FormControl flexGrow={1}>
<Field
as={IAIInput}
id="checkpointFolder"
@ -294,12 +291,12 @@ export default function SearchModels() {
tooltip={t('modelManager.findModels')}
type="submit"
disabled={isProcessing}
paddingX={10}
px={8}
>
{t('modelManager.findModels')}
</IAIButton>
</HStack>
</form>
</IAIForm>
)}
</Formik>
)}
@ -410,7 +407,6 @@ export default function SearchModels() {
maxHeight={72}
overflowY="scroll"
borderRadius="sm"
paddingInlineEnd={4}
gap={2}
>
{foundModels.length > 0 ? (

View File

@ -61,47 +61,53 @@ const SiteHeader = () => {
<LanguagePicker />
<IAIIconButton
aria-label={t('common.reportBugLabel')}
tooltip={t('common.reportBugLabel')}
variant="link"
data-variant="link"
fontSize={20}
size="sm"
icon={
<Link isExternal href="http://github.com/invoke-ai/InvokeAI/issues">
<FaBug />
</Link>
}
/>
<Link
isExternal
href="http://github.com/invoke-ai/InvokeAI/issues"
marginBottom="-0.25rem"
>
<IAIIconButton
aria-label={t('common.reportBugLabel')}
tooltip={t('common.reportBugLabel')}
variant="link"
data-variant="link"
fontSize={20}
size="sm"
icon={<FaBug />}
/>
</Link>
<IAIIconButton
aria-label={t('common.githubLabel')}
tooltip={t('common.githubLabel')}
variant="link"
data-variant="link"
fontSize={20}
size="sm"
icon={
<Link isExternal href="http://github.com/invoke-ai/InvokeAI">
<FaGithub />
</Link>
}
/>
<Link
isExternal
href="http://github.com/invoke-ai/InvokeAI"
marginBottom="-0.25rem"
>
<IAIIconButton
aria-label={t('common.githubLabel')}
tooltip={t('common.githubLabel')}
variant="link"
data-variant="link"
fontSize={20}
size="sm"
icon={<FaGithub />}
/>
</Link>
<IAIIconButton
aria-label={t('common.discordLabel')}
tooltip={t('common.discordLabel')}
variant="link"
data-variant="link"
fontSize={20}
size="sm"
icon={
<Link isExternal href="https://discord.gg/ZmtBAhwWhy">
<FaDiscord />
</Link>
}
/>
<Link
isExternal
href="https://discord.gg/ZmtBAhwWhy"
marginBottom="-0.25rem"
>
<IAIIconButton
aria-label={t('common.discordLabel')}
tooltip={t('common.discordLabel')}
variant="link"
data-variant="link"
fontSize={20}
size="sm"
icon={<FaDiscord />}
/>
</Link>
<SettingsModal>
<IAIIconButton
@ -119,4 +125,5 @@ const SiteHeader = () => {
);
};
SiteHeader.displayName = 'SiteHeader';
export default SiteHeader;

View File

@ -50,7 +50,6 @@ export default function ThemeChanger() {
return (
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
aria-label={t('common.themeLabel')}

View File

@ -166,20 +166,8 @@ export default function InvokeTabs() {
[]
);
/**
* isLazy means the tabs are mounted and unmounted when changing them. There is a tradeoff here,
* as mounting is expensive, but so is retaining all tabs in the DOM at all times.
*
* Removing isLazy messes with the outside click watcher, which is used by ResizableDrawer.
* Because you have multiple handlers listening for an outside click, any click anywhere triggers
* the watcher for the hidden drawers, closing the open drawer.
*
* TODO: Add logic to the `useOutsideClick` in ResizableDrawer to enable it only for the active
* tab's drawer.
*/
return (
<Tabs
isLazy
defaultIndex={activeTab}
index={activeTab}
onChange={(index: number) => {

View File

@ -93,12 +93,9 @@ const ResizableDrawer = ({
useOutsideClick({
ref: outsideClickRef,
handler: () => {
if (isPinned) {
return;
}
onClose();
},
enabled: isOpen && !isPinned,
});
const handleEnables = useMemo(

View File

@ -77,7 +77,6 @@ export default function UnifiedCanvasColorPicker() {
return (
<IAIPopover
trigger="hover"
triggerComponent={
<Box
sx={{

View File

@ -56,7 +56,7 @@ const UnifiedCanvasSettings = () => {
return (
<IAIPopover
trigger="hover"
isLazy={false}
triggerComponent={
<IAIIconButton
tooltip={t('unifiedCanvas.canvasSettings')}

File diff suppressed because one or more lines are too long

View File

@ -46,7 +46,7 @@ export default defineConfig(({ mode }) => {
* overrides any target specified here.
*/
// target: 'esnext',
chunkSizeWarningLimit: 1500, // we don't really care about chunk size
chunkSizeWarningLimit: 1500, // we don't really care about chunk size,
},
};
if (mode == 'development') {

View File

@ -38,16 +38,16 @@ dependencies = [
"albumentations",
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==0.1.10",
"compel==1.0.4",
"datasets",
"diffusers[torch]~=0.14",
"dnspython==2.2.1",
"einops",
"eventlet",
"facexlib",
"fastapi==0.85.0",
"fastapi-events==0.6.0",
"fastapi-socketio==0.0.9",
"fastapi==0.94.1",
"fastapi-events==0.8.0",
"fastapi-socketio==0.0.10",
"flask==2.1.3",
"flask_cors==3.0.10",
"flask_socketio==5.3.0",
@ -75,7 +75,7 @@ dependencies = [
"torchvision>=0.14.1",
"torchmetrics",
"transformers~=4.26",
"uvicorn[standard]==0.20.0",
"uvicorn[standard]==0.21.1",
"windows-curses; sys_platform=='win32'",
]
@ -139,8 +139,24 @@ version = { attr = "invokeai.version.__version__" }
"invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"]
"invokeai.frontend.web.dist" = ["**"]
#=== Begin: PyTest and Coverage
[tool.pytest.ini_options]
addopts = "-p pytest_cov --junitxml=junit/test-results.xml --cov-report=term:skip-covered --cov=ldm/invoke --cov=backend --cov-branch"
addopts = "--cov-report term --cov-report html --cov-report xml"
[tool.coverage.run]
branch = true
source = ["invokeai"]
omit = ["*tests*", "*migrations*", ".venv/*", "*.env"]
[tool.coverage.report]
show_missing = true
fail_under = 85 # let's set something sensible on Day 1 ...
[tool.coverage.json]
output = "coverage/coverage.json"
pretty_print = true
[tool.coverage.html]
directory = "coverage/html"
[tool.coverage.xml]
output = "coverage/index.xml"
#=== End: PyTest and Coverage
[flake8]
max-line-length = 120

View File

@ -105,17 +105,20 @@
// Start building nodes
var id = 1;
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "sampler": sampler, "steps": steps, "seed": seed};
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": seed};
id++;
var i2iNode = {"id": id.toString(), "type": "img2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": Math.floor(Math.random() * 10000)};
id++;
var upscaleNode = {"id": id.toString(), "type": "show_image" };
id++
nodes = {};
nodes[initialNode.id] = initialNode;
nodes[i2iNode.id] = i2iNode;
nodes[upscaleNode.id] = upscaleNode;
links = [
[{ "node_id": initialNode.id, field: "image" },
{ "node_id": upscaleNode.id, field: "image" }]
{ "source": { "node_id": initialNode.id, field: "image" }, "destination": { "node_id": i2iNode.id, field: "image" }},
{ "source": { "node_id": i2iNode.id, field: "image" }, "destination": { "node_id": upscaleNode.id, field: "image" }}
];
// expandSize = 128;
// for (var i = 0; i < 6; ++i) {

View File

@ -1,15 +1,18 @@
from invokeai.app.invocations.image import *
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from invokeai.app.invocations.upscale import UpscaleInvocation
import pytest
# Helpers
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
return Edge(
source=EdgeConnection(node_id = from_id, field = from_field),
destination=EdgeConnection(node_id = to_id, field = to_field)
)
# Tests
def test_connections_are_compatible():
@ -108,7 +111,7 @@ def test_graph_allows_non_conflicting_id_change():
assert g.get_node("3").prompt == "Banana sushi"
assert len(g.edges) == 1
assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges
assert Edge(source=EdgeConnection(node_id = "3", field = "image"), destination=EdgeConnection(node_id = "2", field = "image")) in g.edges
def test_graph_fails_to_update_node_id_if_conflict():
g = Graph()
@ -490,10 +493,10 @@ def test_graph_can_deserialize():
assert g2.nodes['1'] is not None
assert g2.nodes['2'] is not None
assert len(g2.edges) == 1
assert g2.edges[0][0].node_id == '1'
assert g2.edges[0][0].field == 'image'
assert g2.edges[0][1].node_id == '2'
assert g2.edges[0][1].field == 'image'
assert g2.edges[0].source.node_id == '1'
assert g2.edges[0].source.field == 'image'
assert g2.edges[0].destination.node_id == '2'
assert g2.edges[0].destination.field == 'image'
def test_graph_can_generate_schema():
# Not throwing on this line is sufficient

View File

@ -64,10 +64,12 @@ class PromptCollectionTestInvocation(BaseInvocation):
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import EdgeConnection
from invokeai.app.services.graph import Edge, EdgeConnection
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
return Edge(
source=EdgeConnection(node_id = from_id, field = from_field),
destination=EdgeConnection(node_id = to_id, field = to_field))
class TestEvent: