mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into bugfix/dreambooth_ema
This commit is contained in:
commit
6e7dbf99f3
@ -1,6 +0,0 @@
|
||||
[run]
|
||||
omit='.env/*'
|
||||
source='.'
|
||||
|
||||
[report]
|
||||
show_missing = true
|
10
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
10
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
@ -65,6 +65,16 @@ body:
|
||||
placeholder: 8GB
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: version-number
|
||||
attributes:
|
||||
label: What version did you experience this issue on?
|
||||
description: |
|
||||
Please share the version of Invoke AI that you experienced the issue on. If this is not the latest version, please update first to confirm the issue still exists. If you are testing main, please include the commit hash instead.
|
||||
placeholder: X.X.X
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: what-happened
|
||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -63,6 +63,7 @@ pip-delete-this-directory.txt
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coveragerc
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
@ -73,6 +74,7 @@ cov.xml
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
.pytest.ini
|
||||
cover/
|
||||
junit/
|
||||
|
||||
|
@ -1,5 +0,0 @@
|
||||
[pytest]
|
||||
DJANGO_SETTINGS_MODULE = webtas.settings
|
||||
; python_files = tests.py test_*.py *_tests.py
|
||||
|
||||
addopts = --cov=. --cov-config=.coveragerc --cov-report xml:cov.xml
|
4
coverage/.gitignore
vendored
Normal file
4
coverage/.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
# Ignore everything in this directory
|
||||
*
|
||||
# Except this file
|
||||
!.gitignore
|
BIN
docs/assets/contributing/html-detail.png
Normal file
BIN
docs/assets/contributing/html-detail.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 470 KiB |
BIN
docs/assets/contributing/html-overview.png
Normal file
BIN
docs/assets/contributing/html-overview.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 457 KiB |
83
docs/contributing/LOCAL_DEVELOPMENT.md
Normal file
83
docs/contributing/LOCAL_DEVELOPMENT.md
Normal file
@ -0,0 +1,83 @@
|
||||
# Local Development
|
||||
|
||||
If you are looking to contribute you will need to have a local development
|
||||
environment. See the
|
||||
[Developer Install](../installation/020_INSTALL_MANUAL.md#developer-install) for
|
||||
full details.
|
||||
|
||||
Broadly this involves cloning the repository, installing the pre-reqs, and
|
||||
InvokeAI (in editable form). Assuming this is working, choose your area of
|
||||
focus.
|
||||
|
||||
## Documentation
|
||||
|
||||
We use [mkdocs](https://www.mkdocs.org) for our documentation with the
|
||||
[material theme](https://squidfunk.github.io/mkdocs-material/). Documentation is
|
||||
written in markdown files under the `./docs` folder and then built into a static
|
||||
website for hosting with GitHub Pages at
|
||||
[invoke-ai.github.io/InvokeAI](https://invoke-ai.github.io/InvokeAI).
|
||||
|
||||
To contribute to the documentation you'll need to install the dependencies. Note
|
||||
the use of `"`.
|
||||
|
||||
```zsh
|
||||
pip install ".[docs]"
|
||||
```
|
||||
|
||||
Now, to run the documentation locally with hot-reloading for changes made.
|
||||
|
||||
```zsh
|
||||
mkdocs serve
|
||||
```
|
||||
|
||||
You'll then be prompted to connect to `http://127.0.0.1:8080` in order to
|
||||
access.
|
||||
|
||||
## Backend
|
||||
|
||||
The backend is contained within the `./invokeai/backend` folder structure. To
|
||||
get started however please install the development dependencies.
|
||||
|
||||
From the root of the repository run the following command. Note the use of `"`.
|
||||
|
||||
```zsh
|
||||
pip install ".[test]"
|
||||
```
|
||||
|
||||
This in an optional group of packages which is defined within the
|
||||
`pyproject.toml` and will be required for testing the changes you make the the
|
||||
code.
|
||||
|
||||
### Running Tests
|
||||
|
||||
We use [pytest](https://docs.pytest.org/en/7.2.x/) for our test suite. Tests can
|
||||
be found under the `./tests` folder and can be run with a single `pytest`
|
||||
command. Optionally, to review test coverage you can append `--cov`.
|
||||
|
||||
```zsh
|
||||
pytest --cov
|
||||
```
|
||||
|
||||
Test outcomes and coverage will be reported in the terminal. In addition a more
|
||||
detailed report is created in both XML and HTML format in the `./coverage`
|
||||
folder. The HTML one in particular can help identify missing statements
|
||||
requiring tests to ensure coverage. This can be run by opening
|
||||
`./coverage/html/index.html`.
|
||||
|
||||
For example.
|
||||
|
||||
```zsh
|
||||
pytest --cov; open ./coverage/html/index.html
|
||||
```
|
||||
|
||||
??? info "HTML coverage report output"
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
## Front End
|
||||
|
||||
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
|
||||
|
||||
--8<-- "invokeai/frontend/web/README.md"
|
@ -17,7 +17,7 @@ notebooks.
|
||||
|
||||
You will need a GPU to perform training in a reasonable length of
|
||||
time, and at least 12 GB of VRAM. We recommend using the [`xformers`
|
||||
library](../installation/070_INSTALL_XFORMERS) to accelerate the
|
||||
library](../installation/070_INSTALL_XFORMERS.md) to accelerate the
|
||||
training process further. During training, about ~8 GB is temporarily
|
||||
needed in order to store intermediate models, checkpoints and logs.
|
||||
|
||||
|
@ -24,7 +24,7 @@ You need to have opencv installed so that pypatchmatch can be built:
|
||||
brew install opencv
|
||||
```
|
||||
|
||||
The next time you start `invoke`, after sucesfully installing opencv, pypatchmatch will be built.
|
||||
The next time you start `invoke`, after successfully installing opencv, pypatchmatch will be built.
|
||||
|
||||
## Linux
|
||||
|
||||
@ -56,7 +56,7 @@ Prior to installing PyPatchMatch, you need to take the following steps:
|
||||
|
||||
5. Confirm that pypatchmatch is installed. At the command-line prompt enter
|
||||
`python`, and then at the `>>>` line type
|
||||
`from patchmatch import patch_match`: It should look like the follwing:
|
||||
`from patchmatch import patch_match`: It should look like the following:
|
||||
|
||||
```py
|
||||
Python 3.9.5 (default, Nov 23 2021, 15:27:38)
|
||||
@ -108,4 +108,4 @@ Prior to installing PyPatchMatch, you need to take the following steps:
|
||||
|
||||
[**Next, Follow Steps 4-6 from the Debian Section above**](#linux)
|
||||
|
||||
If you see no errors, then you're ready to go!
|
||||
If you see no errors you're ready to go!
|
||||
|
@ -10,6 +10,7 @@ from pydantic.fields import Field
|
||||
from ...invocations import *
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import (
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
@ -92,7 +93,7 @@ async def get_session(
|
||||
async def add_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The node to add"),
|
||||
) -> str:
|
||||
"""Adds a node to the graph"""
|
||||
@ -125,7 +126,7 @@ async def update_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node_path: str = Path(description="The path to the node in the graph"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The new node"),
|
||||
) -> GraphExecutionState:
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
@ -186,7 +187,7 @@ async def delete_node(
|
||||
)
|
||||
async def add_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
edge: tuple[EdgeConnection, EdgeConnection] = Body(description="The edge to add"),
|
||||
edge: Edge = Body(description="The edge to add"),
|
||||
) -> GraphExecutionState:
|
||||
"""Adds an edge to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
@ -228,9 +229,9 @@ async def delete_edge(
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
edge = (
|
||||
EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
EdgeConnection(node_id=to_node_id, field=to_field),
|
||||
edge = Edge(
|
||||
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
destination=EdgeConnection(node_id=to_node_id, field=to_field)
|
||||
)
|
||||
session.delete_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
|
@ -112,10 +112,8 @@ def custom_openapi():
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
if "additionalProperties" not in invoker_schema:
|
||||
invoker_schema["additionalProperties"] = {}
|
||||
|
||||
invoker_schema["additionalProperties"]["outputs"] = outputs_ref
|
||||
invoker_schema["output"] = outputs_ref
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
@ -19,7 +19,7 @@ from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import EdgeConnection, GraphExecutionState
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
@ -77,7 +77,7 @@ def get_command_parser() -> argparse.ArgumentParser:
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation
|
||||
) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
||||
) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
atype = type(a)
|
||||
btype = type(b)
|
||||
@ -94,9 +94,9 @@ def generate_matching_edges(
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
edges = [
|
||||
(
|
||||
EdgeConnection(node_id=a.id, field=field),
|
||||
EdgeConnection(node_id=b.id, field=field),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=a.id, field=field),
|
||||
destination=EdgeConnection(node_id=b.id, field=field)
|
||||
)
|
||||
for field in matching_fields
|
||||
]
|
||||
@ -111,16 +111,15 @@ class SessionError(Exception):
|
||||
def invoke_all(context: CliContext):
|
||||
"""Runs all invocations in the specified session"""
|
||||
context.invoker.invoke(context.session, invoke_all=True)
|
||||
while not context.session.is_complete():
|
||||
while not context.get_session().is_complete():
|
||||
# Wait some time
|
||||
session = context.get_session()
|
||||
time.sleep(0.1)
|
||||
|
||||
# Print any errors
|
||||
if context.session.has_error():
|
||||
for n in context.session.errors:
|
||||
print(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}"
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||
)
|
||||
|
||||
raise SessionError()
|
||||
@ -203,7 +202,7 @@ def invoke_cli():
|
||||
continue
|
||||
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges = []
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
from_id = (
|
||||
history[0] if current_id == start_id else str(current_id - 1)
|
||||
@ -225,19 +224,19 @@ def invoke_cli():
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command
|
||||
)
|
||||
matching_destinations = [e[1] for e in matching_edges]
|
||||
edges = [e for e in edges if e[1] not in matching_destinations]
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
edges.extend(matching_edges)
|
||||
|
||||
if "link" in args and args["link"]:
|
||||
for link in args["link"]:
|
||||
edges = [e for e in edges if e[1].node_id != command.command.id and e[1].field != link[2]]
|
||||
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
|
||||
edges.append(
|
||||
(
|
||||
EdgeConnection(node_id=link[1], field=link[0]),
|
||||
EdgeConnection(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=link[1], field=link[0]),
|
||||
destination=EdgeConnection(
|
||||
node_id=command.command.id, field=link[2]
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -4,6 +4,8 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch import Tensor
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
@ -12,7 +14,9 @@ from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(InvokeAIGenerator.schedulers())
|
||||
@ -41,18 +45,32 @@ class TextToImageInvocation(BaseInvocation):
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, sample: Any = None, step: int = 0
|
||||
) -> None:
|
||||
self, context: InvocationContext, sample: Tensor, step: int
|
||||
) -> None:
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id,
|
||||
self.id,
|
||||
{
|
||||
"width": width,
|
||||
"height": height,
|
||||
"dataURL": dataURL
|
||||
},
|
||||
step,
|
||||
float(step) / float(self.steps),
|
||||
self.steps,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
def step_callback(sample, step=0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state.latents, state.step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
@ -118,7 +136,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
generator_output = next(
|
||||
Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_image=image,
|
||||
init_mask=mask,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
@ -179,8 +197,8 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
generator_output = next(
|
||||
Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
|
@ -1,7 +1,10 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, TypedDict
|
||||
|
||||
ProgressImage = TypedDict(
|
||||
"ProgressImage", {"dataURL": str, "width": int, "height": int}
|
||||
)
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
@ -23,8 +26,9 @@ class EventServiceBase:
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
progress_image: ProgressImage | None,
|
||||
step: int,
|
||||
percent: float,
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_session_event(
|
||||
@ -32,8 +36,9 @@ class EventServiceBase:
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
progress_image=progress_image,
|
||||
step=step,
|
||||
percent=percent,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -44,6 +44,11 @@ class EdgeConnection(BaseModel):
|
||||
return hash(f"{self.node_id}.{self.field}")
|
||||
|
||||
|
||||
class Edge(BaseModel):
|
||||
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
|
||||
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")
|
||||
|
||||
|
||||
def get_output_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_type = type(node)
|
||||
node_outputs = get_type_hints(node_type.get_output_type())
|
||||
@ -194,7 +199,7 @@ class Graph(BaseModel):
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
)
|
||||
edges: list[tuple[EdgeConnection, EdgeConnection]] = Field(
|
||||
edges: list[Edge] = Field(
|
||||
description="The connections between nodes and their fields in this graph",
|
||||
default_factory=list,
|
||||
)
|
||||
@ -251,7 +256,7 @@ class Graph(BaseModel):
|
||||
except NodeNotFoundError:
|
||||
pass # Ignore, not doesn't exist (should this throw?)
|
||||
|
||||
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
def add_edge(self, edge: Edge) -> None:
|
||||
"""Adds an edge to a graph
|
||||
|
||||
:raises InvalidEdgeError: the provided edge is invalid.
|
||||
@ -262,7 +267,7 @@ class Graph(BaseModel):
|
||||
else:
|
||||
raise InvalidEdgeError()
|
||||
|
||||
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
def delete_edge(self, edge: Edge) -> None:
|
||||
"""Deletes an edge from a graph"""
|
||||
|
||||
try:
|
||||
@ -280,7 +285,7 @@ class Graph(BaseModel):
|
||||
|
||||
# Validate all edges reference nodes in the graph
|
||||
node_ids = set(
|
||||
[e[0].node_id for e in self.edges] + [e[1].node_id for e in self.edges]
|
||||
[e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
|
||||
)
|
||||
if not all((self.has_node(node_id) for node_id in node_ids)):
|
||||
return False
|
||||
@ -294,10 +299,10 @@ class Graph(BaseModel):
|
||||
if not all(
|
||||
(
|
||||
are_connections_compatible(
|
||||
self.get_node(e[0].node_id),
|
||||
e[0].field,
|
||||
self.get_node(e[1].node_id),
|
||||
e[1].field,
|
||||
self.get_node(e.source.node_id),
|
||||
e.source.field,
|
||||
self.get_node(e.destination.node_id),
|
||||
e.destination.field,
|
||||
)
|
||||
for e in self.edges
|
||||
)
|
||||
@ -328,58 +333,58 @@ class Graph(BaseModel):
|
||||
|
||||
return True
|
||||
|
||||
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||
try:
|
||||
from_node = self.get_node(edge[0].node_id)
|
||||
to_node = self.get_node(edge[1].node_id)
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge[1].node_id, edge[1].field)
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||
return False
|
||||
|
||||
# Validate that no cycles would be created
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge[0].node_id, edge[1].node_id)
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
return False
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge[0].field, to_node, edge[1].field
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge[1].field == "collection":
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge[1].node_id, new_input=edge[0]
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge[0].field == "item":
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge[0].node_id, new_output=edge[1]
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge[1].field == "item":
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge[1].node_id, new_input=edge[0]
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge[0].field == "collection":
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge[0].node_id, new_output=edge[1]
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
|
||||
@ -438,15 +443,15 @@ class Graph(BaseModel):
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = (
|
||||
new_node.id
|
||||
if "." not in edge[1].node_id
|
||||
else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}'
|
||||
if "." not in edge.destination.node_id
|
||||
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
|
||||
)
|
||||
graph.add_edge(
|
||||
(
|
||||
edge[0],
|
||||
EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge[1].field
|
||||
),
|
||||
Edge(
|
||||
source=edge.source,
|
||||
destination=EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge.destination.field
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@ -454,51 +459,51 @@ class Graph(BaseModel):
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = (
|
||||
new_node.id
|
||||
if "." not in edge[0].node_id
|
||||
else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}'
|
||||
if "." not in edge.source.node_id
|
||||
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
|
||||
)
|
||||
graph.add_edge(
|
||||
(
|
||||
EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge[0].field
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge.source.field
|
||||
),
|
||||
edge[1],
|
||||
destination=edge.destination
|
||||
)
|
||||
)
|
||||
|
||||
def _get_input_edges(
|
||||
self, node_path: str, field: Optional[str] = None
|
||||
) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
||||
) -> list[Edge]:
|
||||
"""Gets all input edges for a node"""
|
||||
edges = self._get_input_edges_and_graphs(node_path)
|
||||
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if field is None or e[2][1].field == field)
|
||||
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
|
||||
|
||||
# Create full node paths for each edge
|
||||
return [
|
||||
(
|
||||
EdgeConnection(
|
||||
node_id=self._get_node_path(e[0].node_id, prefix=prefix),
|
||||
field=e[0].field,
|
||||
),
|
||||
EdgeConnection(
|
||||
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
|
||||
field=e[1].field,
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
|
||||
field=e.source.field,
|
||||
),
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
)
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
|
||||
def _get_input_edges_and_graphs(
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]:
|
||||
) -> list[tuple["Graph", str, Edge]]:
|
||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend(
|
||||
[(self, prefix, e) for e in self.edges if e[1].node_id == node_path]
|
||||
[(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
|
||||
)
|
||||
|
||||
node_id = (
|
||||
@ -522,37 +527,37 @@ class Graph(BaseModel):
|
||||
|
||||
def _get_output_edges(
|
||||
self, node_path: str, field: str
|
||||
) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
||||
) -> list[Edge]:
|
||||
"""Gets all output edges for a node"""
|
||||
edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if e[2][0].field == field)
|
||||
filtered_edges = (e for e in edges if e[2].source.field == field)
|
||||
|
||||
# Create full node paths for each edge
|
||||
return [
|
||||
(
|
||||
EdgeConnection(
|
||||
node_id=self._get_node_path(e[0].node_id, prefix=prefix),
|
||||
field=e[0].field,
|
||||
),
|
||||
EdgeConnection(
|
||||
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
|
||||
field=e[1].field,
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
|
||||
field=e.source.field,
|
||||
),
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
)
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
|
||||
def _get_output_edges_and_graphs(
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]:
|
||||
) -> list[tuple["Graph", str, Edge]]:
|
||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend(
|
||||
[(self, prefix, e) for e in self.edges if e[0].node_id == node_path]
|
||||
[(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
|
||||
)
|
||||
|
||||
node_id = (
|
||||
@ -580,8 +585,8 @@ class Graph(BaseModel):
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = list([e[0] for e in self._get_input_edges(node_path, "collection")])
|
||||
outputs = list([e[1] for e in self._get_output_edges(node_path, "item")])
|
||||
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
|
||||
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@ -622,8 +627,8 @@ class Graph(BaseModel):
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = list([e[0] for e in self._get_input_edges(node_path, "item")])
|
||||
outputs = list([e[1] for e in self._get_output_edges(node_path, "collection")])
|
||||
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
|
||||
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@ -684,7 +689,7 @@ class Graph(BaseModel):
|
||||
# TODO: Cache this?
|
||||
g = nx.DiGraph()
|
||||
g.add_nodes_from([n for n in self.nodes.keys()])
|
||||
g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges]))
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_flat(
|
||||
@ -711,7 +716,7 @@ class Graph(BaseModel):
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges])
|
||||
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
||||
g.add_edges_from(
|
||||
[
|
||||
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
|
||||
@ -768,6 +773,24 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
# Declare all fields as required; necessary for OpenAPI schema generation build.
|
||||
# Technically only fields without a `default_factory` need to be listed here.
|
||||
# See: https://github.com/pydantic/pydantic/discussions/4577
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'id',
|
||||
'graph',
|
||||
'execution_graph',
|
||||
'executed',
|
||||
'executed_history',
|
||||
'results',
|
||||
'errors',
|
||||
'prepared_source_mapping',
|
||||
'source_prepared_mapping',
|
||||
]
|
||||
}
|
||||
|
||||
def next(self) -> BaseInvocation | None:
|
||||
"""Gets the next node ready to execute."""
|
||||
|
||||
@ -841,13 +864,13 @@ class GraphExecutionState(BaseModel):
|
||||
input_collection_prepared_node_id = next(
|
||||
n[1]
|
||||
for n in iteration_node_map
|
||||
if n[0] == input_collection_edge[0].node_id
|
||||
if n[0] == input_collection_edge.source.node_id
|
||||
)
|
||||
input_collection_prepared_node_output = self.results[
|
||||
input_collection_prepared_node_id
|
||||
]
|
||||
input_collection = getattr(
|
||||
input_collection_prepared_node_output, input_collection_edge[0].field
|
||||
input_collection_prepared_node_output, input_collection_edge.source.field
|
||||
)
|
||||
self_iteration_count = len(input_collection)
|
||||
|
||||
@ -864,11 +887,11 @@ class GraphExecutionState(BaseModel):
|
||||
new_edges = list()
|
||||
for edge in input_edges:
|
||||
for input_node_id in (
|
||||
n[1] for n in iteration_node_map if n[0] == edge[0].node_id
|
||||
n[1] for n in iteration_node_map if n[0] == edge.source.node_id
|
||||
):
|
||||
new_edge = (
|
||||
EdgeConnection(node_id=input_node_id, field=edge[0].field),
|
||||
EdgeConnection(node_id="", field=edge[1].field),
|
||||
new_edge = Edge(
|
||||
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
|
||||
destination=EdgeConnection(node_id="", field=edge.destination.field),
|
||||
)
|
||||
new_edges.append(new_edge)
|
||||
|
||||
@ -893,9 +916,9 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Add new edges to execution graph
|
||||
for edge in new_edges:
|
||||
new_edge = (
|
||||
edge[0],
|
||||
EdgeConnection(node_id=new_node.id, field=edge[1].field),
|
||||
new_edge = Edge(
|
||||
source=edge.source,
|
||||
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
|
||||
)
|
||||
self.execution_graph.add_edge(new_edge)
|
||||
|
||||
@ -1043,26 +1066,26 @@ class GraphExecutionState(BaseModel):
|
||||
return self.execution_graph.nodes[next_node]
|
||||
|
||||
def _prepare_inputs(self, node: BaseInvocation):
|
||||
input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id]
|
||||
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
|
||||
if isinstance(node, CollectInvocation):
|
||||
output_collection = [
|
||||
getattr(self.results[edge[0].node_id], edge[0].field)
|
||||
getattr(self.results[edge.source.node_id], edge.source.field)
|
||||
for edge in input_edges
|
||||
if edge[1].field == "item"
|
||||
if edge.destination.field == "item"
|
||||
]
|
||||
setattr(node, "collection", output_collection)
|
||||
else:
|
||||
for edge in input_edges:
|
||||
output_value = getattr(self.results[edge[0].node_id], edge[0].field)
|
||||
setattr(node, edge[1].field, output_value)
|
||||
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
|
||||
setattr(node, edge.destination.field, output_value)
|
||||
|
||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
if not self._is_edge_valid(edge):
|
||||
return False
|
||||
|
||||
# Invalid if destination has already been prepared or executed
|
||||
if edge[1].node_id in self.source_prepared_mapping:
|
||||
if edge.destination.node_id in self.source_prepared_mapping:
|
||||
return False
|
||||
|
||||
# Otherwise, the edge is valid
|
||||
@ -1089,17 +1112,17 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
self.graph.delete_node(node_path)
|
||||
|
||||
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
if not self._is_node_updatable(edge[1].node_id):
|
||||
def add_edge(self, edge: Edge) -> None:
|
||||
if not self._is_node_updatable(edge.destination.node_id):
|
||||
raise NodeAlreadyExecutedError(
|
||||
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to"
|
||||
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot be linked to"
|
||||
)
|
||||
self.graph.add_edge(edge)
|
||||
|
||||
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
if not self._is_node_updatable(edge[1].node_id):
|
||||
def delete_edge(self, edge: Edge) -> None:
|
||||
if not self._is_node_updatable(edge.destination.node_id):
|
||||
raise NodeAlreadyExecutedError(
|
||||
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted"
|
||||
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted"
|
||||
)
|
||||
self.graph.delete_edge(edge)
|
||||
|
||||
|
@ -490,7 +490,7 @@ class Args(object):
|
||||
"-z",
|
||||
type=int,
|
||||
default=6,
|
||||
choices=range(0, 9),
|
||||
choices=range(0, 10),
|
||||
dest="png_compression",
|
||||
help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.",
|
||||
)
|
||||
@ -943,7 +943,6 @@ class Args(object):
|
||||
"--png_compression",
|
||||
"-z",
|
||||
type=int,
|
||||
default=6,
|
||||
choices=range(0, 10),
|
||||
dest="png_compression",
|
||||
help="level of PNG compression, from 0 (none) to 9 (maximum). [6]",
|
||||
|
@ -58,7 +58,7 @@ class InvokeAIGeneratorOutput:
|
||||
'''
|
||||
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
|
||||
operation, including the image, its seed, the model name used to generate the image
|
||||
and the model hash, as well as all the generate() parameters that went into
|
||||
and the model hash, as well as all the generate() parameters that went into
|
||||
generating the image (in .params, also available as attributes)
|
||||
'''
|
||||
image: Image
|
||||
@ -116,7 +116,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
outputs = txt2img.generate(prompt='banana sushi', iterations=None)
|
||||
for o in outputs:
|
||||
print(o.image, o.seed)
|
||||
|
||||
|
||||
'''
|
||||
generator_args = dataclasses.asdict(self.params)
|
||||
generator_args.update(keyword_args)
|
||||
@ -154,6 +154,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
for i in iteration_count:
|
||||
results = generator.generate(prompt,
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
step_callback=step_callback,
|
||||
sampler=scheduler,
|
||||
**generator_args,
|
||||
)
|
||||
@ -167,7 +168,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
if callback:
|
||||
callback(output)
|
||||
yield output
|
||||
|
||||
|
||||
@classmethod
|
||||
def schedulers(self)->List[str]:
|
||||
'''
|
||||
@ -177,7 +178,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
|
||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||
@ -267,12 +268,12 @@ class Embiggen(Txt2Img):
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
strength=strength,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .embiggen import Embiggen
|
||||
return Embiggen
|
||||
|
||||
|
||||
|
||||
class Generator:
|
||||
downsampling_factor: int
|
||||
@ -347,7 +348,6 @@ class Generator:
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
attention_maps_callback=attention_maps_callback,
|
||||
seed=seed,
|
||||
**kwargs,
|
||||
)
|
||||
results = []
|
||||
@ -375,7 +375,8 @@ class Generator:
|
||||
print("** An error occurred while getting initial noise **")
|
||||
print(traceback.format_exc())
|
||||
|
||||
image = make_image(x_T)
|
||||
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
||||
image = make_image(x_T, seed)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
image = self.safety_checker.check(image)
|
||||
@ -497,7 +498,8 @@ class Generator:
|
||||
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
def sample_to_lowres_estimated_image(self, samples):
|
||||
@staticmethod
|
||||
def sample_to_lowres_estimated_image(samples):
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
|
@ -37,7 +37,6 @@ class Img2Img(Generator):
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
attention_maps_callback=None,
|
||||
seed=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -64,7 +63,7 @@ class Img2Img(Generator):
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T):
|
||||
def make_image(x_T: torch.Tensor, seed: int):
|
||||
# FIXME: use x_T for initial seeded noise
|
||||
# We're not at the moment because the pipeline automatically resizes init_image if
|
||||
# necessary, which the x_T input might not match.
|
||||
@ -77,7 +76,7 @@ class Img2Img(Generator):
|
||||
conditioning_data,
|
||||
noise_func=self.get_noise_like,
|
||||
callback=step_callback,
|
||||
seed=seed
|
||||
seed=seed,
|
||||
)
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
@ -88,9 +87,7 @@ class Img2Img(Generator):
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor, seed: Optional[int]):
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
if device.type == "mps":
|
||||
x = torch.randn_like(like, device="cpu").to(device)
|
||||
|
@ -159,6 +159,7 @@ class Inpaint(Img2Img):
|
||||
seam_size: int,
|
||||
seam_blur: int,
|
||||
prompt,
|
||||
seed,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
@ -192,7 +193,7 @@ class Inpaint(Img2Img):
|
||||
|
||||
seam_noise = self.get_noise(im.width, im.height)
|
||||
|
||||
result = make_image(seam_noise)
|
||||
result = make_image(seam_noise, seed)
|
||||
|
||||
return result
|
||||
|
||||
@ -223,7 +224,6 @@ class Inpaint(Img2Img):
|
||||
inpaint_height=None,
|
||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
attention_maps_callback=None,
|
||||
seed=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -311,7 +311,7 @@ class Inpaint(Img2Img):
|
||||
uc, c, cfg_scale
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T):
|
||||
def make_image(x_T: torch.Tensor, seed: int):
|
||||
pipeline_output = pipeline.inpaint_from_embeddings(
|
||||
init_image=init_image,
|
||||
mask=1 - mask, # expects white means "paint here."
|
||||
@ -320,7 +320,7 @@ class Inpaint(Img2Img):
|
||||
conditioning_data=conditioning_data,
|
||||
noise_func=self.get_noise_like,
|
||||
callback=step_callback,
|
||||
seed=seed
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
if (
|
||||
@ -343,6 +343,7 @@ class Inpaint(Img2Img):
|
||||
seam_size,
|
||||
seam_blur,
|
||||
prompt,
|
||||
seed,
|
||||
sampler,
|
||||
seam_steps,
|
||||
cfg_scale,
|
||||
|
@ -61,7 +61,7 @@ class Txt2Img(Generator):
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T) -> PIL.Image.Image:
|
||||
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
||||
noise=x_T,
|
||||
|
@ -64,7 +64,7 @@ class Txt2Img2Img(Generator):
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T):
|
||||
def make_image(x_T: torch.Tensor, _: int):
|
||||
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(x_T),
|
||||
num_inference_steps=steps,
|
||||
|
@ -1085,9 +1085,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
checkpoint = (
|
||||
load_file(checkpoint_path)
|
||||
if Path(checkpoint_path).suffix == ".safetensors"
|
||||
else torch.load(checkpoint_path)
|
||||
torch.load(checkpoint_path)
|
||||
if Path(checkpoint_path).suffix == ".ckpt"
|
||||
else load_file(checkpoint_path)
|
||||
|
||||
)
|
||||
cache_dir = global_cache_dir("hub")
|
||||
pipeline_class = (
|
||||
|
@ -97,7 +97,7 @@ class ModelManager(object):
|
||||
If on disk, will load from there.
|
||||
"""
|
||||
if not model_name:
|
||||
return self.current_model if self.current_model else self.get_model(self.default_model())
|
||||
return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model())
|
||||
|
||||
if not self.valid_model(model_name):
|
||||
print(
|
||||
@ -362,6 +362,7 @@ class ModelManager(object):
|
||||
raise NotImplementedError(
|
||||
f"Unknown model format {model_name}: {model_format}"
|
||||
)
|
||||
self._add_embeddings_to_model(model)
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
@ -436,7 +437,6 @@ class ModelManager(object):
|
||||
height = width
|
||||
|
||||
print(f" | Default image dimensions = {width} x {height}")
|
||||
self._add_embeddings_to_model(pipeline)
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
@ -732,9 +732,9 @@ class ModelManager(object):
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
checkpoint = (
|
||||
safetensors.torch.load_file(model_path)
|
||||
if model_path.suffix == ".safetensors"
|
||||
else torch.load(model_path)
|
||||
torch.load(model_path)
|
||||
if model_path.suffix == ".ckpt"
|
||||
else safetensors.torch.load_file(model_path)
|
||||
)
|
||||
|
||||
# additional probing needed if no config file provided
|
||||
|
@ -6,7 +6,6 @@ The interface is through the Concepts() object.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import Callable
|
||||
from urllib import error as ul_error
|
||||
from urllib import request
|
||||
@ -15,7 +14,6 @@ from huggingface_hub import (
|
||||
HfApi,
|
||||
HfFolder,
|
||||
ModelFilter,
|
||||
ModelSearchArguments,
|
||||
hf_hub_url,
|
||||
)
|
||||
|
||||
@ -84,7 +82,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
"""
|
||||
if not concept_name in self.list_concepts():
|
||||
print(
|
||||
f"This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||
)
|
||||
return None
|
||||
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
|
||||
@ -236,7 +234,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
except ul_error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
print(
|
||||
f"This concept is not known to the Hugging Face library. Generation will continue without the concept."
|
||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
@ -246,7 +244,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
return False
|
||||
except ul_error.URLError as e:
|
||||
print(
|
||||
f"ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
|
@ -9,6 +9,7 @@ from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||
|
||||
import einops
|
||||
import PIL.Image
|
||||
from accelerate.utils import set_seed
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
@ -694,7 +695,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
device=self._model_group.device_for(self.unet),
|
||||
dtype=self.unet.dtype,
|
||||
)
|
||||
noise = noise_func(initial_latents, seed)
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
noise = noise_func(initial_latents)
|
||||
|
||||
return self.img2img_from_latents_and_embeddings(
|
||||
initial_latents,
|
||||
@ -796,7 +799,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
init_image_latents = self.non_noised_latents_from_image(
|
||||
init_image, device=device, dtype=latents_dtype
|
||||
)
|
||||
noise = noise_func(init_image_latents, seed)
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
noise = noise_func(init_image_latents)
|
||||
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
@ -3,6 +3,9 @@ import math
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
import io
|
||||
import base64
|
||||
|
||||
from collections import abc
|
||||
from inspect import isfunction
|
||||
from pathlib import Path
|
||||
@ -364,3 +367,16 @@ def url_attachment_name(url: str) -> dict:
|
||||
def download_with_progress_bar(url: str, dest: Path) -> bool:
|
||||
result = download_with_resume(url, dest, access_token=None)
|
||||
return result is not None
|
||||
|
||||
|
||||
def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
|
||||
"""
|
||||
Converts an image into a base64 image dataURL.
|
||||
"""
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format=image_format)
|
||||
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
|
||||
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
|
||||
buffered.getvalue()
|
||||
).decode("UTF-8")
|
||||
return image_base64
|
||||
|
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
Normal file
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
Normal file
File diff suppressed because one or more lines are too long
188
invokeai/frontend/web/dist/assets/App-b40e839f.js
vendored
188
invokeai/frontend/web/dist/assets/App-b40e839f.js
vendored
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
||||
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-e1f916bd.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-548a355c.js";var Or=`
|
||||
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-f7f41e1f.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-eaf47ae3.js";var Or=`
|
||||
:root {
|
||||
--chakra-vh: 100vh;
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-e1f916bd.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-f7f41e1f.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index-5483945c.css">
|
||||
</head>
|
||||
|
||||
|
3
invokeai/frontend/web/dist/locales/en.json
vendored
3
invokeai/frontend/web/dist/locales/en.json
vendored
@ -64,6 +64,8 @@
|
||||
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
|
||||
"upload": "Upload",
|
||||
"close": "Close",
|
||||
"cancel": "Cancel",
|
||||
"accept": "Accept",
|
||||
"load": "Load",
|
||||
"back": "Back",
|
||||
"statusConnected": "Connected",
|
||||
@ -333,6 +335,7 @@
|
||||
"addNewModel": "Add New Model",
|
||||
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
|
||||
"addDiffuserModel": "Add Diffusers",
|
||||
"scanForModels": "Scan For Models",
|
||||
"addManually": "Add Manually",
|
||||
"manual": "Manual",
|
||||
"name": "Name",
|
||||
|
37
invokeai/frontend/web/index.d.ts
vendored
37
invokeai/frontend/web/index.d.ts
vendored
@ -1,3 +1,7 @@
|
||||
import React, { PropsWithChildren } from 'react';
|
||||
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
|
||||
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
|
||||
|
||||
export {};
|
||||
|
||||
declare module 'redux-socket.io-middleware';
|
||||
@ -39,3 +43,36 @@ declare global {
|
||||
}
|
||||
/* eslint-enable @typescript-eslint/no-explicit-any */
|
||||
}
|
||||
|
||||
declare module '@invoke-ai/invoke-ai-ui' {
|
||||
declare class ThemeChanger extends React.Component<ThemeChangerProps> {
|
||||
public constructor(props: ThemeChangerProps);
|
||||
}
|
||||
|
||||
declare class InvokeAiLogoComponent extends React.Component<InvokeAILogoComponentProps> {
|
||||
public constructor(props: InvokeAILogoComponentProps);
|
||||
}
|
||||
|
||||
declare class IAIPopover extends React.Component<IAIPopoverProps> {
|
||||
public constructor(props: IAIPopoverProps);
|
||||
}
|
||||
|
||||
declare class IAIIconButton extends React.Component<IAIIconButtonProps> {
|
||||
public constructor(props: IAIIconButtonProps);
|
||||
}
|
||||
|
||||
declare class SettingsModal extends React.Component<SettingsModalProps> {
|
||||
public constructor(props: SettingsModalProps);
|
||||
}
|
||||
}
|
||||
|
||||
declare function Invoke(props: PropsWithChildren): JSX.Element;
|
||||
|
||||
export {
|
||||
ThemeChanger,
|
||||
InvokeAiLogoComponent,
|
||||
IAIPopover,
|
||||
IAIIconButton,
|
||||
SettingsModal,
|
||||
};
|
||||
export = Invoke;
|
||||
|
@ -36,6 +36,7 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@chakra-ui/anatomy": "^2.1.1",
|
||||
"@chakra-ui/cli": "^2.3.0",
|
||||
"@chakra-ui/icons": "^2.0.17",
|
||||
"@chakra-ui/react": "^2.5.1",
|
||||
"@chakra-ui/styled-system": "^2.6.1",
|
||||
@ -52,6 +53,7 @@
|
||||
"i18next-http-backend": "^2.1.1",
|
||||
"konva": "^8.4.2",
|
||||
"lodash": "^4.17.21",
|
||||
"patch-package": "^6.5.1",
|
||||
"re-resizable": "^6.9.9",
|
||||
"react": "^18.2.0",
|
||||
"react-colorful": "^5.6.1",
|
||||
@ -72,7 +74,6 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chakra-ui/cli": "^2.3.0",
|
||||
"@fontsource/inter": "^4.5.15",
|
||||
"@types/dateformat": "^5.0.0",
|
||||
"@types/react": "^18.0.28",
|
||||
@ -92,7 +93,6 @@
|
||||
"husky": "^8.0.3",
|
||||
"lint-staged": "^13.1.2",
|
||||
"madge": "^6.0.0",
|
||||
"patch-package": "^6.5.1",
|
||||
"postinstall-postinstall": "^2.1.0",
|
||||
"prettier": "^2.8.4",
|
||||
"rollup-plugin-visualizer": "^5.9.0",
|
||||
|
@ -64,6 +64,8 @@
|
||||
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
|
||||
"upload": "Upload",
|
||||
"close": "Close",
|
||||
"cancel": "Cancel",
|
||||
"accept": "Accept",
|
||||
"load": "Load",
|
||||
"back": "Back",
|
||||
"statusConnected": "Connected",
|
||||
@ -333,6 +335,7 @@
|
||||
"addNewModel": "Add New Model",
|
||||
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
|
||||
"addDiffuserModel": "Add Diffusers",
|
||||
"scanForModels": "Scan For Models",
|
||||
"addManually": "Add Manually",
|
||||
"manual": "Manual",
|
||||
"name": "Name",
|
||||
|
@ -14,11 +14,11 @@ import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
|
||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||
import { useAppSelector } from './storeHooks';
|
||||
import { useEffect } from 'react';
|
||||
import { PropsWithChildren, useEffect } from 'react';
|
||||
|
||||
keepGUIAlive();
|
||||
|
||||
const App = () => {
|
||||
const App = (props: PropsWithChildren) => {
|
||||
useToastWatcher();
|
||||
|
||||
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
|
||||
@ -40,7 +40,7 @@ const App = () => {
|
||||
w={APP_WIDTH}
|
||||
h={APP_HEIGHT}
|
||||
>
|
||||
<SiteHeader />
|
||||
{props.children || <SiteHeader />}
|
||||
<Flex gap={4} w="full" h="full">
|
||||
<InvokeTabs />
|
||||
<ImageGalleryPanel />
|
||||
|
@ -31,18 +31,14 @@ export const DIFFUSERS_SAMPLERS: Array<string> = [
|
||||
];
|
||||
|
||||
// Valid image widths
|
||||
export const WIDTHS: Array<number> = [
|
||||
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
||||
1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792,
|
||||
1856, 1920, 1984, 2048,
|
||||
];
|
||||
export const WIDTHS: Array<number> = Array.from(Array(65)).map(
|
||||
(_x, i) => i * 64
|
||||
);
|
||||
|
||||
// Valid image heights
|
||||
export const HEIGHTS: Array<number> = [
|
||||
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
||||
1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792,
|
||||
1856, 1920, 1984, 2048,
|
||||
];
|
||||
export const HEIGHTS: Array<number> = Array.from(Array(65)).map(
|
||||
(_x, i) => i * 64
|
||||
);
|
||||
|
||||
// Valid upscaling levels
|
||||
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
|
||||
|
@ -9,6 +9,7 @@ import {
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { cloneElement, memo, ReactElement, ReactNode, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import IAIButton from './IAIButton';
|
||||
|
||||
type Props = {
|
||||
@ -22,10 +23,12 @@ type Props = {
|
||||
};
|
||||
|
||||
const IAIAlertDialog = forwardRef((props: Props, ref) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const {
|
||||
acceptButtonText = 'Accept',
|
||||
acceptButtonText = t('common.accept'),
|
||||
acceptCallback,
|
||||
cancelButtonText = 'Cancel',
|
||||
cancelButtonText = t('common.cancel'),
|
||||
cancelCallback,
|
||||
children,
|
||||
title,
|
||||
@ -56,6 +59,7 @@ const IAIAlertDialog = forwardRef((props: Props, ref) => {
|
||||
isOpen={isOpen}
|
||||
leastDestructiveRef={cancelRef}
|
||||
onClose={onClose}
|
||||
isCentered
|
||||
>
|
||||
<AlertDialogOverlay>
|
||||
<AlertDialogContent>
|
||||
|
8
invokeai/frontend/web/src/common/components/IAIForm.tsx
Normal file
8
invokeai/frontend/web/src/common/components/IAIForm.tsx
Normal file
@ -0,0 +1,8 @@
|
||||
import { chakra } from '@chakra-ui/react';
|
||||
|
||||
/**
|
||||
* Chakra-enabled <form />
|
||||
*/
|
||||
const IAIForm = chakra.form;
|
||||
|
||||
export default IAIForm;
|
@ -0,0 +1,23 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { ReactElement } from 'react';
|
||||
|
||||
export function IAIFormItemWrapper({
|
||||
children,
|
||||
}: {
|
||||
children: ReactElement | ReactElement[];
|
||||
}) {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
rowGap: 4,
|
||||
borderRadius: 'base',
|
||||
width: 'full',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -8,13 +8,14 @@ import {
|
||||
import { memo } from 'react';
|
||||
|
||||
export type IAIIconButtonProps = IconButtonProps & {
|
||||
role?: string;
|
||||
tooltip?: string;
|
||||
tooltipProps?: Omit<TooltipProps, 'children'>;
|
||||
isChecked?: boolean;
|
||||
};
|
||||
|
||||
const IAIIconButton = forwardRef((props: IAIIconButtonProps, forwardedRef) => {
|
||||
const { tooltip = '', tooltipProps, isChecked, ...rest } = props;
|
||||
const { role, tooltip = '', tooltipProps, isChecked, ...rest } = props;
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
@ -27,6 +28,7 @@ const IAIIconButton = forwardRef((props: IAIIconButtonProps, forwardedRef) => {
|
||||
>
|
||||
<IconButton
|
||||
ref={forwardedRef}
|
||||
role={role}
|
||||
aria-checked={isChecked !== undefined ? isChecked : undefined}
|
||||
{...rest}
|
||||
/>
|
||||
@ -34,4 +36,5 @@ const IAIIconButton = forwardRef((props: IAIIconButtonProps, forwardedRef) => {
|
||||
);
|
||||
});
|
||||
|
||||
IAIIconButton.displayName = 'IAIIconButton';
|
||||
export default memo(IAIIconButton);
|
||||
|
@ -8,7 +8,7 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import { memo, ReactNode } from 'react';
|
||||
|
||||
type IAIPopoverProps = PopoverProps & {
|
||||
export type IAIPopoverProps = PopoverProps & {
|
||||
triggerComponent: ReactNode;
|
||||
triggerContainerProps?: BoxProps;
|
||||
children: ReactNode;
|
||||
|
@ -1,4 +1,4 @@
|
||||
import React, { lazy } from 'react';
|
||||
import React, { lazy, PropsWithChildren } from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { PersistGate } from 'redux-persist/integration/react';
|
||||
import { store } from './app/store';
|
||||
@ -21,14 +21,14 @@ import './i18n';
|
||||
const App = lazy(() => import('./app/App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
|
||||
|
||||
export default function Component() {
|
||||
export default function Component(props: PropsWithChildren) {
|
||||
return (
|
||||
<React.StrictMode>
|
||||
<Provider store={store}>
|
||||
<PersistGate loading={<Loading />} persistor={persistor}>
|
||||
<React.Suspense fallback={<Loading showText />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App />
|
||||
<App>{props.children}</App>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
</PersistGate>
|
||||
|
16
invokeai/frontend/web/src/exports.tsx
Normal file
16
invokeai/frontend/web/src/exports.tsx
Normal file
@ -0,0 +1,16 @@
|
||||
import Component from './component';
|
||||
|
||||
import InvokeAiLogoComponent from './features/system/components/InvokeAILogoComponent';
|
||||
import ThemeChanger from './features/system/components/ThemeChanger';
|
||||
import IAIPopover from './common/components/IAIPopover';
|
||||
import IAIIconButton from './common/components/IAIIconButton';
|
||||
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
|
||||
|
||||
export default Component;
|
||||
export {
|
||||
InvokeAiLogoComponent,
|
||||
ThemeChanger,
|
||||
IAIPopover,
|
||||
IAIIconButton,
|
||||
SettingsModal,
|
||||
};
|
@ -104,7 +104,6 @@ const IAICanvasMaskOptions = () => {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<ButtonGroup>
|
||||
<IAIIconButton
|
||||
|
@ -88,7 +88,7 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
isLazy={false}
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
tooltip={t('unifiedCanvas.canvasSettings')}
|
||||
|
@ -219,7 +219,6 @@ const IAICanvasToolChooserOptions = () => {
|
||||
onClick={handleSelectColorPickerTool}
|
||||
/>
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
aria-label={t('unifiedCanvas.brushOptions')}
|
||||
|
@ -405,7 +405,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
>
|
||||
<ButtonGroup isAttached={true}>
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
aria-label={`${t('parameters.sendTo')}...`}
|
||||
@ -505,7 +504,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
|
||||
<ButtonGroup isAttached={true}>
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
icon={<FaGrinStars />}
|
||||
@ -535,7 +533,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
</IAIPopover>
|
||||
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
icon={<FaExpandArrowsAlt />}
|
||||
|
@ -0,0 +1,24 @@
|
||||
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
|
||||
|
||||
type CurrentImageFallbackProps = SpinnerProps;
|
||||
|
||||
const CurrentImageFallback = (props: CurrentImageFallbackProps) => {
|
||||
const { size = 'xl', ...rest } = props;
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
position: 'absolute',
|
||||
color: 'base.400',
|
||||
}}
|
||||
>
|
||||
<Spinner size={size} {...rest} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default CurrentImageFallback;
|
@ -7,6 +7,7 @@ import { isEqual } from 'lodash';
|
||||
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
|
||||
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import CurrentImageFallback from './CurrentImageFallback';
|
||||
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||
|
||||
@ -48,6 +49,7 @@ export default function CurrentImagePreview() {
|
||||
src={imageToDisplay.url}
|
||||
width={imageToDisplay.width}
|
||||
height={imageToDisplay.height}
|
||||
fallback={!isIntermediate ? <CurrentImageFallback /> : undefined}
|
||||
sx={{
|
||||
objectFit: 'contain',
|
||||
maxWidth: '100%',
|
||||
|
@ -101,12 +101,14 @@ const ImageGalleryContent = () => {
|
||||
aria-label={t('gallery.showGenerations')}
|
||||
tooltip={t('gallery.showGenerations')}
|
||||
isChecked={currentCategory === 'result'}
|
||||
role="radio"
|
||||
icon={<FaImage />}
|
||||
onClick={() => dispatch(setCurrentCategory('result'))}
|
||||
/>
|
||||
<IAIIconButton
|
||||
aria-label={t('gallery.showUploads')}
|
||||
tooltip={t('gallery.showUploads')}
|
||||
role="radio"
|
||||
isChecked={currentCategory === 'user'}
|
||||
icon={<FaUser />}
|
||||
onClick={() => dispatch(setCurrentCategory('user'))}
|
||||
@ -251,4 +253,5 @@ const ImageGalleryContent = () => {
|
||||
);
|
||||
};
|
||||
|
||||
ImageGalleryContent.displayName = 'ImageGalleryContent';
|
||||
export default ImageGalleryContent;
|
||||
|
@ -55,7 +55,6 @@ export default function LanguagePicker() {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
aria-label={t('common.languagePickerLabel')}
|
||||
|
@ -1,4 +1,5 @@
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormHelperText,
|
||||
@ -25,10 +26,10 @@ import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { InvokeModelConfigProps } from 'app/invokeai';
|
||||
import type { RootState } from 'app/store';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import type { FieldInputProps, FormikProps } from 'formik';
|
||||
import { BiArrowBack } from 'react-icons/bi';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
|
||||
|
||||
const MIN_MODEL_SIZE = 64;
|
||||
const MAX_MODEL_SIZE = 2048;
|
||||
@ -72,243 +73,250 @@ export default function AddCheckpointModel() {
|
||||
|
||||
return (
|
||||
<VStack gap={2} alignItems="flex-start">
|
||||
<IAIIconButton
|
||||
aria-label={t('common.back')}
|
||||
tooltip={t('common.back')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
||||
width="max-content"
|
||||
position="absolute"
|
||||
zIndex={1}
|
||||
size="sm"
|
||||
insetInlineEnd={12}
|
||||
top={3}
|
||||
icon={<BiArrowBack />}
|
||||
/>
|
||||
<Flex columnGap={4}>
|
||||
<IAICheckbox
|
||||
isChecked={!addManually}
|
||||
label={t('modelManager.scanForModels')}
|
||||
onChange={() => setAddmanually(!addManually)}
|
||||
/>
|
||||
<IAICheckbox
|
||||
label={t('modelManager.addManually')}
|
||||
isChecked={addManually}
|
||||
onChange={() => setAddmanually(!addManually)}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
<SearchModels />
|
||||
<IAICheckbox
|
||||
label={t('modelManager.addManually')}
|
||||
isChecked={addManually}
|
||||
onChange={() => setAddmanually(!addManually)}
|
||||
/>
|
||||
|
||||
{addManually && (
|
||||
{addManually ? (
|
||||
<Formik
|
||||
initialValues={addModelFormValues}
|
||||
onSubmit={addModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<IAIForm onSubmit={handleSubmit} sx={{ w: 'full' }}>
|
||||
<VStack rowGap={2}>
|
||||
<Text fontSize={20} fontWeight="bold" alignSelf="start">
|
||||
{t('modelManager.manual')}
|
||||
</Text>
|
||||
{/* Name */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.name && touched.name}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="name" fontSize="sm">
|
||||
{t('modelManager.name')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="name"
|
||||
name="name"
|
||||
type="text"
|
||||
validate={baseValidation}
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.name && touched.name ? (
|
||||
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.nameValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl
|
||||
isInvalid={!!errors.name && touched.name}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="name" fontSize="sm">
|
||||
{t('modelManager.name')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="name"
|
||||
name="name"
|
||||
type="text"
|
||||
validate={baseValidation}
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.name && touched.name ? (
|
||||
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.nameValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="description" fontSize="sm">
|
||||
{t('modelManager.description')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="description"
|
||||
name="description"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.description && touched.description ? (
|
||||
<FormErrorMessage>{errors.description}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.descriptionValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="description" fontSize="sm">
|
||||
{t('modelManager.description')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="description"
|
||||
name="description"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.description && touched.description ? (
|
||||
<FormErrorMessage>
|
||||
{errors.description}
|
||||
</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.descriptionValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* Config */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.config && touched.config}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelManager.config')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="config"
|
||||
name="config"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.config && touched.config ? (
|
||||
<FormErrorMessage>{errors.config}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.configValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl
|
||||
isInvalid={!!errors.config && touched.config}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelManager.config')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="config"
|
||||
name="config"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.config && touched.config ? (
|
||||
<FormErrorMessage>{errors.config}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.configValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* Weights */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.weights && touched.weights}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelManager.modelLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="weights"
|
||||
name="weights"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.weights && touched.weights ? (
|
||||
<FormErrorMessage>{errors.weights}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.modelLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl
|
||||
isInvalid={!!errors.weights && touched.weights}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelManager.modelLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="weights"
|
||||
name="weights"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.weights && touched.weights ? (
|
||||
<FormErrorMessage>{errors.weights}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.modelLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* VAE */}
|
||||
<FormControl isInvalid={!!errors.vae && touched.vae}>
|
||||
<FormLabel htmlFor="vae" fontSize="sm">
|
||||
{t('modelManager.vaeLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae"
|
||||
name="vae"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.vae && touched.vae ? (
|
||||
<FormErrorMessage>{errors.vae}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.vaeLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl isInvalid={!!errors.vae && touched.vae}>
|
||||
<FormLabel htmlFor="vae" fontSize="sm">
|
||||
{t('modelManager.vaeLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae"
|
||||
name="vae"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.vae && touched.vae ? (
|
||||
<FormErrorMessage>{errors.vae}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.vaeLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<HStack width="100%">
|
||||
{/* Width */}
|
||||
<FormControl isInvalid={!!errors.width && touched.width}>
|
||||
<FormLabel htmlFor="width" fontSize="sm">
|
||||
{t('modelManager.width')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field id="width" name="width">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="width"
|
||||
name="width"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
step={64}
|
||||
width="90%"
|
||||
value={form.values.width}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl isInvalid={!!errors.width && touched.width}>
|
||||
<FormLabel htmlFor="width" fontSize="sm">
|
||||
{t('modelManager.width')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field id="width" name="width">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="width"
|
||||
name="width"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
step={64}
|
||||
value={form.values.width}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
|
||||
{!!errors.width && touched.width ? (
|
||||
<FormErrorMessage>{errors.width}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.widthValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
{!!errors.width && touched.width ? (
|
||||
<FormErrorMessage>{errors.width}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.widthValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* Height */}
|
||||
<FormControl isInvalid={!!errors.height && touched.height}>
|
||||
<FormLabel htmlFor="height" fontSize="sm">
|
||||
{t('modelManager.height')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field id="height" name="height">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="height"
|
||||
name="height"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
width="90%"
|
||||
step={64}
|
||||
value={form.values.height}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl isInvalid={!!errors.height && touched.height}>
|
||||
<FormLabel htmlFor="height" fontSize="sm">
|
||||
{t('modelManager.height')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field id="height" name="height">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="height"
|
||||
name="height"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
step={64}
|
||||
value={form.values.height}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
|
||||
{!!errors.height && touched.height ? (
|
||||
<FormErrorMessage>{errors.height}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.heightValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
{!!errors.height && touched.height ? (
|
||||
<FormErrorMessage>{errors.height}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.heightValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
</HStack>
|
||||
|
||||
<IAIButton
|
||||
@ -319,9 +327,11 @@ export default function AddCheckpointModel() {
|
||||
{t('modelManager.addModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
) : (
|
||||
<SearchModels />
|
||||
)}
|
||||
</VStack>
|
||||
);
|
||||
|
@ -11,36 +11,14 @@ import { InvokeDiffusersModelConfigProps } from 'app/invokeai';
|
||||
import { addNewModel } from 'app/socketio/actions';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import { Field, Formik } from 'formik';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { BiArrowBack } from 'react-icons/bi';
|
||||
|
||||
import type { RootState } from 'app/store';
|
||||
import type { ReactElement } from 'react';
|
||||
|
||||
function FormItemWrapper({
|
||||
children,
|
||||
}: {
|
||||
children: ReactElement | ReactElement[];
|
||||
}) {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
rowGap: 4,
|
||||
borderRadius: 'base',
|
||||
width: 'full',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
|
||||
|
||||
export default function AddDiffusersModel() {
|
||||
const dispatch = useAppDispatch();
|
||||
@ -89,26 +67,14 @@ export default function AddDiffusersModel() {
|
||||
|
||||
return (
|
||||
<Flex>
|
||||
<IAIIconButton
|
||||
aria-label={t('common.back')}
|
||||
tooltip={t('common.back')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
||||
width="max-content"
|
||||
position="absolute"
|
||||
zIndex={1}
|
||||
size="sm"
|
||||
insetInlineEnd={12}
|
||||
top={3}
|
||||
icon={<BiArrowBack />}
|
||||
/>
|
||||
<Formik
|
||||
initialValues={addModelFormValues}
|
||||
onSubmit={addModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<IAIForm onSubmit={handleSubmit}>
|
||||
<VStack rowGap={2}>
|
||||
<FormItemWrapper>
|
||||
<IAIFormItemWrapper>
|
||||
{/* Name */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.name && touched.name}
|
||||
@ -136,9 +102,9 @@ export default function AddDiffusersModel() {
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<FormItemWrapper>
|
||||
<IAIFormItemWrapper>
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
@ -165,9 +131,9 @@ export default function AddDiffusersModel() {
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<FormItemWrapper>
|
||||
<IAIFormItemWrapper>
|
||||
<Text fontWeight="bold" fontSize="sm">
|
||||
{t('modelManager.formMessageDiffusersModelLocation')}
|
||||
</Text>
|
||||
@ -226,9 +192,9 @@ export default function AddDiffusersModel() {
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<FormItemWrapper>
|
||||
<IAIFormItemWrapper>
|
||||
{/* VAE Path */}
|
||||
<Text fontWeight="bold">
|
||||
{t('modelManager.formMessageDiffusersVAELocation')}
|
||||
@ -290,13 +256,13 @@ export default function AddDiffusersModel() {
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<IAIButton type="submit" isLoading={isProcessing}>
|
||||
{t('modelManager.addModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
</Flex>
|
||||
|
@ -14,7 +14,7 @@ import {
|
||||
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
|
||||
import { FaPlus } from 'react-icons/fa';
|
||||
import { FaArrowLeft, FaPlus } from 'react-icons/fa';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -23,6 +23,7 @@ import type { RootState } from 'app/store';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import AddCheckpointModel from './AddCheckpointModel';
|
||||
import AddDiffusersModel from './AddDiffusersModel';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
|
||||
function AddModelBox({
|
||||
text,
|
||||
@ -83,8 +84,22 @@ export default function AddModel() {
|
||||
closeOnOverlayClick={false}
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent margin="auto" paddingInlineEnd={4}>
|
||||
<ModalHeader>{t('modelManager.addNewModel')}</ModalHeader>
|
||||
<ModalContent margin="auto">
|
||||
<ModalHeader>{t('modelManager.addNewModel')} </ModalHeader>
|
||||
{addNewModelUIOption !== null && (
|
||||
<IAIIconButton
|
||||
aria-label={t('common.back')}
|
||||
tooltip={t('common.back')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
||||
position="absolute"
|
||||
variant="ghost"
|
||||
zIndex={1}
|
||||
size="sm"
|
||||
insetInlineEnd={12}
|
||||
top={2}
|
||||
icon={<FaArrowLeft />}
|
||||
/>
|
||||
)}
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
{addNewModelUIOption == null && (
|
||||
|
@ -28,6 +28,7 @@ import { isEqual, pickBy } from 'lodash';
|
||||
import ModelConvert from './ModelConvert';
|
||||
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
||||
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
|
||||
const selector = createSelector(
|
||||
[systemSelector],
|
||||
@ -120,7 +121,7 @@ export default function CheckpointModelEdit() {
|
||||
onSubmit={editModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<IAIForm onSubmit={handleSubmit}>
|
||||
<VStack rowGap={2} alignItems="start">
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
@ -317,7 +318,7 @@ export default function CheckpointModelEdit() {
|
||||
{t('modelManager.updateModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
</Flex>
|
||||
|
@ -18,6 +18,7 @@ import type { RootState } from 'app/store';
|
||||
import { isEqual, pickBy } from 'lodash';
|
||||
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
||||
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
|
||||
const selector = createSelector(
|
||||
[systemSelector],
|
||||
@ -116,7 +117,7 @@ export default function DiffusersModelEdit() {
|
||||
onSubmit={editModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<IAIForm onSubmit={handleSubmit}>
|
||||
<VStack rowGap={2} alignItems="start">
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
@ -259,7 +260,7 @@ export default function DiffusersModelEdit() {
|
||||
{t('modelManager.updateModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
</Flex>
|
||||
|
@ -12,14 +12,13 @@ import {
|
||||
RadioGroup,
|
||||
Spacer,
|
||||
Text,
|
||||
VStack,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { FaPlus, FaSearch } from 'react-icons/fa';
|
||||
import { FaSearch, FaTrash } from 'react-icons/fa';
|
||||
|
||||
import { addNewModel, searchForModels } from 'app/socketio/actions';
|
||||
import {
|
||||
@ -34,7 +33,7 @@ import IAIInput from 'common/components/IAIInput';
|
||||
import { Field, Formik } from 'formik';
|
||||
import { forEach, remove } from 'lodash';
|
||||
import type { ChangeEvent, ReactNode } from 'react';
|
||||
import { BiReset } from 'react-icons/bi';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
|
||||
const existingModelsSelector = createSelector([systemSelector], (system) => {
|
||||
const { model_list } = system;
|
||||
@ -71,34 +70,32 @@ function SearchModelEntry({
|
||||
};
|
||||
|
||||
return (
|
||||
<VStack>
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
gap={2}
|
||||
backgroundColor={
|
||||
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
|
||||
}
|
||||
paddingX={4}
|
||||
paddingY={2}
|
||||
borderRadius={4}
|
||||
>
|
||||
<Flex gap={4}>
|
||||
<IAICheckbox
|
||||
value={model.name}
|
||||
label={<Text fontWeight={500}>{model.name}</Text>}
|
||||
isChecked={modelsToAdd.includes(model.name)}
|
||||
isDisabled={existingModels.includes(model.location)}
|
||||
onChange={foundModelsChangeHandler}
|
||||
></IAICheckbox>
|
||||
{existingModels.includes(model.location) && (
|
||||
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
|
||||
)}
|
||||
</Flex>
|
||||
<Text fontStyle="italic" variant="subtext">
|
||||
{model.location}
|
||||
</Text>
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
gap={2}
|
||||
backgroundColor={
|
||||
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
|
||||
}
|
||||
paddingX={4}
|
||||
paddingY={2}
|
||||
borderRadius={4}
|
||||
>
|
||||
<Flex gap={4} alignItems="center" justifyContent="space-between">
|
||||
<IAICheckbox
|
||||
value={model.name}
|
||||
label={<Text fontWeight={500}>{model.name}</Text>}
|
||||
isChecked={modelsToAdd.includes(model.name)}
|
||||
isDisabled={existingModels.includes(model.location)}
|
||||
onChange={foundModelsChangeHandler}
|
||||
></IAICheckbox>
|
||||
{existingModels.includes(model.location) && (
|
||||
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
|
||||
)}
|
||||
</Flex>
|
||||
</VStack>
|
||||
<Text fontStyle="italic" variant="subtext">
|
||||
{model.location}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
@ -215,10 +212,10 @@ export default function SearchModels() {
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex flexDirection="column" rowGap={4}>
|
||||
{newFoundModels}
|
||||
{shouldShowExistingModelsInSearch && existingFoundModels}
|
||||
</>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@ -245,26 +242,26 @@ export default function SearchModels() {
|
||||
<Text
|
||||
sx={{
|
||||
fontWeight: 500,
|
||||
fontSize: 'sm',
|
||||
}}
|
||||
variant="subtext"
|
||||
>
|
||||
{t('modelManager.checkpointFolder')}
|
||||
</Text>
|
||||
<Text sx={{ fontWeight: 500, fontSize: 'sm' }}>{searchFolder}</Text>
|
||||
<Text sx={{ fontWeight: 500 }}>{searchFolder}</Text>
|
||||
</Flex>
|
||||
<Spacer />
|
||||
<IAIIconButton
|
||||
aria-label={t('modelManager.scanAgain')}
|
||||
tooltip={t('modelManager.scanAgain')}
|
||||
icon={<BiReset />}
|
||||
icon={<FaSearch />}
|
||||
fontSize={18}
|
||||
disabled={isProcessing}
|
||||
onClick={() => dispatch(searchForModels(searchFolder))}
|
||||
/>
|
||||
<IAIIconButton
|
||||
aria-label={t('modelManager.clearCheckpointFolder')}
|
||||
icon={<FaPlus style={{ transform: 'rotate(45deg)' }} />}
|
||||
tooltip={t('modelManager.clearCheckpointFolder')}
|
||||
icon={<FaTrash />}
|
||||
onClick={resetSearchModelHandler}
|
||||
/>
|
||||
</Flex>
|
||||
@ -276,9 +273,9 @@ export default function SearchModels() {
|
||||
}}
|
||||
>
|
||||
{({ handleSubmit }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<HStack columnGap={2} alignItems="flex-end" width="100%">
|
||||
<FormControl isRequired width="lg">
|
||||
<IAIForm onSubmit={handleSubmit} width="100%">
|
||||
<HStack columnGap={2} alignItems="flex-end">
|
||||
<FormControl flexGrow={1}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="checkpointFolder"
|
||||
@ -294,12 +291,12 @@ export default function SearchModels() {
|
||||
tooltip={t('modelManager.findModels')}
|
||||
type="submit"
|
||||
disabled={isProcessing}
|
||||
paddingX={10}
|
||||
px={8}
|
||||
>
|
||||
{t('modelManager.findModels')}
|
||||
</IAIButton>
|
||||
</HStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
)}
|
||||
@ -410,7 +407,6 @@ export default function SearchModels() {
|
||||
maxHeight={72}
|
||||
overflowY="scroll"
|
||||
borderRadius="sm"
|
||||
paddingInlineEnd={4}
|
||||
gap={2}
|
||||
>
|
||||
{foundModels.length > 0 ? (
|
||||
|
@ -61,47 +61,53 @@ const SiteHeader = () => {
|
||||
|
||||
<LanguagePicker />
|
||||
|
||||
<IAIIconButton
|
||||
aria-label={t('common.reportBugLabel')}
|
||||
tooltip={t('common.reportBugLabel')}
|
||||
variant="link"
|
||||
data-variant="link"
|
||||
fontSize={20}
|
||||
size="sm"
|
||||
icon={
|
||||
<Link isExternal href="http://github.com/invoke-ai/InvokeAI/issues">
|
||||
<FaBug />
|
||||
</Link>
|
||||
}
|
||||
/>
|
||||
<Link
|
||||
isExternal
|
||||
href="http://github.com/invoke-ai/InvokeAI/issues"
|
||||
marginBottom="-0.25rem"
|
||||
>
|
||||
<IAIIconButton
|
||||
aria-label={t('common.reportBugLabel')}
|
||||
tooltip={t('common.reportBugLabel')}
|
||||
variant="link"
|
||||
data-variant="link"
|
||||
fontSize={20}
|
||||
size="sm"
|
||||
icon={<FaBug />}
|
||||
/>
|
||||
</Link>
|
||||
|
||||
<IAIIconButton
|
||||
aria-label={t('common.githubLabel')}
|
||||
tooltip={t('common.githubLabel')}
|
||||
variant="link"
|
||||
data-variant="link"
|
||||
fontSize={20}
|
||||
size="sm"
|
||||
icon={
|
||||
<Link isExternal href="http://github.com/invoke-ai/InvokeAI">
|
||||
<FaGithub />
|
||||
</Link>
|
||||
}
|
||||
/>
|
||||
<Link
|
||||
isExternal
|
||||
href="http://github.com/invoke-ai/InvokeAI"
|
||||
marginBottom="-0.25rem"
|
||||
>
|
||||
<IAIIconButton
|
||||
aria-label={t('common.githubLabel')}
|
||||
tooltip={t('common.githubLabel')}
|
||||
variant="link"
|
||||
data-variant="link"
|
||||
fontSize={20}
|
||||
size="sm"
|
||||
icon={<FaGithub />}
|
||||
/>
|
||||
</Link>
|
||||
|
||||
<IAIIconButton
|
||||
aria-label={t('common.discordLabel')}
|
||||
tooltip={t('common.discordLabel')}
|
||||
variant="link"
|
||||
data-variant="link"
|
||||
fontSize={20}
|
||||
size="sm"
|
||||
icon={
|
||||
<Link isExternal href="https://discord.gg/ZmtBAhwWhy">
|
||||
<FaDiscord />
|
||||
</Link>
|
||||
}
|
||||
/>
|
||||
<Link
|
||||
isExternal
|
||||
href="https://discord.gg/ZmtBAhwWhy"
|
||||
marginBottom="-0.25rem"
|
||||
>
|
||||
<IAIIconButton
|
||||
aria-label={t('common.discordLabel')}
|
||||
tooltip={t('common.discordLabel')}
|
||||
variant="link"
|
||||
data-variant="link"
|
||||
fontSize={20}
|
||||
size="sm"
|
||||
icon={<FaDiscord />}
|
||||
/>
|
||||
</Link>
|
||||
|
||||
<SettingsModal>
|
||||
<IAIIconButton
|
||||
@ -119,4 +125,5 @@ const SiteHeader = () => {
|
||||
);
|
||||
};
|
||||
|
||||
SiteHeader.displayName = 'SiteHeader';
|
||||
export default SiteHeader;
|
||||
|
@ -50,7 +50,6 @@ export default function ThemeChanger() {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
aria-label={t('common.themeLabel')}
|
||||
|
@ -166,20 +166,8 @@ export default function InvokeTabs() {
|
||||
[]
|
||||
);
|
||||
|
||||
/**
|
||||
* isLazy means the tabs are mounted and unmounted when changing them. There is a tradeoff here,
|
||||
* as mounting is expensive, but so is retaining all tabs in the DOM at all times.
|
||||
*
|
||||
* Removing isLazy messes with the outside click watcher, which is used by ResizableDrawer.
|
||||
* Because you have multiple handlers listening for an outside click, any click anywhere triggers
|
||||
* the watcher for the hidden drawers, closing the open drawer.
|
||||
*
|
||||
* TODO: Add logic to the `useOutsideClick` in ResizableDrawer to enable it only for the active
|
||||
* tab's drawer.
|
||||
*/
|
||||
return (
|
||||
<Tabs
|
||||
isLazy
|
||||
defaultIndex={activeTab}
|
||||
index={activeTab}
|
||||
onChange={(index: number) => {
|
||||
|
@ -93,12 +93,9 @@ const ResizableDrawer = ({
|
||||
useOutsideClick({
|
||||
ref: outsideClickRef,
|
||||
handler: () => {
|
||||
if (isPinned) {
|
||||
return;
|
||||
}
|
||||
|
||||
onClose();
|
||||
},
|
||||
enabled: isOpen && !isPinned,
|
||||
});
|
||||
|
||||
const handleEnables = useMemo(
|
||||
|
@ -77,7 +77,6 @@ export default function UnifiedCanvasColorPicker() {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<Box
|
||||
sx={{
|
||||
|
@ -56,7 +56,7 @@ const UnifiedCanvasSettings = () => {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
isLazy={false}
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
tooltip={t('unifiedCanvas.canvasSettings')}
|
||||
|
File diff suppressed because one or more lines are too long
@ -46,7 +46,7 @@ export default defineConfig(({ mode }) => {
|
||||
* overrides any target specified here.
|
||||
*/
|
||||
// target: 'esnext',
|
||||
chunkSizeWarningLimit: 1500, // we don't really care about chunk size
|
||||
chunkSizeWarningLimit: 1500, // we don't really care about chunk size,
|
||||
},
|
||||
};
|
||||
if (mode == 'development') {
|
||||
|
@ -38,16 +38,16 @@ dependencies = [
|
||||
"albumentations",
|
||||
"click",
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==0.1.10",
|
||||
"compel==1.0.4",
|
||||
"datasets",
|
||||
"diffusers[torch]~=0.14",
|
||||
"dnspython==2.2.1",
|
||||
"einops",
|
||||
"eventlet",
|
||||
"facexlib",
|
||||
"fastapi==0.85.0",
|
||||
"fastapi-events==0.6.0",
|
||||
"fastapi-socketio==0.0.9",
|
||||
"fastapi==0.94.1",
|
||||
"fastapi-events==0.8.0",
|
||||
"fastapi-socketio==0.0.10",
|
||||
"flask==2.1.3",
|
||||
"flask_cors==3.0.10",
|
||||
"flask_socketio==5.3.0",
|
||||
@ -75,7 +75,7 @@ dependencies = [
|
||||
"torchvision>=0.14.1",
|
||||
"torchmetrics",
|
||||
"transformers~=4.26",
|
||||
"uvicorn[standard]==0.20.0",
|
||||
"uvicorn[standard]==0.21.1",
|
||||
"windows-curses; sys_platform=='win32'",
|
||||
]
|
||||
|
||||
@ -139,8 +139,24 @@ version = { attr = "invokeai.version.__version__" }
|
||||
"invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"]
|
||||
"invokeai.frontend.web.dist" = ["**"]
|
||||
|
||||
#=== Begin: PyTest and Coverage
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "-p pytest_cov --junitxml=junit/test-results.xml --cov-report=term:skip-covered --cov=ldm/invoke --cov=backend --cov-branch"
|
||||
addopts = "--cov-report term --cov-report html --cov-report xml"
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
source = ["invokeai"]
|
||||
omit = ["*tests*", "*migrations*", ".venv/*", "*.env"]
|
||||
[tool.coverage.report]
|
||||
show_missing = true
|
||||
fail_under = 85 # let's set something sensible on Day 1 ...
|
||||
[tool.coverage.json]
|
||||
output = "coverage/coverage.json"
|
||||
pretty_print = true
|
||||
[tool.coverage.html]
|
||||
directory = "coverage/html"
|
||||
[tool.coverage.xml]
|
||||
output = "coverage/index.xml"
|
||||
#=== End: PyTest and Coverage
|
||||
|
||||
[flake8]
|
||||
max-line-length = 120
|
||||
|
@ -105,17 +105,20 @@
|
||||
|
||||
// Start building nodes
|
||||
var id = 1;
|
||||
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "sampler": sampler, "steps": steps, "seed": seed};
|
||||
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": seed};
|
||||
id++;
|
||||
var i2iNode = {"id": id.toString(), "type": "img2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": Math.floor(Math.random() * 10000)};
|
||||
id++;
|
||||
var upscaleNode = {"id": id.toString(), "type": "show_image" };
|
||||
id++
|
||||
|
||||
nodes = {};
|
||||
nodes[initialNode.id] = initialNode;
|
||||
nodes[i2iNode.id] = i2iNode;
|
||||
nodes[upscaleNode.id] = upscaleNode;
|
||||
links = [
|
||||
[{ "node_id": initialNode.id, field: "image" },
|
||||
{ "node_id": upscaleNode.id, field: "image" }]
|
||||
{ "source": { "node_id": initialNode.id, field: "image" }, "destination": { "node_id": i2iNode.id, field: "image" }},
|
||||
{ "source": { "node_id": i2iNode.id, field: "image" }, "destination": { "node_id": upscaleNode.id, field: "image" }}
|
||||
];
|
||||
// expandSize = 128;
|
||||
// for (var i = 0; i < 6; ++i) {
|
||||
|
@ -1,15 +1,18 @@
|
||||
from invokeai.app.invocations.image import *
|
||||
|
||||
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
|
||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
||||
from invokeai.app.invocations.upscale import UpscaleInvocation
|
||||
import pytest
|
||||
|
||||
|
||||
# Helpers
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
|
||||
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
|
||||
return Edge(
|
||||
source=EdgeConnection(node_id = from_id, field = from_field),
|
||||
destination=EdgeConnection(node_id = to_id, field = to_field)
|
||||
)
|
||||
|
||||
# Tests
|
||||
def test_connections_are_compatible():
|
||||
@ -108,7 +111,7 @@ def test_graph_allows_non_conflicting_id_change():
|
||||
assert g.get_node("3").prompt == "Banana sushi"
|
||||
|
||||
assert len(g.edges) == 1
|
||||
assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges
|
||||
assert Edge(source=EdgeConnection(node_id = "3", field = "image"), destination=EdgeConnection(node_id = "2", field = "image")) in g.edges
|
||||
|
||||
def test_graph_fails_to_update_node_id_if_conflict():
|
||||
g = Graph()
|
||||
@ -490,10 +493,10 @@ def test_graph_can_deserialize():
|
||||
assert g2.nodes['1'] is not None
|
||||
assert g2.nodes['2'] is not None
|
||||
assert len(g2.edges) == 1
|
||||
assert g2.edges[0][0].node_id == '1'
|
||||
assert g2.edges[0][0].field == 'image'
|
||||
assert g2.edges[0][1].node_id == '2'
|
||||
assert g2.edges[0][1].field == 'image'
|
||||
assert g2.edges[0].source.node_id == '1'
|
||||
assert g2.edges[0].source.field == 'image'
|
||||
assert g2.edges[0].destination.node_id == '2'
|
||||
assert g2.edges[0].destination.field == 'image'
|
||||
|
||||
def test_graph_can_generate_schema():
|
||||
# Not throwing on this line is sufficient
|
||||
|
@ -64,10 +64,12 @@ class PromptCollectionTestInvocation(BaseInvocation):
|
||||
|
||||
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import EdgeConnection
|
||||
from invokeai.app.services.graph import Edge, EdgeConnection
|
||||
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
|
||||
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
|
||||
return Edge(
|
||||
source=EdgeConnection(node_id = from_id, field = from_field),
|
||||
destination=EdgeConnection(node_id = to_id, field = to_field))
|
||||
|
||||
|
||||
class TestEvent:
|
||||
|
Loading…
Reference in New Issue
Block a user