Merge branch 'main' into patch-1

This commit is contained in:
Lincoln Stein 2023-03-25 10:48:25 -04:00 committed by GitHub
commit 07ea806553
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
85 changed files with 1132 additions and 873 deletions

View File

@ -1,6 +0,0 @@
[run]
omit='.env/*'
source='.'
[report]
show_missing = true

8
.github/CODEOWNERS vendored
View File

@ -1,16 +1,16 @@
# continuous integration
/.github/workflows/ @mauwii @lstein
/.github/workflows/ @mauwii @lstein @blessedcoolant
# documentation
/docs/ @lstein @mauwii @tildebyte
/mkdocs.yml @lstein @mauwii
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
/mkdocs.yml @lstein @mauwii @blessedcoolant
# nodes
/invokeai/app/ @Kyle0654 @blessedcoolant
# installation and configuration
/pyproject.toml @mauwii @lstein @blessedcoolant
/docker/ @mauwii @lstein
/docker/ @mauwii @lstein @blessedcoolant
/scripts/ @ebr @lstein
/installer/ @lstein @ebr
/invokeai/assets @lstein @ebr

View File

@ -65,6 +65,16 @@ body:
placeholder: 8GB
validations:
required: false
- type: input
id: version-number
attributes:
label: What version did you experience this issue on?
description: |
Please share the version of Invoke AI that you experienced the issue on. If this is not the latest version, please update first to confirm the issue still exists. If you are testing main, please include the commit hash instead.
placeholder: X.X.X
validations:
required: true
- type: textarea
id: what-happened

View File

@ -34,6 +34,8 @@ jobs:
- name: deploy to gh-pages
if: ${{ github.ref == 'refs/heads/main' }}
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
run: |
python -m \
mkdocs gh-deploy \

View File

@ -6,7 +6,6 @@ on:
- '!pyproject.toml'
- '!invokeai/**'
- 'invokeai/frontend/web/**'
- '!invokeai/frontend/web/dist/**'
merge_group:
workflow_dispatch:

View File

@ -7,13 +7,11 @@ on:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'invokeai/frontend/web/dist/**'
pull_request:
paths:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'invokeai/frontend/web/dist/**'
types:
- 'ready_for_review'
- 'opened'

2
.gitignore vendored
View File

@ -63,6 +63,7 @@ pip-delete-this-directory.txt
htmlcov/
.tox/
.nox/
.coveragerc
.coverage
.coverage.*
.cache
@ -73,6 +74,7 @@ cov.xml
*.py,cover
.hypothesis/
.pytest_cache/
.pytest.ini
cover/
junit/

View File

@ -1,5 +0,0 @@
[pytest]
DJANGO_SETTINGS_MODULE = webtas.settings
; python_files = tests.py test_*.py *_tests.py
addopts = --cov=. --cov-config=.coveragerc --cov-report xml:cov.xml

View File

@ -139,7 +139,7 @@ not supported.
_For Windows/Linux with an NVIDIA GPU:_
```terminal
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
```
_For Linux with an AMD GPU:_

4
coverage/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore

Binary file not shown.

After

Width:  |  Height:  |  Size: 470 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 457 KiB

View File

@ -0,0 +1,83 @@
# Local Development
If you are looking to contribute you will need to have a local development
environment. See the
[Developer Install](../installation/020_INSTALL_MANUAL.md#developer-install) for
full details.
Broadly this involves cloning the repository, installing the pre-reqs, and
InvokeAI (in editable form). Assuming this is working, choose your area of
focus.
## Documentation
We use [mkdocs](https://www.mkdocs.org) for our documentation with the
[material theme](https://squidfunk.github.io/mkdocs-material/). Documentation is
written in markdown files under the `./docs` folder and then built into a static
website for hosting with GitHub Pages at
[invoke-ai.github.io/InvokeAI](https://invoke-ai.github.io/InvokeAI).
To contribute to the documentation you'll need to install the dependencies. Note
the use of `"`.
```zsh
pip install ".[docs]"
```
Now, to run the documentation locally with hot-reloading for changes made.
```zsh
mkdocs serve
```
You'll then be prompted to connect to `http://127.0.0.1:8080` in order to
access.
## Backend
The backend is contained within the `./invokeai/backend` folder structure. To
get started however please install the development dependencies.
From the root of the repository run the following command. Note the use of `"`.
```zsh
pip install ".[test]"
```
This in an optional group of packages which is defined within the
`pyproject.toml` and will be required for testing the changes you make the the
code.
### Running Tests
We use [pytest](https://docs.pytest.org/en/7.2.x/) for our test suite. Tests can
be found under the `./tests` folder and can be run with a single `pytest`
command. Optionally, to review test coverage you can append `--cov`.
```zsh
pytest --cov
```
Test outcomes and coverage will be reported in the terminal. In addition a more
detailed report is created in both XML and HTML format in the `./coverage`
folder. The HTML one in particular can help identify missing statements
requiring tests to ensure coverage. This can be run by opening
`./coverage/html/index.html`.
For example.
```zsh
pytest --cov; open ./coverage/html/index.html
```
??? info "HTML coverage report output"
![html-overview](../assets/contributing/html-overview.png)
![html-detail](../assets/contributing/html-detail.png)
## Front End
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
--8<-- "invokeai/frontend/web/README.md"

View File

@ -17,7 +17,7 @@ notebooks.
You will need a GPU to perform training in a reasonable length of
time, and at least 12 GB of VRAM. We recommend using the [`xformers`
library](../installation/070_INSTALL_XFORMERS) to accelerate the
library](../installation/070_INSTALL_XFORMERS.md) to accelerate the
training process further. During training, about ~8 GB is temporarily
needed in order to store intermediate models, checkpoints and logs.

View File

@ -24,7 +24,7 @@ You need to have opencv installed so that pypatchmatch can be built:
brew install opencv
```
The next time you start `invoke`, after sucesfully installing opencv, pypatchmatch will be built.
The next time you start `invoke`, after successfully installing opencv, pypatchmatch will be built.
## Linux
@ -56,7 +56,7 @@ Prior to installing PyPatchMatch, you need to take the following steps:
5. Confirm that pypatchmatch is installed. At the command-line prompt enter
`python`, and then at the `>>>` line type
`from patchmatch import patch_match`: It should look like the follwing:
`from patchmatch import patch_match`: It should look like the following:
```py
Python 3.9.5 (default, Nov 23 2021, 15:27:38)
@ -108,4 +108,4 @@ Prior to installing PyPatchMatch, you need to take the following steps:
[**Next, Follow Steps 4-6 from the Debian Section above**](#linux)
If you see no errors, then you're ready to go!
If you see no errors you're ready to go!

View File

@ -10,6 +10,7 @@ from pydantic.fields import Field
from ...invocations import *
from ...invocations.baseinvocation import BaseInvocation
from ...services.graph import (
Edge,
EdgeConnection,
Graph,
GraphExecutionState,
@ -92,7 +93,7 @@ async def get_session(
async def add_node(
session_id: str = Path(description="The id of the session"),
node: Annotated[
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
] = Body(description="The node to add"),
) -> str:
"""Adds a node to the graph"""
@ -125,7 +126,7 @@ async def update_node(
session_id: str = Path(description="The id of the session"),
node_path: str = Path(description="The path to the node in the graph"),
node: Annotated[
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
] = Body(description="The new node"),
) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges"""
@ -186,7 +187,7 @@ async def delete_node(
)
async def add_edge(
session_id: str = Path(description="The id of the session"),
edge: tuple[EdgeConnection, EdgeConnection] = Body(description="The edge to add"),
edge: Edge = Body(description="The edge to add"),
) -> GraphExecutionState:
"""Adds an edge to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
@ -228,9 +229,9 @@ async def delete_edge(
return Response(status_code=404)
try:
edge = (
EdgeConnection(node_id=from_node_id, field=from_field),
EdgeConnection(node_id=to_node_id, field=to_field),
edge = Edge(
source=EdgeConnection(node_id=from_node_id, field=from_field),
destination=EdgeConnection(node_id=to_node_id, field=to_field)
)
session.delete_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set(

View File

@ -19,7 +19,7 @@ from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import EdgeConnection, GraphExecutionState
from .services.graph import Edge, EdgeConnection, GraphExecutionState
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
@ -77,7 +77,7 @@ def get_command_parser() -> argparse.ArgumentParser:
def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation
) -> list[tuple[EdgeConnection, EdgeConnection]]:
) -> list[Edge]:
"""Generates all possible edges between two invocations"""
atype = type(a)
btype = type(b)
@ -94,9 +94,9 @@ def generate_matching_edges(
matching_fields = matching_fields.difference(invalid_fields)
edges = [
(
EdgeConnection(node_id=a.id, field=field),
EdgeConnection(node_id=b.id, field=field),
Edge(
source=EdgeConnection(node_id=a.id, field=field),
destination=EdgeConnection(node_id=b.id, field=field)
)
for field in matching_fields
]
@ -111,16 +111,15 @@ class SessionError(Exception):
def invoke_all(context: CliContext):
"""Runs all invocations in the specified session"""
context.invoker.invoke(context.session, invoke_all=True)
while not context.session.is_complete():
while not context.get_session().is_complete():
# Wait some time
session = context.get_session()
time.sleep(0.1)
# Print any errors
if context.session.has_error():
for n in context.session.errors:
print(
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}"
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
)
raise SessionError()
@ -203,7 +202,7 @@ def invoke_cli():
continue
# Pipe previous command output (if there was a previous command)
edges = []
edges: list[Edge] = list()
if len(history) > 0 or current_id != start_id:
from_id = (
history[0] if current_id == start_id else str(current_id - 1)
@ -225,19 +224,19 @@ def invoke_cli():
matching_edges = generate_matching_edges(
link_node, command.command
)
matching_destinations = [e[1] for e in matching_edges]
edges = [e for e in edges if e[1] not in matching_destinations]
matching_destinations = [e.destination for e in matching_edges]
edges = [e for e in edges if e.destination not in matching_destinations]
edges.extend(matching_edges)
if "link" in args and args["link"]:
for link in args["link"]:
edges = [e for e in edges if e[1].node_id != command.command.id and e[1].field != link[2]]
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
edges.append(
(
EdgeConnection(node_id=link[1], field=link[0]),
EdgeConnection(
Edge(
source=EdgeConnection(node_id=link[1], field=link[0]),
destination=EdgeConnection(
node_id=command.command.id, field=link[2]
),
)
)
)

View File

@ -4,6 +4,8 @@ from datetime import datetime, timezone
from typing import Any, Literal, Optional, Union
import numpy as np
from torch import Tensor
from PIL import Image
from pydantic import Field
from skimage.exposure.histogram_matching import match_histograms
@ -12,7 +14,9 @@ from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers())
@ -41,18 +45,32 @@ class TextToImageInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, sample: Any = None, step: int = 0
) -> None:
self, context: InvocationContext, sample: Tensor, step: int
) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
float(step) / float(self.steps),
self.steps,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache

View File

@ -1,7 +1,10 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict
from typing import Any, Dict, TypedDict
ProgressImage = TypedDict(
"ProgressImage", {"dataURL": str, "width": int, "height": int}
)
class EventServiceBase:
session_event: str = "session_event"
@ -23,8 +26,9 @@ class EventServiceBase:
self,
graph_execution_state_id: str,
invocation_id: str,
progress_image: ProgressImage | None,
step: int,
percent: float,
total_steps: int,
) -> None:
"""Emitted when there is generation progress"""
self.__emit_session_event(
@ -32,8 +36,9 @@ class EventServiceBase:
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
progress_image=progress_image,
step=step,
percent=percent,
total_steps=total_steps,
),
)

View File

@ -44,6 +44,11 @@ class EdgeConnection(BaseModel):
return hash(f"{self.node_id}.{self.field}")
class Edge(BaseModel):
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")
def get_output_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_type())
@ -194,7 +199,7 @@ class Graph(BaseModel):
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
description="The nodes in this graph", default_factory=dict
)
edges: list[tuple[EdgeConnection, EdgeConnection]] = Field(
edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph",
default_factory=list,
)
@ -251,7 +256,7 @@ class Graph(BaseModel):
except NodeNotFoundError:
pass # Ignore, not doesn't exist (should this throw?)
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
def add_edge(self, edge: Edge) -> None:
"""Adds an edge to a graph
:raises InvalidEdgeError: the provided edge is invalid.
@ -262,7 +267,7 @@ class Graph(BaseModel):
else:
raise InvalidEdgeError()
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
def delete_edge(self, edge: Edge) -> None:
"""Deletes an edge from a graph"""
try:
@ -280,7 +285,7 @@ class Graph(BaseModel):
# Validate all edges reference nodes in the graph
node_ids = set(
[e[0].node_id for e in self.edges] + [e[1].node_id for e in self.edges]
[e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
)
if not all((self.has_node(node_id) for node_id in node_ids)):
return False
@ -294,10 +299,10 @@ class Graph(BaseModel):
if not all(
(
are_connections_compatible(
self.get_node(e[0].node_id),
e[0].field,
self.get_node(e[1].node_id),
e[1].field,
self.get_node(e.source.node_id),
e.source.field,
self.get_node(e.destination.node_id),
e.destination.field,
)
for e in self.edges
)
@ -328,58 +333,58 @@ class Graph(BaseModel):
return True
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
def _is_edge_valid(self, edge: Edge) -> bool:
"""Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
try:
from_node = self.get_node(edge[0].node_id)
to_node = self.get_node(edge[1].node_id)
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
return False
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge[1].node_id, edge[1].field)
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
return False
# Validate that no cycles would be created
g = self.nx_graph_flat()
g.add_edge(edge[0].node_id, edge[1].node_id)
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
return False
# Validate that the field types are compatible
if not are_connections_compatible(
from_node, edge[0].field, to_node, edge[1].field
from_node, edge.source.field, to_node, edge.destination.field
):
return False
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge[1].field == "collection":
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid(
edge[1].node_id, new_input=edge[0]
edge.destination.node_id, new_input=edge.source
):
return False
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge[0].field == "item":
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid(
edge[0].node_id, new_output=edge[1]
edge.source.node_id, new_output=edge.destination
):
return False
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge[1].field == "item":
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid(
edge[1].node_id, new_input=edge[0]
edge.destination.node_id, new_input=edge.source
):
return False
# Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge[0].field == "collection":
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid(
edge[0].node_id, new_output=edge[1]
edge.source.node_id, new_output=edge.destination
):
return False
@ -438,15 +443,15 @@ class Graph(BaseModel):
# Remove the graph prefix from the node path
new_graph_node_path = (
new_node.id
if "." not in edge[1].node_id
else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}'
if "." not in edge.destination.node_id
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
)
graph.add_edge(
(
edge[0],
EdgeConnection(
node_id=new_graph_node_path, field=edge[1].field
),
Edge(
source=edge.source,
destination=EdgeConnection(
node_id=new_graph_node_path, field=edge.destination.field
)
)
)
@ -454,51 +459,51 @@ class Graph(BaseModel):
# Remove the graph prefix from the node path
new_graph_node_path = (
new_node.id
if "." not in edge[0].node_id
else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}'
if "." not in edge.source.node_id
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
)
graph.add_edge(
(
EdgeConnection(
node_id=new_graph_node_path, field=edge[0].field
Edge(
source=EdgeConnection(
node_id=new_graph_node_path, field=edge.source.field
),
edge[1],
destination=edge.destination
)
)
def _get_input_edges(
self, node_path: str, field: Optional[str] = None
) -> list[tuple[EdgeConnection, EdgeConnection]]:
) -> list[Edge]:
"""Gets all input edges for a node"""
edges = self._get_input_edges_and_graphs(node_path)
# Filter to edges that match the field
filtered_edges = (e for e in edges if field is None or e[2][1].field == field)
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
# Create full node paths for each edge
return [
(
EdgeConnection(
node_id=self._get_node_path(e[0].node_id, prefix=prefix),
field=e[0].field,
),
EdgeConnection(
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
field=e[1].field,
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
)
)
for _, prefix, e in filtered_edges
]
def _get_input_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]:
) -> list[tuple["Graph", str, Edge]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = list()
# Return any input edges that appear in this graph
edges.extend(
[(self, prefix, e) for e in self.edges if e[1].node_id == node_path]
[(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
)
node_id = (
@ -522,37 +527,37 @@ class Graph(BaseModel):
def _get_output_edges(
self, node_path: str, field: str
) -> list[tuple[EdgeConnection, EdgeConnection]]:
) -> list[Edge]:
"""Gets all output edges for a node"""
edges = self._get_output_edges_and_graphs(node_path)
# Filter to edges that match the field
filtered_edges = (e for e in edges if e[2][0].field == field)
filtered_edges = (e for e in edges if e[2].source.field == field)
# Create full node paths for each edge
return [
(
EdgeConnection(
node_id=self._get_node_path(e[0].node_id, prefix=prefix),
field=e[0].field,
),
EdgeConnection(
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
field=e[1].field,
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
)
)
for _, prefix, e in filtered_edges
]
def _get_output_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]:
) -> list[tuple["Graph", str, Edge]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = list()
# Return any input edges that appear in this graph
edges.extend(
[(self, prefix, e) for e in self.edges if e[0].node_id == node_path]
[(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
)
node_id = (
@ -580,8 +585,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = list([e[0] for e in self._get_input_edges(node_path, "collection")])
outputs = list([e[1] for e in self._get_output_edges(node_path, "item")])
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
if new_input is not None:
inputs.append(new_input)
@ -622,8 +627,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = list([e[0] for e in self._get_input_edges(node_path, "item")])
outputs = list([e[1] for e in self._get_output_edges(node_path, "collection")])
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
if new_input is not None:
inputs.append(new_input)
@ -684,7 +689,7 @@ class Graph(BaseModel):
# TODO: Cache this?
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.keys()])
g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges]))
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_flat(
@ -711,7 +716,7 @@ class Graph(BaseModel):
# TODO: figure out if iteration nodes need to be expanded
unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges])
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
g.add_edges_from(
[
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
@ -768,6 +773,24 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
# Declare all fields as required; necessary for OpenAPI schema generation build.
# Technically only fields without a `default_factory` need to be listed here.
# See: https://github.com/pydantic/pydantic/discussions/4577
class Config:
schema_extra = {
'required': [
'id',
'graph',
'execution_graph',
'executed',
'executed_history',
'results',
'errors',
'prepared_source_mapping',
'source_prepared_mapping',
]
}
def next(self) -> BaseInvocation | None:
"""Gets the next node ready to execute."""
@ -841,13 +864,13 @@ class GraphExecutionState(BaseModel):
input_collection_prepared_node_id = next(
n[1]
for n in iteration_node_map
if n[0] == input_collection_edge[0].node_id
if n[0] == input_collection_edge.source.node_id
)
input_collection_prepared_node_output = self.results[
input_collection_prepared_node_id
]
input_collection = getattr(
input_collection_prepared_node_output, input_collection_edge[0].field
input_collection_prepared_node_output, input_collection_edge.source.field
)
self_iteration_count = len(input_collection)
@ -864,11 +887,11 @@ class GraphExecutionState(BaseModel):
new_edges = list()
for edge in input_edges:
for input_node_id in (
n[1] for n in iteration_node_map if n[0] == edge[0].node_id
n[1] for n in iteration_node_map if n[0] == edge.source.node_id
):
new_edge = (
EdgeConnection(node_id=input_node_id, field=edge[0].field),
EdgeConnection(node_id="", field=edge[1].field),
new_edge = Edge(
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
destination=EdgeConnection(node_id="", field=edge.destination.field),
)
new_edges.append(new_edge)
@ -893,9 +916,9 @@ class GraphExecutionState(BaseModel):
# Add new edges to execution graph
for edge in new_edges:
new_edge = (
edge[0],
EdgeConnection(node_id=new_node.id, field=edge[1].field),
new_edge = Edge(
source=edge.source,
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
)
self.execution_graph.add_edge(new_edge)
@ -1043,26 +1066,26 @@ class GraphExecutionState(BaseModel):
return self.execution_graph.nodes[next_node]
def _prepare_inputs(self, node: BaseInvocation):
input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id]
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
if isinstance(node, CollectInvocation):
output_collection = [
getattr(self.results[edge[0].node_id], edge[0].field)
getattr(self.results[edge.source.node_id], edge.source.field)
for edge in input_edges
if edge[1].field == "item"
if edge.destination.field == "item"
]
setattr(node, "collection", output_collection)
else:
for edge in input_edges:
output_value = getattr(self.results[edge[0].node_id], edge[0].field)
setattr(node, edge[1].field, output_value)
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
setattr(node, edge.destination.field, output_value)
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
def _is_edge_valid(self, edge: Edge) -> bool:
if not self._is_edge_valid(edge):
return False
# Invalid if destination has already been prepared or executed
if edge[1].node_id in self.source_prepared_mapping:
if edge.destination.node_id in self.source_prepared_mapping:
return False
# Otherwise, the edge is valid
@ -1089,17 +1112,17 @@ class GraphExecutionState(BaseModel):
)
self.graph.delete_node(node_path)
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
if not self._is_node_updatable(edge[1].node_id):
def add_edge(self, edge: Edge) -> None:
if not self._is_node_updatable(edge.destination.node_id):
raise NodeAlreadyExecutedError(
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to"
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot be linked to"
)
self.graph.add_edge(edge)
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
if not self._is_node_updatable(edge[1].node_id):
def delete_edge(self, edge: Edge) -> None:
if not self._is_node_updatable(edge.destination.node_id):
raise NodeAlreadyExecutedError(
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted"
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted"
)
self.graph.delete_edge(edge)

View File

@ -490,7 +490,7 @@ class Args(object):
"-z",
type=int,
default=6,
choices=range(0, 9),
choices=range(0, 10),
dest="png_compression",
help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.",
)
@ -943,7 +943,6 @@ class Args(object):
"--png_compression",
"-z",
type=int,
default=6,
choices=range(0, 10),
dest="png_compression",
help="level of PNG compression, from 0 (none) to 9 (maximum). [6]",

View File

@ -154,6 +154,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
for i in iteration_count:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
step_callback=step_callback,
sampler=scheduler,
**generator_args,
)
@ -497,7 +498,8 @@ class Generator:
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
return matched_result
def sample_to_lowres_estimated_image(self, samples):
@staticmethod
def sample_to_lowres_estimated_image(samples):
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7

View File

@ -159,6 +159,7 @@ class Inpaint(Img2Img):
seam_size: int,
seam_blur: int,
prompt,
seed,
sampler,
steps,
cfg_scale,
@ -192,7 +193,7 @@ class Inpaint(Img2Img):
seam_noise = self.get_noise(im.width, im.height)
result = make_image(seam_noise)
result = make_image(seam_noise, seed)
return result
@ -342,6 +343,7 @@ class Inpaint(Img2Img):
seam_size,
seam_blur,
prompt,
seed,
sampler,
seam_steps,
cfg_scale,

View File

@ -372,22 +372,32 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100:
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
if extract_ema:
print(" | Extracting EMA weights (usually better for inference)")
print(" | Extracting EMA weights (usually better for inference)")
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key
)
flat_ema_key_alt = "model_ema." + "".join(key.split(".")[2:])
if flat_ema_key in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key
)
elif flat_ema_key_alt in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key_alt
)
else:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
key
)
else:
print(
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
)
for key in keys:
if key.startswith(unet_key):
if key.startswith("model.diffusion_model") and key in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
@ -1026,6 +1036,15 @@ def convert_open_clip_checkpoint(checkpoint):
return text_model
def replace_checkpoint_vae(checkpoint, vae_path:str):
if vae_path.endswith(".safetensors"):
vae_ckpt = load_file(vae_path)
else:
vae_ckpt = torch.load(vae_path, map_location="cpu")
state_dict = vae_ckpt['state_dict'] if "state_dict" in vae_ckpt else vae_ckpt
for vae_key in state_dict:
new_key = f'first_stage_model.{vae_key}'
checkpoint[new_key] = state_dict[vae_key]
def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path: str,
@ -1038,8 +1057,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
extract_ema: bool = True,
upcast_attn: bool = False,
vae: AutoencoderKL = None,
vae_path: str = None,
precision: torch.dtype = torch.float32,
return_generator_pipeline: bool = False,
scan_needed:bool=True,
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
"""
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@ -1067,6 +1088,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
running stable diffusion 2.1.
:param vae: A diffusers VAE to load into the pipeline.
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
"""
with warnings.catch_warnings():
@ -1074,11 +1097,13 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
checkpoint = (
load_file(checkpoint_path)
if Path(checkpoint_path).suffix == ".safetensors"
else torch.load(checkpoint_path)
)
if Path(checkpoint_path).suffix == '.ckpt':
if scan_needed:
ModelManager.scan_model(checkpoint_path,checkpoint_path)
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = load_file(checkpoint_path)
cache_dir = global_cache_dir("hub")
pipeline_class = (
StableDiffusionGeneratorPipeline
@ -1090,7 +1115,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
print(" | global_step key not found in model")
print(" | global_step key not found in model")
global_step = None
# sometimes there is a state_dict key and sometimes not
@ -1201,9 +1226,19 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model, or use the one passed
if not vae:
# If a replacement VAE path was specified, we'll incorporate that into
# the checkpoint model and then convert it
if vae_path:
print(f" | Converting VAE {vae_path}")
replace_checkpoint_vae(checkpoint,vae_path)
# otherwise we use the original VAE, provided that
# an externally loaded diffusers VAE was not passed
elif not vae:
print(" | Using checkpoint model's original VAE")
if vae:
print(" | Using replacement diffusers VAE")
else: # convert the original or replacement VAE
vae_config = create_vae_diffusers_config(
original_config, image_size=image_size
)
@ -1213,8 +1248,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
else:
print(" | Using external VAE specified in config")
# Convert the text model.
model_type = pipeline_type

View File

@ -34,7 +34,7 @@ from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume
from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
V1 = 1
@ -45,9 +45,6 @@ class SDLegacyType(Enum):
UNKNOWN = 99
DEFAULT_MAX_MODELS = 2
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
"vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse",
}
class ModelManager(object):
'''
@ -285,13 +282,13 @@ class ModelManager(object):
self.stack.remove(model_name)
if delete_files:
if weights:
print(f"** deleting file {weights}")
print(f"** Deleting file {weights}")
Path(weights).unlink(missing_ok=True)
elif path:
print(f"** deleting directory {path}")
print(f"** Deleting directory {path}")
rmtree(path, ignore_errors=True)
elif repo_id:
print(f"** deleting the cached model directory for {repo_id}")
print(f"** Deleting the cached model directory for {repo_id}")
self._delete_model_from_cache(repo_id)
def add_model(
@ -362,6 +359,7 @@ class ModelManager(object):
raise NotImplementedError(
f"Unknown model format {model_name}: {model_format}"
)
self._add_embeddings_to_model(model)
# usage statistics
toc = time.time()
@ -381,9 +379,9 @@ class ModelManager(object):
print(f">> Loading diffusers model from {name_or_path}")
if using_fp16:
print(" | Using faster float16 precision")
print(" | Using faster float16 precision")
else:
print(" | Using more accurate float32 precision")
print(" | Using more accurate float32 precision")
# TODO: scan weights maybe?
pipeline_args: dict[str, Any] = dict(
@ -434,10 +432,9 @@ class ModelManager(object):
# square images???
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
height = width
print(f" | Default image dimensions = {width} x {height}")
print(f" | Default image dimensions = {width} x {height}")
self._add_embeddings_to_model(pipeline)
return pipeline, width, height, model_hash
def _load_ckpt_model(self, model_name, mconfig):
@ -457,15 +454,21 @@ class ModelManager(object):
from . import load_pipeline_from_original_stable_diffusion_ckpt
self.offload_model(self.current_model)
if vae_config := self._choose_diffusers_vae(model_name):
vae = self._load_vae(vae_config)
try:
if self.list_models()[self.current_model]['status'] == 'active':
self.offload_model(self.current_model)
except Exception as e:
pass
vae_path = None
if vae:
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
if self._has_cuda():
torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=weights,
original_config_file=config,
vae=vae,
vae_path=vae_path,
return_generator_pipeline=True,
precision=torch.float16 if self.precision == "float16" else torch.float32,
)
@ -512,18 +515,20 @@ class ModelManager(object):
print(f">> Offloading {model_name} to CPU")
model = self.models[model_name]["model"]
model.offload_all()
self.current_model = None
gc.collect()
if self._has_cuda():
torch.cuda.empty_cache()
@classmethod
def scan_model(self, model_name, checkpoint):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
print(f">> Scanning Model: {model_name}")
print(f" | Scanning Model: {model_name}")
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if scan_result.infected_files == 1:
@ -546,7 +551,7 @@ class ModelManager(object):
print("### Exiting InvokeAI")
sys.exit()
else:
print(">> Model scanned ok")
print(" | Model scanned ok")
def import_diffuser_model(
self,
@ -665,7 +670,7 @@ class ModelManager(object):
print(f">> Probing {thing} for import")
if thing.startswith(("http:", "https:", "ftp:")):
print(f" | {thing} appears to be a URL")
print(f" | {thing} appears to be a URL")
model_path = self._resolve_path(
thing, "models/ldm/stable-diffusion-v1"
) # _resolve_path does a download if needed
@ -673,15 +678,15 @@ class ModelManager(object):
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
print(
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
)
return
else:
print(f" | {thing} appears to be a checkpoint file on disk")
print(f" | {thing} appears to be a checkpoint file on disk")
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
print(f" | {thing} appears to be a diffusers file on disk")
print(f" | {thing} appears to be a diffusers file on disk")
model_name = self.import_diffuser_model(
thing,
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
@ -692,13 +697,13 @@ class ModelManager(object):
elif Path(thing).is_dir():
if (Path(thing) / "model_index.json").exists():
print(f" | {thing} appears to be a diffusers model.")
print(f" | {thing} appears to be a diffusers model.")
model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf
)
else:
print(
f" |{thing} appears to be a directory. Will scan for models to import"
f" |{thing} appears to be a directory. Will scan for models to import"
)
for m in list(Path(thing).rglob("*.ckpt")) + list(
Path(thing).rglob("*.safetensors")
@ -710,7 +715,7 @@ class ModelManager(object):
return model_name
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf
)
@ -727,32 +732,33 @@ class ModelManager(object):
return
if model_path.stem in self.config: # already imported
print(" | Already imported. Skipping")
print(" | Already imported. Skipping")
return model_path.stem
# another round of heuristics to guess the correct config file.
checkpoint = (
safetensors.torch.load_file(model_path)
if model_path.suffix == ".safetensors"
else torch.load(model_path)
)
checkpoint = None
if model_path.suffix.endswith((".ckpt",".pt")):
self.scan_model(model_path,model_path)
checkpoint = torch.load(model_path)
else:
checkpoint = safetensors.torch.load_file(model_path)
# additional probing needed if no config file provided
if model_config_file is None:
model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1:
print(" | SD-v1 model detected")
print(" | SD-v1 model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
)
elif model_type == SDLegacyType.V1_INPAINT:
print(" | SD-v1 inpainting model detected")
print(" | SD-v1 inpainting model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
)
elif model_type == SDLegacyType.V2_v:
print(
" | SD-v2-v model detected; model will be converted to diffusers format"
" | SD-v2-v model detected; model will be converted to diffusers format"
)
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
@ -760,7 +766,7 @@ class ModelManager(object):
convert = True
elif model_type == SDLegacyType.V2_e:
print(
" | SD-v2-e model detected; model will be converted to diffusers format"
" | SD-v2-e model detected; model will be converted to diffusers format"
)
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
@ -788,18 +794,21 @@ class ModelManager(object):
model_description=description,
original_config_file=model_config_file,
commit_to_conf=commit_to_conf,
scan_needed=False,
)
return model_name
def convert_and_import(
self,
ckpt_path: Path,
diffusers_path: Path,
model_name=None,
model_description=None,
vae=None,
original_config_file: Path = None,
commit_to_conf: Path = None,
self,
ckpt_path: Path,
diffusers_path: Path,
model_name=None,
model_description=None,
vae:dict=None,
vae_path:Path=None,
original_config_file: Path = None,
commit_to_conf: Path = None,
scan_needed: bool=True,
) -> str:
"""
Convert a legacy ckpt weights file to diffuser model and import
@ -827,18 +836,23 @@ class ModelManager(object):
try:
# By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file
vae_model = self._load_vae(vae) if vae else None
vae_model=None
if vae:
vae_model=self._load_vae(vae)
vae_path=None
convert_ckpt_to_diffusers(
ckpt_path,
diffusers_path,
extract_ema=True,
original_config_file=original_config_file,
vae=vae_model,
vae_path=vae_path,
scan_needed=scan_needed,
)
print(
f" | Success. Optimized model is now located at {str(diffusers_path)}"
f" | Success. Optimized model is now located at {str(diffusers_path)}"
)
print(f" | Writing new config file entry for {model_name}")
print(f" | Writing new config file entry for {model_name}")
new_config = dict(
path=str(diffusers_path),
description=model_description,
@ -849,7 +863,7 @@ class ModelManager(object):
self.add_model(model_name, new_config, True)
if commit_to_conf:
self.commit(commit_to_conf)
print(">> Conversion succeeded")
print(" | Conversion succeeded")
except Exception as e:
print(f"** Conversion failed: {str(e)}")
print(
@ -879,36 +893,6 @@ class ModelManager(object):
return search_folder, found_models
def _choose_diffusers_vae(
self, model_name: str, vae: str = None
) -> Union[dict, str]:
# In the event that the original entry is using a custom ckpt VAE, we try to
# map that VAE onto a diffuser VAE using a hard-coded dictionary.
# I would prefer to do this differently: We load the ckpt model into memory, swap the
# VAE in memory, and then pass that to convert_ckpt_to_diffuser() so that the swapped
# VAE is built into the model. However, when I tried this I got obscure key errors.
if vae:
return vae
if model_name in self.config and (
vae_ckpt_path := self.model_info(model_name).get("vae", None)
):
vae_basename = Path(vae_ckpt_path).stem
diffusers_vae = None
if diffusers_vae := VAE_TO_REPO_ID.get(vae_basename, None):
print(
f">> {vae_basename} VAE corresponds to known {diffusers_vae} diffusers version"
)
vae = {"repo_id": diffusers_vae}
else:
print(
f'** Custom VAE "{vae_basename}" found, but corresponding diffusers model unknown'
)
print(
'** Using "stabilityai/sd-vae-ft-mse"; If this isn\'t right, please edit the model config'
)
vae = {"repo_id": "stabilityai/sd-vae-ft-mse"}
return vae
def _make_cache_room(self) -> None:
num_loaded_models = len(self.models)
if num_loaded_models >= self.max_loaded_models:
@ -1105,7 +1089,7 @@ class ModelManager(object):
with open(hashpath) as f:
hash = f.read()
return hash
print(" | Calculating sha256 hash of model files")
print(" | Calculating sha256 hash of model files")
tic = time.time()
sha = hashlib.sha256()
count = 0
@ -1117,7 +1101,7 @@ class ModelManager(object):
sha.update(chunk)
hash = sha.hexdigest()
toc = time.time()
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
with open(hashpath, "w") as f:
f.write(hash)
return hash
@ -1162,12 +1146,12 @@ class ModelManager(object):
local_files_only=not Globals.internet_available,
)
print(f" | Loading diffusers VAE from {name_or_path}")
print(f" | Loading diffusers VAE from {name_or_path}")
if using_fp16:
vae_args.update(torch_dtype=torch.float16)
fp_args_list = [{"revision": "fp16"}, {}]
else:
print(" | Using more accurate float32 precision")
print(" | Using more accurate float32 precision")
fp_args_list = [{}]
vae = None
@ -1208,7 +1192,7 @@ class ModelManager(object):
hashes_to_delete.add(revision.commit_hash)
strategy = cache_info.delete_revisions(*hashes_to_delete)
print(
f"** deletion of this model is expected to free {strategy.expected_freed_size_str}"
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
)
strategy.execute()

View File

@ -6,7 +6,6 @@ The interface is through the Concepts() object.
"""
import os
import re
import traceback
from typing import Callable
from urllib import error as ul_error
from urllib import request
@ -15,7 +14,6 @@ from huggingface_hub import (
HfApi,
HfFolder,
ModelFilter,
ModelSearchArguments,
hf_hub_url,
)
@ -84,7 +82,7 @@ class HuggingFaceConceptsLibrary(object):
"""
if not concept_name in self.list_concepts():
print(
f"This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
)
return None
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
@ -236,7 +234,7 @@ class HuggingFaceConceptsLibrary(object):
except ul_error.HTTPError as e:
if e.code == 404:
print(
f"This concept is not known to the Hugging Face library. Generation will continue without the concept."
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
)
else:
print(
@ -246,7 +244,7 @@ class HuggingFaceConceptsLibrary(object):
return False
except ul_error.URLError as e:
print(
f"ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
)
os.rmdir(dest)
return False

View File

@ -3,6 +3,9 @@ import math
import multiprocessing as mp
import os
import re
import io
import base64
from collections import abc
from inspect import isfunction
from pathlib import Path
@ -364,3 +367,16 @@ def url_attachment_name(url: str) -> dict:
def download_with_progress_bar(url: str, dest: Path) -> bool:
result = download_with_resume(url, dest, access_token=None)
return result is not None
def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
"""
Converts an image into a base64 image dataURL.
"""
buffered = io.BytesIO()
image.save(buffered, format=image_format)
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
return image_base64

View File

@ -772,16 +772,10 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
original_config_file = Path(model_info["config"])
model_name = model_name_or_path
model_description = model_info["description"]
vae = model_info["vae"]
vae_path = model_info.get("vae")
else:
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
return
if vae_repo := invokeai.backend.model_management.model_manager.VAE_TO_REPO_ID.get(
Path(vae).stem
):
vae_repo = dict(repo_id=vae_repo)
else:
vae_repo = None
model_name = manager.convert_and_import(
ckpt_path,
diffusers_path=Path(
@ -790,7 +784,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
model_name=model_name,
model_description=model_description,
original_config_file=original_config_file,
vae=vae_repo,
vae_path=vae_path,
)
else:
try:

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-2ad84bef.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-e63a2dc4.js";var Or=`
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-f7f41e1f.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-eaf47ae3.js";var Or=`
:root {
--chakra-vh: 100vh;
}

View File

@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-2ad84bef.js"></script>
<script type="module" crossorigin src="./assets/index-f7f41e1f.js"></script>
<link rel="stylesheet" href="./assets/index-5483945c.css">
</head>

View File

@ -64,6 +64,8 @@
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
"upload": "Upload",
"close": "Close",
"cancel": "Cancel",
"accept": "Accept",
"load": "Load",
"back": "Back",
"statusConnected": "Connected",
@ -333,6 +335,7 @@
"addNewModel": "Add New Model",
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
"addDiffuserModel": "Add Diffusers",
"scanForModels": "Scan For Models",
"addManually": "Add Manually",
"manual": "Manual",
"name": "Name",

View File

@ -1,3 +1,7 @@
import React, { PropsWithChildren } from 'react';
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
export {};
declare module 'redux-socket.io-middleware';
@ -40,5 +44,35 @@ declare global {
/* eslint-enable @typescript-eslint/no-explicit-any */
}
declare function Invoke(): React.JSX;
declare module '@invoke-ai/invoke-ai-ui' {
declare class ThemeChanger extends React.Component<ThemeChangerProps> {
public constructor(props: ThemeChangerProps);
}
declare class InvokeAiLogoComponent extends React.Component<InvokeAILogoComponentProps> {
public constructor(props: InvokeAILogoComponentProps);
}
declare class IAIPopover extends React.Component<IAIPopoverProps> {
public constructor(props: IAIPopoverProps);
}
declare class IAIIconButton extends React.Component<IAIIconButtonProps> {
public constructor(props: IAIIconButtonProps);
}
declare class SettingsModal extends React.Component<SettingsModalProps> {
public constructor(props: SettingsModalProps);
}
}
declare function Invoke(props: PropsWithChildren): JSX.Element;
export {
ThemeChanger,
InvokeAiLogoComponent,
IAIPopover,
IAIIconButton,
SettingsModal,
};
export = Invoke;

View File

@ -6,7 +6,6 @@
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"build:package": "vite build --mode=package",
"preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",

View File

@ -8,7 +8,6 @@
"darkTheme": "داكن",
"lightTheme": "فاتح",
"greenTheme": "أخضر",
"text2img": "نص إلى صورة",
"img2img": "صورة إلى صورة",
"unifiedCanvas": "لوحة موحدة",
"nodes": "عقد",

View File

@ -7,7 +7,6 @@
"darkTheme": "Dunkel",
"lightTheme": "Hell",
"greenTheme": "Grün",
"text2img": "Text zu Bild",
"img2img": "Bild zu Bild",
"nodes": "Knoten",
"langGerman": "Deutsch",

View File

@ -64,6 +64,8 @@
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
"upload": "Upload",
"close": "Close",
"cancel": "Cancel",
"accept": "Accept",
"load": "Load",
"back": "Back",
"statusConnected": "Connected",
@ -333,6 +335,7 @@
"addNewModel": "Add New Model",
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
"addDiffuserModel": "Add Diffusers",
"scanForModels": "Scan For Models",
"addManually": "Add Manually",
"manual": "Manual",
"name": "Name",

View File

@ -8,7 +8,6 @@
"darkTheme": "Oscuro",
"lightTheme": "Claro",
"greenTheme": "Verde",
"text2img": "Texto a Imagen",
"img2img": "Imagen a Imagen",
"unifiedCanvas": "Lienzo Unificado",
"nodes": "Nodos",
@ -70,7 +69,11 @@
"langHebrew": "Hebreo",
"pinOptionsPanel": "Pin del panel de opciones",
"loading": "Cargando",
"loadingInvokeAI": "Cargando invocar a la IA"
"loadingInvokeAI": "Cargando invocar a la IA",
"postprocessing": "Tratamiento posterior",
"txt2img": "De texto a imagen",
"accept": "Aceptar",
"cancel": "Cancelar"
},
"gallery": {
"generations": "Generaciones",
@ -404,7 +407,8 @@
"none": "ninguno",
"pickModelType": "Elige el tipo de modelo",
"v2_768": "v2 (768px)",
"addDifference": "Añadir una diferencia"
"addDifference": "Añadir una diferencia",
"scanForModels": "Buscar modelos"
},
"parameters": {
"images": "Imágenes",
@ -574,7 +578,7 @@
"autoSaveToGallery": "Guardar automáticamente en galería",
"saveBoxRegionOnly": "Guardar solo región dentro de la caja",
"limitStrokesToBox": "Limitar trazos a la caja",
"showCanvasDebugInfo": "Mostrar información de depuración de lienzo",
"showCanvasDebugInfo": "Mostrar la información adicional del lienzo",
"clearCanvasHistory": "Limpiar historial de lienzo",
"clearHistory": "Limpiar historial",
"clearCanvasHistoryMessage": "Limpiar el historial de lienzo también restablece completamente el lienzo unificado. Esto incluye todo el historial de deshacer/rehacer, las imágenes en el área de preparación y la capa base del lienzo.",

View File

@ -8,7 +8,6 @@
"darkTheme": "Sombre",
"lightTheme": "Clair",
"greenTheme": "Vert",
"text2img": "Texte en image",
"img2img": "Image en image",
"unifiedCanvas": "Canvas unifié",
"nodes": "Nœuds",
@ -47,7 +46,19 @@
"statusLoadingModel": "Chargement du modèle",
"statusModelChanged": "Modèle changé",
"discordLabel": "Discord",
"githubLabel": "Github"
"githubLabel": "Github",
"accept": "Accepter",
"statusMergingModels": "Mélange des modèles",
"loadingInvokeAI": "Chargement de Invoke AI",
"cancel": "Annuler",
"langEnglish": "Anglais",
"statusConvertingModel": "Conversion du modèle",
"statusModelConverted": "Modèle converti",
"loading": "Chargement",
"pinOptionsPanel": "Épingler la page d'options",
"statusMergedModels": "Modèles mélangés",
"txt2img": "Texte vers image",
"postprocessing": "Post-Traitement"
},
"gallery": {
"generations": "Générations",
@ -518,5 +529,15 @@
"betaDarkenOutside": "Assombrir à l'extérieur",
"betaLimitToBox": "Limiter à la boîte",
"betaPreserveMasked": "Conserver masqué"
},
"accessibility": {
"uploadImage": "Charger une image",
"reset": "Réinitialiser",
"nextImage": "Image suivante",
"previousImage": "Image précédente",
"useThisParameter": "Utiliser ce paramètre",
"zoomIn": "Zoom avant",
"zoomOut": "Zoom arrière",
"showOptionsPanel": "Montrer la page d'options"
}
}

View File

@ -125,7 +125,6 @@
"langSimplifiedChinese": "סינית",
"langUkranian": "אוקראינית",
"langSpanish": "ספרדית",
"text2img": "טקסט לתמונה",
"img2img": "תמונה לתמונה",
"unifiedCanvas": "קנבס מאוחד",
"nodes": "צמתים",

View File

@ -8,7 +8,6 @@
"darkTheme": "Scuro",
"lightTheme": "Chiaro",
"greenTheme": "Verde",
"text2img": "Testo a Immagine",
"img2img": "Immagine a Immagine",
"unifiedCanvas": "Tela unificata",
"nodes": "Nodi",
@ -70,7 +69,11 @@
"loading": "Caricamento in corso",
"oceanTheme": "Oceano",
"langHebrew": "Ebraico",
"loadingInvokeAI": "Caricamento Invoke AI"
"loadingInvokeAI": "Caricamento Invoke AI",
"postprocessing": "Post Elaborazione",
"txt2img": "Testo a Immagine",
"accept": "Accetta",
"cancel": "Annulla"
},
"gallery": {
"generations": "Generazioni",
@ -404,7 +407,8 @@
"v2_768": "v2 (768px)",
"none": "niente",
"addDifference": "Aggiungi differenza",
"pickModelType": "Scegli il tipo di modello"
"pickModelType": "Scegli il tipo di modello",
"scanForModels": "Cerca modelli"
},
"parameters": {
"images": "Immagini",
@ -574,7 +578,7 @@
"autoSaveToGallery": "Salvataggio automatico nella Galleria",
"saveBoxRegionOnly": "Salva solo l'area di selezione",
"limitStrokesToBox": "Limita i tratti all'area di selezione",
"showCanvasDebugInfo": "Mostra informazioni di debug della Tela",
"showCanvasDebugInfo": "Mostra ulteriori informazioni sulla Tela",
"clearCanvasHistory": "Cancella cronologia Tela",
"clearHistory": "Cancella la cronologia",
"clearCanvasHistoryMessage": "La cancellazione della cronologia della tela lascia intatta la tela corrente, ma cancella in modo irreversibile la cronologia degli annullamenti e dei ripristini.",
@ -612,7 +616,7 @@
"copyMetadataJson": "Copia i metadati JSON",
"exitViewer": "Esci dal visualizzatore",
"zoomIn": "Zoom avanti",
"zoomOut": "Zoom Indietro",
"zoomOut": "Zoom indietro",
"rotateCounterClockwise": "Ruotare in senso antiorario",
"rotateClockwise": "Ruotare in senso orario",
"flipHorizontally": "Capovolgi orizzontalmente",

View File

@ -11,7 +11,6 @@
"langArabic": "العربية",
"langEnglish": "English",
"langDutch": "Nederlands",
"text2img": "텍스트->이미지",
"unifiedCanvas": "통합 캔버스",
"langFrench": "Français",
"langGerman": "Deutsch",

View File

@ -8,7 +8,6 @@
"darkTheme": "Donker",
"lightTheme": "Licht",
"greenTheme": "Groen",
"text2img": "Tekst naar afbeelding",
"img2img": "Afbeelding naar afbeelding",
"unifiedCanvas": "Centraal canvas",
"nodes": "Knooppunten",

View File

@ -8,7 +8,6 @@
"darkTheme": "Ciemny",
"lightTheme": "Jasny",
"greenTheme": "Zielony",
"text2img": "Tekst na obraz",
"img2img": "Obraz na obraz",
"unifiedCanvas": "Tryb uniwersalny",
"nodes": "Węzły",

View File

@ -20,7 +20,6 @@
"langSpanish": "Espanhol",
"langRussian": "Русский",
"langUkranian": "Украї́нська",
"text2img": "Texto para Imagem",
"img2img": "Imagem para Imagem",
"unifiedCanvas": "Tela Unificada",
"nodes": "Nós",

View File

@ -8,7 +8,6 @@
"darkTheme": "Noite",
"lightTheme": "Dia",
"greenTheme": "Verde",
"text2img": "Texto Para Imagem",
"img2img": "Imagem Para Imagem",
"unifiedCanvas": "Tela Unificada",
"nodes": "Nódulos",

View File

@ -8,7 +8,6 @@
"darkTheme": "Темная",
"lightTheme": "Светлая",
"greenTheme": "Зеленая",
"text2img": "Изображение из текста (text2img)",
"img2img": "Изображение в изображение (img2img)",
"unifiedCanvas": "Универсальный холст",
"nodes": "Ноды",

View File

@ -8,7 +8,6 @@
"darkTheme": "Темна",
"lightTheme": "Світла",
"greenTheme": "Зелена",
"text2img": "Зображення із тексту (text2img)",
"img2img": "Зображення із зображення (img2img)",
"unifiedCanvas": "Універсальне полотно",
"nodes": "Вузли",

View File

@ -8,7 +8,6 @@
"darkTheme": "暗色",
"lightTheme": "亮色",
"greenTheme": "绿色",
"text2img": "文字到图像",
"img2img": "图像到图像",
"unifiedCanvas": "统一画布",
"nodes": "节点",

View File

@ -33,7 +33,6 @@
"langBrPortuguese": "巴西葡萄牙語",
"langRussian": "俄語",
"langSpanish": "西班牙語",
"text2img": "文字到圖像",
"unifiedCanvas": "統一畫布"
}
}

View File

@ -14,11 +14,11 @@ import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
import Lightbox from 'features/lightbox/components/Lightbox';
import { useAppSelector } from './storeHooks';
import { useEffect } from 'react';
import { PropsWithChildren, useEffect } from 'react';
keepGUIAlive();
const App = () => {
const App = (props: PropsWithChildren) => {
useToastWatcher();
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
@ -40,7 +40,7 @@ const App = () => {
w={APP_WIDTH}
h={APP_HEIGHT}
>
<SiteHeader />
{props.children || <SiteHeader />}
<Flex gap={4} w="full" h="full">
<InvokeTabs />
<ImageGalleryPanel />

View File

@ -31,18 +31,14 @@ export const DIFFUSERS_SAMPLERS: Array<string> = [
];
// Valid image widths
export const WIDTHS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792,
1856, 1920, 1984, 2048,
];
export const WIDTHS: Array<number> = Array.from(Array(65)).map(
(_x, i) => i * 64
);
// Valid image heights
export const HEIGHTS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792,
1856, 1920, 1984, 2048,
];
export const HEIGHTS: Array<number> = Array.from(Array(65)).map(
(_x, i) => i * 64
);
// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [

View File

@ -9,6 +9,7 @@ import {
useDisclosure,
} from '@chakra-ui/react';
import { cloneElement, memo, ReactElement, ReactNode, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import IAIButton from './IAIButton';
type Props = {
@ -22,10 +23,12 @@ type Props = {
};
const IAIAlertDialog = forwardRef((props: Props, ref) => {
const { t } = useTranslation();
const {
acceptButtonText = 'Accept',
acceptButtonText = t('common.accept'),
acceptCallback,
cancelButtonText = 'Cancel',
cancelButtonText = t('common.cancel'),
cancelCallback,
children,
title,
@ -56,6 +59,7 @@ const IAIAlertDialog = forwardRef((props: Props, ref) => {
isOpen={isOpen}
leastDestructiveRef={cancelRef}
onClose={onClose}
isCentered
>
<AlertDialogOverlay>
<AlertDialogContent>

View File

@ -0,0 +1,8 @@
import { chakra } from '@chakra-ui/react';
/**
* Chakra-enabled <form />
*/
const IAIForm = chakra.form;
export default IAIForm;

View File

@ -0,0 +1,23 @@
import { Flex } from '@chakra-ui/react';
import { ReactElement } from 'react';
export function IAIFormItemWrapper({
children,
}: {
children: ReactElement | ReactElement[];
}) {
return (
<Flex
sx={{
flexDirection: 'column',
padding: 4,
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: 'base.900',
}}
>
{children}
</Flex>
);
}

View File

@ -8,7 +8,7 @@ import {
} from '@chakra-ui/react';
import { memo, ReactNode } from 'react';
type IAIPopoverProps = PopoverProps & {
export type IAIPopoverProps = PopoverProps & {
triggerComponent: ReactNode;
triggerContainerProps?: BoxProps;
children: ReactNode;

View File

@ -1,4 +1,4 @@
import React, { lazy } from 'react';
import React, { lazy, PropsWithChildren } from 'react';
import { Provider } from 'react-redux';
import { PersistGate } from 'redux-persist/integration/react';
import { store } from './app/store';
@ -21,14 +21,14 @@ import './i18n';
const App = lazy(() => import('./app/App'));
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
export default function Component() {
export default function Component(props: PropsWithChildren) {
return (
<React.StrictMode>
<Provider store={store}>
<PersistGate loading={<Loading />} persistor={persistor}>
<React.Suspense fallback={<Loading showText />}>
<ThemeLocaleProvider>
<App />
<App>{props.children}</App>
</ThemeLocaleProvider>
</React.Suspense>
</PersistGate>

View File

@ -0,0 +1,16 @@
import Component from './component';
import InvokeAiLogoComponent from './features/system/components/InvokeAILogoComponent';
import ThemeChanger from './features/system/components/ThemeChanger';
import IAIPopover from './common/components/IAIPopover';
import IAIIconButton from './common/components/IAIIconButton';
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
export default Component;
export {
InvokeAiLogoComponent,
ThemeChanger,
IAIPopover,
IAIIconButton,
SettingsModal,
};

View File

@ -104,7 +104,6 @@ const IAICanvasMaskOptions = () => {
return (
<IAIPopover
trigger="hover"
triggerComponent={
<ButtonGroup>
<IAIIconButton

View File

@ -88,7 +88,7 @@ const IAICanvasSettingsButtonPopover = () => {
return (
<IAIPopover
trigger="hover"
isLazy={false}
triggerComponent={
<IAIIconButton
tooltip={t('unifiedCanvas.canvasSettings')}

View File

@ -219,7 +219,6 @@ const IAICanvasToolChooserOptions = () => {
onClick={handleSelectColorPickerTool}
/>
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
aria-label={t('unifiedCanvas.brushOptions')}

View File

@ -405,7 +405,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
>
<ButtonGroup isAttached={true}>
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
aria-label={`${t('parameters.sendTo')}...`}
@ -505,7 +504,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<ButtonGroup isAttached={true}>
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
icon={<FaGrinStars />}
@ -535,7 +533,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
</IAIPopover>
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
icon={<FaExpandArrowsAlt />}

View File

@ -0,0 +1,24 @@
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
type CurrentImageFallbackProps = SpinnerProps;
const CurrentImageFallback = (props: CurrentImageFallbackProps) => {
const { size = 'xl', ...rest } = props;
return (
<Flex
sx={{
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'absolute',
color: 'base.400',
}}
>
<Spinner size={size} {...rest} />
</Flex>
);
};
export default CurrentImageFallback;

View File

@ -7,6 +7,7 @@ import { isEqual } from 'lodash';
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
import { gallerySelector } from '../store/gallerySelectors';
import CurrentImageFallback from './CurrentImageFallback';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
@ -48,6 +49,7 @@ export default function CurrentImagePreview() {
src={imageToDisplay.url}
width={imageToDisplay.width}
height={imageToDisplay.height}
fallback={!isIntermediate ? <CurrentImageFallback /> : undefined}
sx={{
objectFit: 'contain',
maxWidth: '100%',

View File

@ -55,7 +55,6 @@ export default function LanguagePicker() {
return (
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
aria-label={t('common.languagePickerLabel')}

View File

@ -1,4 +1,5 @@
import {
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
@ -25,10 +26,10 @@ import { useTranslation } from 'react-i18next';
import type { InvokeModelConfigProps } from 'app/invokeai';
import type { RootState } from 'app/store';
import IAIIconButton from 'common/components/IAIIconButton';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import type { FieldInputProps, FormikProps } from 'formik';
import { BiArrowBack } from 'react-icons/bi';
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048;
@ -72,243 +73,250 @@ export default function AddCheckpointModel() {
return (
<VStack gap={2} alignItems="flex-start">
<IAIIconButton
aria-label={t('common.back')}
tooltip={t('common.back')}
onClick={() => dispatch(setAddNewModelUIOption(null))}
width="max-content"
position="absolute"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={3}
icon={<BiArrowBack />}
/>
<Flex columnGap={4}>
<IAICheckbox
isChecked={!addManually}
label={t('modelManager.scanForModels')}
onChange={() => setAddmanually(!addManually)}
/>
<IAICheckbox
label={t('modelManager.addManually')}
isChecked={addManually}
onChange={() => setAddmanually(!addManually)}
/>
</Flex>
<SearchModels />
<IAICheckbox
label={t('modelManager.addManually')}
isChecked={addManually}
onChange={() => setAddmanually(!addManually)}
/>
{addManually && (
{addManually ? (
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}>
<IAIForm onSubmit={handleSubmit} sx={{ w: 'full' }}>
<VStack rowGap={2}>
<Text fontSize={20} fontWeight="bold" alignSelf="start">
{t('modelManager.manual')}
</Text>
{/* Name */}
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
width="2xl"
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
width="full"
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="2xl"
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>{errors.description}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="full"
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>
{errors.description}
</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Config */}
<FormControl
isInvalid={!!errors.config && touched.config}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.config')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="config"
name="config"
type="text"
width="2xl"
/>
{!!errors.config && touched.config ? (
<FormErrorMessage>{errors.config}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.configValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.config && touched.config}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.config')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="config"
name="config"
type="text"
width="full"
/>
{!!errors.config && touched.config ? (
<FormErrorMessage>{errors.config}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.configValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Weights */}
<FormControl
isInvalid={!!errors.weights && touched.weights}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="weights"
name="weights"
type="text"
width="2xl"
/>
{!!errors.weights && touched.weights ? (
<FormErrorMessage>{errors.weights}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.weights && touched.weights}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="weights"
name="weights"
type="text"
width="full"
/>
{!!errors.weights && touched.weights ? (
<FormErrorMessage>{errors.weights}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* VAE */}
<FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae"
name="vae"
type="text"
width="2xl"
/>
{!!errors.vae && touched.vae ? (
<FormErrorMessage>{errors.vae}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae"
name="vae"
type="text"
width="full"
/>
{!!errors.vae && touched.vae ? (
<FormErrorMessage>{errors.vae}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<HStack width="100%">
{/* Width */}
<FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')}
</FormLabel>
<VStack alignItems="start">
<Field id="width" name="width">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="width"
name="width"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
width="90%"
value={form.values.width}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')}
</FormLabel>
<VStack alignItems="start">
<Field id="width" name="width">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="width"
name="width"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.width}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.width && touched.width ? (
<FormErrorMessage>{errors.width}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.widthValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{!!errors.width && touched.width ? (
<FormErrorMessage>{errors.width}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.widthValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Height */}
<FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')}
</FormLabel>
<VStack alignItems="start">
<Field id="height" name="height">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="height"
name="height"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
width="90%"
step={64}
value={form.values.height}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')}
</FormLabel>
<VStack alignItems="start">
<Field id="height" name="height">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="height"
name="height"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.height}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.height && touched.height ? (
<FormErrorMessage>{errors.height}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.heightValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{!!errors.height && touched.height ? (
<FormErrorMessage>{errors.height}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.heightValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
</HStack>
<IAIButton
@ -319,9 +327,11 @@ export default function AddCheckpointModel() {
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</form>
</IAIForm>
)}
</Formik>
) : (
<SearchModels />
)}
</VStack>
);

View File

@ -11,36 +11,14 @@ import { InvokeDiffusersModelConfigProps } from 'app/invokeai';
import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIInput from 'common/components/IAIInput';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import { BiArrowBack } from 'react-icons/bi';
import type { RootState } from 'app/store';
import type { ReactElement } from 'react';
function FormItemWrapper({
children,
}: {
children: ReactElement | ReactElement[];
}) {
return (
<Flex
sx={{
flexDirection: 'column',
padding: 4,
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: 'base.900',
}}
>
{children}
</Flex>
);
}
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
export default function AddDiffusersModel() {
const dispatch = useAppDispatch();
@ -89,26 +67,14 @@ export default function AddDiffusersModel() {
return (
<Flex>
<IAIIconButton
aria-label={t('common.back')}
tooltip={t('common.back')}
onClick={() => dispatch(setAddNewModelUIOption(null))}
width="max-content"
position="absolute"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={3}
icon={<BiArrowBack />}
/>
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}>
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2}>
<FormItemWrapper>
<IAIFormItemWrapper>
{/* Name */}
<FormControl
isInvalid={!!errors.name && touched.name}
@ -136,9 +102,9 @@ export default function AddDiffusersModel() {
)}
</VStack>
</FormControl>
</FormItemWrapper>
</IAIFormItemWrapper>
<FormItemWrapper>
<IAIFormItemWrapper>
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
@ -165,9 +131,9 @@ export default function AddDiffusersModel() {
)}
</VStack>
</FormControl>
</FormItemWrapper>
</IAIFormItemWrapper>
<FormItemWrapper>
<IAIFormItemWrapper>
<Text fontWeight="bold" fontSize="sm">
{t('modelManager.formMessageDiffusersModelLocation')}
</Text>
@ -226,9 +192,9 @@ export default function AddDiffusersModel() {
)}
</VStack>
</FormControl>
</FormItemWrapper>
</IAIFormItemWrapper>
<FormItemWrapper>
<IAIFormItemWrapper>
{/* VAE Path */}
<Text fontWeight="bold">
{t('modelManager.formMessageDiffusersVAELocation')}
@ -290,13 +256,13 @@ export default function AddDiffusersModel() {
)}
</VStack>
</FormControl>
</FormItemWrapper>
</IAIFormItemWrapper>
<IAIButton type="submit" isLoading={isProcessing}>
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</form>
</IAIForm>
)}
</Formik>
</Flex>

View File

@ -14,7 +14,7 @@ import {
import IAIButton from 'common/components/IAIButton';
import { FaPlus } from 'react-icons/fa';
import { FaArrowLeft, FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { useTranslation } from 'react-i18next';
@ -23,6 +23,7 @@ import type { RootState } from 'app/store';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import AddCheckpointModel from './AddCheckpointModel';
import AddDiffusersModel from './AddDiffusersModel';
import IAIIconButton from 'common/components/IAIIconButton';
function AddModelBox({
text,
@ -83,8 +84,22 @@ export default function AddModel() {
closeOnOverlayClick={false}
>
<ModalOverlay />
<ModalContent margin="auto" paddingInlineEnd={4}>
<ModalHeader>{t('modelManager.addNewModel')}</ModalHeader>
<ModalContent margin="auto">
<ModalHeader>{t('modelManager.addNewModel')} </ModalHeader>
{addNewModelUIOption !== null && (
<IAIIconButton
aria-label={t('common.back')}
tooltip={t('common.back')}
onClick={() => dispatch(setAddNewModelUIOption(null))}
position="absolute"
variant="ghost"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={2}
icon={<FaArrowLeft />}
/>
)}
<ModalCloseButton />
<ModalBody>
{addNewModelUIOption == null && (

View File

@ -28,6 +28,7 @@ import { isEqual, pickBy } from 'lodash';
import ModelConvert from './ModelConvert';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIForm from 'common/components/IAIForm';
const selector = createSelector(
[systemSelector],
@ -120,7 +121,7 @@ export default function CheckpointModelEdit() {
onSubmit={editModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}>
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2} alignItems="start">
{/* Description */}
<FormControl
@ -317,7 +318,7 @@ export default function CheckpointModelEdit() {
{t('modelManager.updateModel')}
</IAIButton>
</VStack>
</form>
</IAIForm>
)}
</Formik>
</Flex>

View File

@ -18,6 +18,7 @@ import type { RootState } from 'app/store';
import { isEqual, pickBy } from 'lodash';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIForm from 'common/components/IAIForm';
const selector = createSelector(
[systemSelector],
@ -116,7 +117,7 @@ export default function DiffusersModelEdit() {
onSubmit={editModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}>
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2} alignItems="start">
{/* Description */}
<FormControl
@ -259,7 +260,7 @@ export default function DiffusersModelEdit() {
{t('modelManager.updateModel')}
</IAIButton>
</VStack>
</form>
</IAIForm>
)}
</Formik>
</Flex>

View File

@ -12,14 +12,13 @@ import {
RadioGroup,
Spacer,
Text,
VStack,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { useTranslation } from 'react-i18next';
import { FaPlus, FaSearch } from 'react-icons/fa';
import { FaSearch, FaTrash } from 'react-icons/fa';
import { addNewModel, searchForModels } from 'app/socketio/actions';
import {
@ -34,7 +33,7 @@ import IAIInput from 'common/components/IAIInput';
import { Field, Formik } from 'formik';
import { forEach, remove } from 'lodash';
import type { ChangeEvent, ReactNode } from 'react';
import { BiReset } from 'react-icons/bi';
import IAIForm from 'common/components/IAIForm';
const existingModelsSelector = createSelector([systemSelector], (system) => {
const { model_list } = system;
@ -71,34 +70,32 @@ function SearchModelEntry({
};
return (
<VStack>
<Flex
flexDirection="column"
gap={2}
backgroundColor={
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
}
paddingX={4}
paddingY={2}
borderRadius={4}
>
<Flex gap={4}>
<IAICheckbox
value={model.name}
label={<Text fontWeight={500}>{model.name}</Text>}
isChecked={modelsToAdd.includes(model.name)}
isDisabled={existingModels.includes(model.location)}
onChange={foundModelsChangeHandler}
></IAICheckbox>
{existingModels.includes(model.location) && (
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
)}
</Flex>
<Text fontStyle="italic" variant="subtext">
{model.location}
</Text>
<Flex
flexDirection="column"
gap={2}
backgroundColor={
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
}
paddingX={4}
paddingY={2}
borderRadius={4}
>
<Flex gap={4} alignItems="center" justifyContent="space-between">
<IAICheckbox
value={model.name}
label={<Text fontWeight={500}>{model.name}</Text>}
isChecked={modelsToAdd.includes(model.name)}
isDisabled={existingModels.includes(model.location)}
onChange={foundModelsChangeHandler}
></IAICheckbox>
{existingModels.includes(model.location) && (
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
)}
</Flex>
</VStack>
<Text fontStyle="italic" variant="subtext">
{model.location}
</Text>
</Flex>
);
}
@ -215,10 +212,10 @@ export default function SearchModels() {
}
return (
<>
<Flex flexDirection="column" rowGap={4}>
{newFoundModels}
{shouldShowExistingModelsInSearch && existingFoundModels}
</>
</Flex>
);
};
@ -245,26 +242,26 @@ export default function SearchModels() {
<Text
sx={{
fontWeight: 500,
fontSize: 'sm',
}}
variant="subtext"
>
{t('modelManager.checkpointFolder')}
</Text>
<Text sx={{ fontWeight: 500, fontSize: 'sm' }}>{searchFolder}</Text>
<Text sx={{ fontWeight: 500 }}>{searchFolder}</Text>
</Flex>
<Spacer />
<IAIIconButton
aria-label={t('modelManager.scanAgain')}
tooltip={t('modelManager.scanAgain')}
icon={<BiReset />}
icon={<FaSearch />}
fontSize={18}
disabled={isProcessing}
onClick={() => dispatch(searchForModels(searchFolder))}
/>
<IAIIconButton
aria-label={t('modelManager.clearCheckpointFolder')}
icon={<FaPlus style={{ transform: 'rotate(45deg)' }} />}
tooltip={t('modelManager.clearCheckpointFolder')}
icon={<FaTrash />}
onClick={resetSearchModelHandler}
/>
</Flex>
@ -276,9 +273,9 @@ export default function SearchModels() {
}}
>
{({ handleSubmit }) => (
<form onSubmit={handleSubmit}>
<HStack columnGap={2} alignItems="flex-end" width="100%">
<FormControl isRequired width="lg">
<IAIForm onSubmit={handleSubmit} width="100%">
<HStack columnGap={2} alignItems="flex-end">
<FormControl flexGrow={1}>
<Field
as={IAIInput}
id="checkpointFolder"
@ -294,12 +291,12 @@ export default function SearchModels() {
tooltip={t('modelManager.findModels')}
type="submit"
disabled={isProcessing}
paddingX={10}
px={8}
>
{t('modelManager.findModels')}
</IAIButton>
</HStack>
</form>
</IAIForm>
)}
</Formik>
)}
@ -410,7 +407,6 @@ export default function SearchModels() {
maxHeight={72}
overflowY="scroll"
borderRadius="sm"
paddingInlineEnd={4}
gap={2}
>
{foundModels.length > 0 ? (

View File

@ -50,7 +50,6 @@ export default function ThemeChanger() {
return (
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
aria-label={t('common.themeLabel')}

View File

@ -166,20 +166,8 @@ export default function InvokeTabs() {
[]
);
/**
* isLazy means the tabs are mounted and unmounted when changing them. There is a tradeoff here,
* as mounting is expensive, but so is retaining all tabs in the DOM at all times.
*
* Removing isLazy messes with the outside click watcher, which is used by ResizableDrawer.
* Because you have multiple handlers listening for an outside click, any click anywhere triggers
* the watcher for the hidden drawers, closing the open drawer.
*
* TODO: Add logic to the `useOutsideClick` in ResizableDrawer to enable it only for the active
* tab's drawer.
*/
return (
<Tabs
isLazy
defaultIndex={activeTab}
index={activeTab}
onChange={(index: number) => {

View File

@ -93,12 +93,9 @@ const ResizableDrawer = ({
useOutsideClick({
ref: outsideClickRef,
handler: () => {
if (isPinned) {
return;
}
onClose();
},
enabled: isOpen && !isPinned,
});
const handleEnables = useMemo(

View File

@ -77,7 +77,6 @@ export default function UnifiedCanvasColorPicker() {
return (
<IAIPopover
trigger="hover"
triggerComponent={
<Box
sx={{

View File

@ -56,7 +56,7 @@ const UnifiedCanvasSettings = () => {
return (
<IAIPopover
trigger="hover"
isLazy={false}
triggerComponent={
<IAIIconButton
tooltip={t('unifiedCanvas.canvasSettings')}

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,3 @@
import path from 'path';
import react from '@vitejs/plugin-react-swc';
import { visualizer } from 'rollup-plugin-visualizer';
import { defineConfig, PluginOption } from 'vite';
@ -58,26 +57,6 @@ export default defineConfig(({ mode }) => {
// sourcemap: true, // this can be enabled if needed, it adds ovwer 15MB to the commit
},
};
} else if (mode === 'package') {
return {
...common,
build: {
...common.build,
lib: {
entry: path.resolve(__dirname, 'src/component.tsx'),
name: 'InvokeAI UI',
fileName: (format) => `invoke-ai-ui.${format}.js`,
},
rollupOptions: {
external: ['react', 'react-dom'],
output: {
globals: {
react: 'React',
},
},
},
},
};
} else {
return {
...common,

View File

@ -38,16 +38,16 @@ dependencies = [
"albumentations",
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==0.1.10",
"compel==1.0.4",
"datasets",
"diffusers[torch]~=0.14",
"dnspython==2.2.1",
"einops",
"eventlet",
"facexlib",
"fastapi==0.85.0",
"fastapi-events==0.6.0",
"fastapi-socketio==0.0.9",
"fastapi==0.94.1",
"fastapi-events==0.8.0",
"fastapi-socketio==0.0.10",
"flask==2.1.3",
"flask_cors==3.0.10",
"flask_socketio==5.3.0",
@ -75,7 +75,7 @@ dependencies = [
"torchvision>=0.14.1",
"torchmetrics",
"transformers~=4.26",
"uvicorn[standard]==0.20.0",
"uvicorn[standard]==0.21.1",
"windows-curses; sys_platform=='win32'",
]
@ -139,8 +139,24 @@ version = { attr = "invokeai.version.__version__" }
"invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"]
"invokeai.frontend.web.dist" = ["**"]
#=== Begin: PyTest and Coverage
[tool.pytest.ini_options]
addopts = "-p pytest_cov --junitxml=junit/test-results.xml --cov-report=term:skip-covered --cov=ldm/invoke --cov=backend --cov-branch"
addopts = "--cov-report term --cov-report html --cov-report xml"
[tool.coverage.run]
branch = true
source = ["invokeai"]
omit = ["*tests*", "*migrations*", ".venv/*", "*.env"]
[tool.coverage.report]
show_missing = true
fail_under = 85 # let's set something sensible on Day 1 ...
[tool.coverage.json]
output = "coverage/coverage.json"
pretty_print = true
[tool.coverage.html]
directory = "coverage/html"
[tool.coverage.xml]
output = "coverage/index.xml"
#=== End: PyTest and Coverage
[flake8]
max-line-length = 120

View File

@ -105,17 +105,20 @@
// Start building nodes
var id = 1;
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "sampler": sampler, "steps": steps, "seed": seed};
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": seed};
id++;
var i2iNode = {"id": id.toString(), "type": "img2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": Math.floor(Math.random() * 10000)};
id++;
var upscaleNode = {"id": id.toString(), "type": "show_image" };
id++
nodes = {};
nodes[initialNode.id] = initialNode;
nodes[i2iNode.id] = i2iNode;
nodes[upscaleNode.id] = upscaleNode;
links = [
[{ "node_id": initialNode.id, field: "image" },
{ "node_id": upscaleNode.id, field: "image" }]
{ "source": { "node_id": initialNode.id, field: "image" }, "destination": { "node_id": i2iNode.id, field: "image" }},
{ "source": { "node_id": i2iNode.id, field: "image" }, "destination": { "node_id": upscaleNode.id, field: "image" }}
];
// expandSize = 128;
// for (var i = 0; i < 6; ++i) {

View File

@ -1,15 +1,18 @@
from invokeai.app.invocations.image import *
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
from invokeai.app.invocations.upscale import UpscaleInvocation
import pytest
# Helpers
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
return Edge(
source=EdgeConnection(node_id = from_id, field = from_field),
destination=EdgeConnection(node_id = to_id, field = to_field)
)
# Tests
def test_connections_are_compatible():
@ -108,7 +111,7 @@ def test_graph_allows_non_conflicting_id_change():
assert g.get_node("3").prompt == "Banana sushi"
assert len(g.edges) == 1
assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges
assert Edge(source=EdgeConnection(node_id = "3", field = "image"), destination=EdgeConnection(node_id = "2", field = "image")) in g.edges
def test_graph_fails_to_update_node_id_if_conflict():
g = Graph()
@ -490,10 +493,10 @@ def test_graph_can_deserialize():
assert g2.nodes['1'] is not None
assert g2.nodes['2'] is not None
assert len(g2.edges) == 1
assert g2.edges[0][0].node_id == '1'
assert g2.edges[0][0].field == 'image'
assert g2.edges[0][1].node_id == '2'
assert g2.edges[0][1].field == 'image'
assert g2.edges[0].source.node_id == '1'
assert g2.edges[0].source.field == 'image'
assert g2.edges[0].destination.node_id == '2'
assert g2.edges[0].destination.field == 'image'
def test_graph_can_generate_schema():
# Not throwing on this line is sufficient

View File

@ -64,10 +64,12 @@ class PromptCollectionTestInvocation(BaseInvocation):
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import EdgeConnection
from invokeai.app.services.graph import Edge, EdgeConnection
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
return Edge(
source=EdgeConnection(node_id = from_id, field = from_field),
destination=EdgeConnection(node_id = to_id, field = to_field))
class TestEvent: