mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into patch-1
This commit is contained in:
commit
07ea806553
@ -1,6 +0,0 @@
|
||||
[run]
|
||||
omit='.env/*'
|
||||
source='.'
|
||||
|
||||
[report]
|
||||
show_missing = true
|
8
.github/CODEOWNERS
vendored
8
.github/CODEOWNERS
vendored
@ -1,16 +1,16 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @mauwii @lstein
|
||||
/.github/workflows/ @mauwii @lstein @blessedcoolant
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @mauwii @tildebyte
|
||||
/mkdocs.yml @lstein @mauwii
|
||||
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
|
||||
/mkdocs.yml @lstein @mauwii @blessedcoolant
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @mauwii @lstein @blessedcoolant
|
||||
/docker/ @mauwii @lstein
|
||||
/docker/ @mauwii @lstein @blessedcoolant
|
||||
/scripts/ @ebr @lstein
|
||||
/installer/ @lstein @ebr
|
||||
/invokeai/assets @lstein @ebr
|
||||
|
10
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
10
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
@ -65,6 +65,16 @@ body:
|
||||
placeholder: 8GB
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: version-number
|
||||
attributes:
|
||||
label: What version did you experience this issue on?
|
||||
description: |
|
||||
Please share the version of Invoke AI that you experienced the issue on. If this is not the latest version, please update first to confirm the issue still exists. If you are testing main, please include the commit hash instead.
|
||||
placeholder: X.X.X
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: what-happened
|
||||
|
2
.github/workflows/mkdocs-material.yml
vendored
2
.github/workflows/mkdocs-material.yml
vendored
@ -34,6 +34,8 @@ jobs:
|
||||
|
||||
- name: deploy to gh-pages
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
python -m \
|
||||
mkdocs gh-deploy \
|
||||
|
1
.github/workflows/test-invoke-pip-skip.yml
vendored
1
.github/workflows/test-invoke-pip-skip.yml
vendored
@ -6,7 +6,6 @@ on:
|
||||
- '!pyproject.toml'
|
||||
- '!invokeai/**'
|
||||
- 'invokeai/frontend/web/**'
|
||||
- '!invokeai/frontend/web/dist/**'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
|
2
.github/workflows/test-invoke-pip.yml
vendored
2
.github/workflows/test-invoke-pip.yml
vendored
@ -7,13 +7,11 @@ on:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'invokeai/frontend/web/dist/**'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'invokeai/frontend/web/dist/**'
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -63,6 +63,7 @@ pip-delete-this-directory.txt
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coveragerc
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
@ -73,6 +74,7 @@ cov.xml
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
.pytest.ini
|
||||
cover/
|
||||
junit/
|
||||
|
||||
|
@ -1,5 +0,0 @@
|
||||
[pytest]
|
||||
DJANGO_SETTINGS_MODULE = webtas.settings
|
||||
; python_files = tests.py test_*.py *_tests.py
|
||||
|
||||
addopts = --cov=. --cov-config=.coveragerc --cov-report xml:cov.xml
|
@ -139,7 +139,7 @@ not supported.
|
||||
_For Windows/Linux with an NVIDIA GPU:_
|
||||
|
||||
```terminal
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
```
|
||||
|
||||
_For Linux with an AMD GPU:_
|
||||
|
4
coverage/.gitignore
vendored
Normal file
4
coverage/.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
# Ignore everything in this directory
|
||||
*
|
||||
# Except this file
|
||||
!.gitignore
|
BIN
docs/assets/contributing/html-detail.png
Normal file
BIN
docs/assets/contributing/html-detail.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 470 KiB |
BIN
docs/assets/contributing/html-overview.png
Normal file
BIN
docs/assets/contributing/html-overview.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 457 KiB |
83
docs/contributing/LOCAL_DEVELOPMENT.md
Normal file
83
docs/contributing/LOCAL_DEVELOPMENT.md
Normal file
@ -0,0 +1,83 @@
|
||||
# Local Development
|
||||
|
||||
If you are looking to contribute you will need to have a local development
|
||||
environment. See the
|
||||
[Developer Install](../installation/020_INSTALL_MANUAL.md#developer-install) for
|
||||
full details.
|
||||
|
||||
Broadly this involves cloning the repository, installing the pre-reqs, and
|
||||
InvokeAI (in editable form). Assuming this is working, choose your area of
|
||||
focus.
|
||||
|
||||
## Documentation
|
||||
|
||||
We use [mkdocs](https://www.mkdocs.org) for our documentation with the
|
||||
[material theme](https://squidfunk.github.io/mkdocs-material/). Documentation is
|
||||
written in markdown files under the `./docs` folder and then built into a static
|
||||
website for hosting with GitHub Pages at
|
||||
[invoke-ai.github.io/InvokeAI](https://invoke-ai.github.io/InvokeAI).
|
||||
|
||||
To contribute to the documentation you'll need to install the dependencies. Note
|
||||
the use of `"`.
|
||||
|
||||
```zsh
|
||||
pip install ".[docs]"
|
||||
```
|
||||
|
||||
Now, to run the documentation locally with hot-reloading for changes made.
|
||||
|
||||
```zsh
|
||||
mkdocs serve
|
||||
```
|
||||
|
||||
You'll then be prompted to connect to `http://127.0.0.1:8080` in order to
|
||||
access.
|
||||
|
||||
## Backend
|
||||
|
||||
The backend is contained within the `./invokeai/backend` folder structure. To
|
||||
get started however please install the development dependencies.
|
||||
|
||||
From the root of the repository run the following command. Note the use of `"`.
|
||||
|
||||
```zsh
|
||||
pip install ".[test]"
|
||||
```
|
||||
|
||||
This in an optional group of packages which is defined within the
|
||||
`pyproject.toml` and will be required for testing the changes you make the the
|
||||
code.
|
||||
|
||||
### Running Tests
|
||||
|
||||
We use [pytest](https://docs.pytest.org/en/7.2.x/) for our test suite. Tests can
|
||||
be found under the `./tests` folder and can be run with a single `pytest`
|
||||
command. Optionally, to review test coverage you can append `--cov`.
|
||||
|
||||
```zsh
|
||||
pytest --cov
|
||||
```
|
||||
|
||||
Test outcomes and coverage will be reported in the terminal. In addition a more
|
||||
detailed report is created in both XML and HTML format in the `./coverage`
|
||||
folder. The HTML one in particular can help identify missing statements
|
||||
requiring tests to ensure coverage. This can be run by opening
|
||||
`./coverage/html/index.html`.
|
||||
|
||||
For example.
|
||||
|
||||
```zsh
|
||||
pytest --cov; open ./coverage/html/index.html
|
||||
```
|
||||
|
||||
??? info "HTML coverage report output"
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
## Front End
|
||||
|
||||
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
|
||||
|
||||
--8<-- "invokeai/frontend/web/README.md"
|
@ -17,7 +17,7 @@ notebooks.
|
||||
|
||||
You will need a GPU to perform training in a reasonable length of
|
||||
time, and at least 12 GB of VRAM. We recommend using the [`xformers`
|
||||
library](../installation/070_INSTALL_XFORMERS) to accelerate the
|
||||
library](../installation/070_INSTALL_XFORMERS.md) to accelerate the
|
||||
training process further. During training, about ~8 GB is temporarily
|
||||
needed in order to store intermediate models, checkpoints and logs.
|
||||
|
||||
|
@ -24,7 +24,7 @@ You need to have opencv installed so that pypatchmatch can be built:
|
||||
brew install opencv
|
||||
```
|
||||
|
||||
The next time you start `invoke`, after sucesfully installing opencv, pypatchmatch will be built.
|
||||
The next time you start `invoke`, after successfully installing opencv, pypatchmatch will be built.
|
||||
|
||||
## Linux
|
||||
|
||||
@ -56,7 +56,7 @@ Prior to installing PyPatchMatch, you need to take the following steps:
|
||||
|
||||
5. Confirm that pypatchmatch is installed. At the command-line prompt enter
|
||||
`python`, and then at the `>>>` line type
|
||||
`from patchmatch import patch_match`: It should look like the follwing:
|
||||
`from patchmatch import patch_match`: It should look like the following:
|
||||
|
||||
```py
|
||||
Python 3.9.5 (default, Nov 23 2021, 15:27:38)
|
||||
@ -108,4 +108,4 @@ Prior to installing PyPatchMatch, you need to take the following steps:
|
||||
|
||||
[**Next, Follow Steps 4-6 from the Debian Section above**](#linux)
|
||||
|
||||
If you see no errors, then you're ready to go!
|
||||
If you see no errors you're ready to go!
|
||||
|
@ -10,6 +10,7 @@ from pydantic.fields import Field
|
||||
from ...invocations import *
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import (
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
@ -92,7 +93,7 @@ async def get_session(
|
||||
async def add_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The node to add"),
|
||||
) -> str:
|
||||
"""Adds a node to the graph"""
|
||||
@ -125,7 +126,7 @@ async def update_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node_path: str = Path(description="The path to the node in the graph"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The new node"),
|
||||
) -> GraphExecutionState:
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
@ -186,7 +187,7 @@ async def delete_node(
|
||||
)
|
||||
async def add_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
edge: tuple[EdgeConnection, EdgeConnection] = Body(description="The edge to add"),
|
||||
edge: Edge = Body(description="The edge to add"),
|
||||
) -> GraphExecutionState:
|
||||
"""Adds an edge to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
@ -228,9 +229,9 @@ async def delete_edge(
|
||||
return Response(status_code=404)
|
||||
|
||||
try:
|
||||
edge = (
|
||||
EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
EdgeConnection(node_id=to_node_id, field=to_field),
|
||||
edge = Edge(
|
||||
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
destination=EdgeConnection(node_id=to_node_id, field=to_field)
|
||||
)
|
||||
session.delete_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
|
@ -19,7 +19,7 @@ from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import EdgeConnection, GraphExecutionState
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
@ -77,7 +77,7 @@ def get_command_parser() -> argparse.ArgumentParser:
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation
|
||||
) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
||||
) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
atype = type(a)
|
||||
btype = type(b)
|
||||
@ -94,9 +94,9 @@ def generate_matching_edges(
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
edges = [
|
||||
(
|
||||
EdgeConnection(node_id=a.id, field=field),
|
||||
EdgeConnection(node_id=b.id, field=field),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=a.id, field=field),
|
||||
destination=EdgeConnection(node_id=b.id, field=field)
|
||||
)
|
||||
for field in matching_fields
|
||||
]
|
||||
@ -111,16 +111,15 @@ class SessionError(Exception):
|
||||
def invoke_all(context: CliContext):
|
||||
"""Runs all invocations in the specified session"""
|
||||
context.invoker.invoke(context.session, invoke_all=True)
|
||||
while not context.session.is_complete():
|
||||
while not context.get_session().is_complete():
|
||||
# Wait some time
|
||||
session = context.get_session()
|
||||
time.sleep(0.1)
|
||||
|
||||
# Print any errors
|
||||
if context.session.has_error():
|
||||
for n in context.session.errors:
|
||||
print(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}"
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||
)
|
||||
|
||||
raise SessionError()
|
||||
@ -203,7 +202,7 @@ def invoke_cli():
|
||||
continue
|
||||
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges = []
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
from_id = (
|
||||
history[0] if current_id == start_id else str(current_id - 1)
|
||||
@ -225,19 +224,19 @@ def invoke_cli():
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command
|
||||
)
|
||||
matching_destinations = [e[1] for e in matching_edges]
|
||||
edges = [e for e in edges if e[1] not in matching_destinations]
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
edges.extend(matching_edges)
|
||||
|
||||
if "link" in args and args["link"]:
|
||||
for link in args["link"]:
|
||||
edges = [e for e in edges if e[1].node_id != command.command.id and e[1].field != link[2]]
|
||||
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
|
||||
edges.append(
|
||||
(
|
||||
EdgeConnection(node_id=link[1], field=link[0]),
|
||||
EdgeConnection(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=link[1], field=link[0]),
|
||||
destination=EdgeConnection(
|
||||
node_id=command.command.id, field=link[2]
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -4,6 +4,8 @@ from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch import Tensor
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
@ -12,7 +14,9 @@ from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(InvokeAIGenerator.schedulers())
|
||||
@ -41,18 +45,32 @@ class TextToImageInvocation(BaseInvocation):
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, sample: Any = None, step: int = 0
|
||||
) -> None:
|
||||
self, context: InvocationContext, sample: Tensor, step: int
|
||||
) -> None:
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id,
|
||||
self.id,
|
||||
{
|
||||
"width": width,
|
||||
"height": height,
|
||||
"dataURL": dataURL
|
||||
},
|
||||
step,
|
||||
float(step) / float(self.steps),
|
||||
self.steps,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
def step_callback(sample, step=0):
|
||||
self.dispatch_progress(context, sample, step)
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state.latents, state.step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
|
@ -1,7 +1,10 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, TypedDict
|
||||
|
||||
ProgressImage = TypedDict(
|
||||
"ProgressImage", {"dataURL": str, "width": int, "height": int}
|
||||
)
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
@ -23,8 +26,9 @@ class EventServiceBase:
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
progress_image: ProgressImage | None,
|
||||
step: int,
|
||||
percent: float,
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_session_event(
|
||||
@ -32,8 +36,9 @@ class EventServiceBase:
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation_id=invocation_id,
|
||||
progress_image=progress_image,
|
||||
step=step,
|
||||
percent=percent,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -44,6 +44,11 @@ class EdgeConnection(BaseModel):
|
||||
return hash(f"{self.node_id}.{self.field}")
|
||||
|
||||
|
||||
class Edge(BaseModel):
|
||||
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
|
||||
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")
|
||||
|
||||
|
||||
def get_output_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_type = type(node)
|
||||
node_outputs = get_type_hints(node_type.get_output_type())
|
||||
@ -194,7 +199,7 @@ class Graph(BaseModel):
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
)
|
||||
edges: list[tuple[EdgeConnection, EdgeConnection]] = Field(
|
||||
edges: list[Edge] = Field(
|
||||
description="The connections between nodes and their fields in this graph",
|
||||
default_factory=list,
|
||||
)
|
||||
@ -251,7 +256,7 @@ class Graph(BaseModel):
|
||||
except NodeNotFoundError:
|
||||
pass # Ignore, not doesn't exist (should this throw?)
|
||||
|
||||
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
def add_edge(self, edge: Edge) -> None:
|
||||
"""Adds an edge to a graph
|
||||
|
||||
:raises InvalidEdgeError: the provided edge is invalid.
|
||||
@ -262,7 +267,7 @@ class Graph(BaseModel):
|
||||
else:
|
||||
raise InvalidEdgeError()
|
||||
|
||||
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
def delete_edge(self, edge: Edge) -> None:
|
||||
"""Deletes an edge from a graph"""
|
||||
|
||||
try:
|
||||
@ -280,7 +285,7 @@ class Graph(BaseModel):
|
||||
|
||||
# Validate all edges reference nodes in the graph
|
||||
node_ids = set(
|
||||
[e[0].node_id for e in self.edges] + [e[1].node_id for e in self.edges]
|
||||
[e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
|
||||
)
|
||||
if not all((self.has_node(node_id) for node_id in node_ids)):
|
||||
return False
|
||||
@ -294,10 +299,10 @@ class Graph(BaseModel):
|
||||
if not all(
|
||||
(
|
||||
are_connections_compatible(
|
||||
self.get_node(e[0].node_id),
|
||||
e[0].field,
|
||||
self.get_node(e[1].node_id),
|
||||
e[1].field,
|
||||
self.get_node(e.source.node_id),
|
||||
e.source.field,
|
||||
self.get_node(e.destination.node_id),
|
||||
e.destination.field,
|
||||
)
|
||||
for e in self.edges
|
||||
)
|
||||
@ -328,58 +333,58 @@ class Graph(BaseModel):
|
||||
|
||||
return True
|
||||
|
||||
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||
try:
|
||||
from_node = self.get_node(edge[0].node_id)
|
||||
to_node = self.get_node(edge[1].node_id)
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge[1].node_id, edge[1].field)
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||
return False
|
||||
|
||||
# Validate that no cycles would be created
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge[0].node_id, edge[1].node_id)
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
return False
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge[0].field, to_node, edge[1].field
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge[1].field == "collection":
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge[1].node_id, new_input=edge[0]
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge[0].field == "item":
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge[0].node_id, new_output=edge[1]
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge[1].field == "item":
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge[1].node_id, new_input=edge[0]
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge[0].field == "collection":
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge[0].node_id, new_output=edge[1]
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
|
||||
@ -438,15 +443,15 @@ class Graph(BaseModel):
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = (
|
||||
new_node.id
|
||||
if "." not in edge[1].node_id
|
||||
else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}'
|
||||
if "." not in edge.destination.node_id
|
||||
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
|
||||
)
|
||||
graph.add_edge(
|
||||
(
|
||||
edge[0],
|
||||
EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge[1].field
|
||||
),
|
||||
Edge(
|
||||
source=edge.source,
|
||||
destination=EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge.destination.field
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@ -454,51 +459,51 @@ class Graph(BaseModel):
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = (
|
||||
new_node.id
|
||||
if "." not in edge[0].node_id
|
||||
else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}'
|
||||
if "." not in edge.source.node_id
|
||||
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
|
||||
)
|
||||
graph.add_edge(
|
||||
(
|
||||
EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge[0].field
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge.source.field
|
||||
),
|
||||
edge[1],
|
||||
destination=edge.destination
|
||||
)
|
||||
)
|
||||
|
||||
def _get_input_edges(
|
||||
self, node_path: str, field: Optional[str] = None
|
||||
) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
||||
) -> list[Edge]:
|
||||
"""Gets all input edges for a node"""
|
||||
edges = self._get_input_edges_and_graphs(node_path)
|
||||
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if field is None or e[2][1].field == field)
|
||||
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
|
||||
|
||||
# Create full node paths for each edge
|
||||
return [
|
||||
(
|
||||
EdgeConnection(
|
||||
node_id=self._get_node_path(e[0].node_id, prefix=prefix),
|
||||
field=e[0].field,
|
||||
),
|
||||
EdgeConnection(
|
||||
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
|
||||
field=e[1].field,
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
|
||||
field=e.source.field,
|
||||
),
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
)
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
|
||||
def _get_input_edges_and_graphs(
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]:
|
||||
) -> list[tuple["Graph", str, Edge]]:
|
||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend(
|
||||
[(self, prefix, e) for e in self.edges if e[1].node_id == node_path]
|
||||
[(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
|
||||
)
|
||||
|
||||
node_id = (
|
||||
@ -522,37 +527,37 @@ class Graph(BaseModel):
|
||||
|
||||
def _get_output_edges(
|
||||
self, node_path: str, field: str
|
||||
) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
||||
) -> list[Edge]:
|
||||
"""Gets all output edges for a node"""
|
||||
edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if e[2][0].field == field)
|
||||
filtered_edges = (e for e in edges if e[2].source.field == field)
|
||||
|
||||
# Create full node paths for each edge
|
||||
return [
|
||||
(
|
||||
EdgeConnection(
|
||||
node_id=self._get_node_path(e[0].node_id, prefix=prefix),
|
||||
field=e[0].field,
|
||||
),
|
||||
EdgeConnection(
|
||||
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
|
||||
field=e[1].field,
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
|
||||
field=e.source.field,
|
||||
),
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
)
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
|
||||
def _get_output_edges_and_graphs(
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]:
|
||||
) -> list[tuple["Graph", str, Edge]]:
|
||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend(
|
||||
[(self, prefix, e) for e in self.edges if e[0].node_id == node_path]
|
||||
[(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
|
||||
)
|
||||
|
||||
node_id = (
|
||||
@ -580,8 +585,8 @@ class Graph(BaseModel):
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = list([e[0] for e in self._get_input_edges(node_path, "collection")])
|
||||
outputs = list([e[1] for e in self._get_output_edges(node_path, "item")])
|
||||
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
|
||||
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@ -622,8 +627,8 @@ class Graph(BaseModel):
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = list([e[0] for e in self._get_input_edges(node_path, "item")])
|
||||
outputs = list([e[1] for e in self._get_output_edges(node_path, "collection")])
|
||||
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
|
||||
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@ -684,7 +689,7 @@ class Graph(BaseModel):
|
||||
# TODO: Cache this?
|
||||
g = nx.DiGraph()
|
||||
g.add_nodes_from([n for n in self.nodes.keys()])
|
||||
g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges]))
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_flat(
|
||||
@ -711,7 +716,7 @@ class Graph(BaseModel):
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges])
|
||||
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
||||
g.add_edges_from(
|
||||
[
|
||||
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
|
||||
@ -768,6 +773,24 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
# Declare all fields as required; necessary for OpenAPI schema generation build.
|
||||
# Technically only fields without a `default_factory` need to be listed here.
|
||||
# See: https://github.com/pydantic/pydantic/discussions/4577
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'id',
|
||||
'graph',
|
||||
'execution_graph',
|
||||
'executed',
|
||||
'executed_history',
|
||||
'results',
|
||||
'errors',
|
||||
'prepared_source_mapping',
|
||||
'source_prepared_mapping',
|
||||
]
|
||||
}
|
||||
|
||||
def next(self) -> BaseInvocation | None:
|
||||
"""Gets the next node ready to execute."""
|
||||
|
||||
@ -841,13 +864,13 @@ class GraphExecutionState(BaseModel):
|
||||
input_collection_prepared_node_id = next(
|
||||
n[1]
|
||||
for n in iteration_node_map
|
||||
if n[0] == input_collection_edge[0].node_id
|
||||
if n[0] == input_collection_edge.source.node_id
|
||||
)
|
||||
input_collection_prepared_node_output = self.results[
|
||||
input_collection_prepared_node_id
|
||||
]
|
||||
input_collection = getattr(
|
||||
input_collection_prepared_node_output, input_collection_edge[0].field
|
||||
input_collection_prepared_node_output, input_collection_edge.source.field
|
||||
)
|
||||
self_iteration_count = len(input_collection)
|
||||
|
||||
@ -864,11 +887,11 @@ class GraphExecutionState(BaseModel):
|
||||
new_edges = list()
|
||||
for edge in input_edges:
|
||||
for input_node_id in (
|
||||
n[1] for n in iteration_node_map if n[0] == edge[0].node_id
|
||||
n[1] for n in iteration_node_map if n[0] == edge.source.node_id
|
||||
):
|
||||
new_edge = (
|
||||
EdgeConnection(node_id=input_node_id, field=edge[0].field),
|
||||
EdgeConnection(node_id="", field=edge[1].field),
|
||||
new_edge = Edge(
|
||||
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
|
||||
destination=EdgeConnection(node_id="", field=edge.destination.field),
|
||||
)
|
||||
new_edges.append(new_edge)
|
||||
|
||||
@ -893,9 +916,9 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Add new edges to execution graph
|
||||
for edge in new_edges:
|
||||
new_edge = (
|
||||
edge[0],
|
||||
EdgeConnection(node_id=new_node.id, field=edge[1].field),
|
||||
new_edge = Edge(
|
||||
source=edge.source,
|
||||
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
|
||||
)
|
||||
self.execution_graph.add_edge(new_edge)
|
||||
|
||||
@ -1043,26 +1066,26 @@ class GraphExecutionState(BaseModel):
|
||||
return self.execution_graph.nodes[next_node]
|
||||
|
||||
def _prepare_inputs(self, node: BaseInvocation):
|
||||
input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id]
|
||||
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
|
||||
if isinstance(node, CollectInvocation):
|
||||
output_collection = [
|
||||
getattr(self.results[edge[0].node_id], edge[0].field)
|
||||
getattr(self.results[edge.source.node_id], edge.source.field)
|
||||
for edge in input_edges
|
||||
if edge[1].field == "item"
|
||||
if edge.destination.field == "item"
|
||||
]
|
||||
setattr(node, "collection", output_collection)
|
||||
else:
|
||||
for edge in input_edges:
|
||||
output_value = getattr(self.results[edge[0].node_id], edge[0].field)
|
||||
setattr(node, edge[1].field, output_value)
|
||||
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
|
||||
setattr(node, edge.destination.field, output_value)
|
||||
|
||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
if not self._is_edge_valid(edge):
|
||||
return False
|
||||
|
||||
# Invalid if destination has already been prepared or executed
|
||||
if edge[1].node_id in self.source_prepared_mapping:
|
||||
if edge.destination.node_id in self.source_prepared_mapping:
|
||||
return False
|
||||
|
||||
# Otherwise, the edge is valid
|
||||
@ -1089,17 +1112,17 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
self.graph.delete_node(node_path)
|
||||
|
||||
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
if not self._is_node_updatable(edge[1].node_id):
|
||||
def add_edge(self, edge: Edge) -> None:
|
||||
if not self._is_node_updatable(edge.destination.node_id):
|
||||
raise NodeAlreadyExecutedError(
|
||||
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to"
|
||||
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot be linked to"
|
||||
)
|
||||
self.graph.add_edge(edge)
|
||||
|
||||
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
||||
if not self._is_node_updatable(edge[1].node_id):
|
||||
def delete_edge(self, edge: Edge) -> None:
|
||||
if not self._is_node_updatable(edge.destination.node_id):
|
||||
raise NodeAlreadyExecutedError(
|
||||
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted"
|
||||
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted"
|
||||
)
|
||||
self.graph.delete_edge(edge)
|
||||
|
||||
|
@ -490,7 +490,7 @@ class Args(object):
|
||||
"-z",
|
||||
type=int,
|
||||
default=6,
|
||||
choices=range(0, 9),
|
||||
choices=range(0, 10),
|
||||
dest="png_compression",
|
||||
help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.",
|
||||
)
|
||||
@ -943,7 +943,6 @@ class Args(object):
|
||||
"--png_compression",
|
||||
"-z",
|
||||
type=int,
|
||||
default=6,
|
||||
choices=range(0, 10),
|
||||
dest="png_compression",
|
||||
help="level of PNG compression, from 0 (none) to 9 (maximum). [6]",
|
||||
|
@ -154,6 +154,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
for i in iteration_count:
|
||||
results = generator.generate(prompt,
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
step_callback=step_callback,
|
||||
sampler=scheduler,
|
||||
**generator_args,
|
||||
)
|
||||
@ -497,7 +498,8 @@ class Generator:
|
||||
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
def sample_to_lowres_estimated_image(self, samples):
|
||||
@staticmethod
|
||||
def sample_to_lowres_estimated_image(samples):
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
|
@ -159,6 +159,7 @@ class Inpaint(Img2Img):
|
||||
seam_size: int,
|
||||
seam_blur: int,
|
||||
prompt,
|
||||
seed,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
@ -192,7 +193,7 @@ class Inpaint(Img2Img):
|
||||
|
||||
seam_noise = self.get_noise(im.width, im.height)
|
||||
|
||||
result = make_image(seam_noise)
|
||||
result = make_image(seam_noise, seed)
|
||||
|
||||
return result
|
||||
|
||||
@ -342,6 +343,7 @@ class Inpaint(Img2Img):
|
||||
seam_size,
|
||||
seam_blur,
|
||||
prompt,
|
||||
seed,
|
||||
sampler,
|
||||
seam_steps,
|
||||
cfg_scale,
|
||||
|
@ -372,22 +372,32 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||
unet_key = "model.diffusion_model."
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
if extract_ema:
|
||||
print(" | Extracting EMA weights (usually better for inference)")
|
||||
print(" | Extracting EMA weights (usually better for inference)")
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
|
||||
flat_ema_key
|
||||
)
|
||||
flat_ema_key_alt = "model_ema." + "".join(key.split(".")[2:])
|
||||
if flat_ema_key in checkpoint:
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
|
||||
flat_ema_key
|
||||
)
|
||||
elif flat_ema_key_alt in checkpoint:
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
|
||||
flat_ema_key_alt
|
||||
)
|
||||
else:
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
|
||||
key
|
||||
)
|
||||
else:
|
||||
print(
|
||||
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
if key.startswith(unet_key):
|
||||
if key.startswith("model.diffusion_model") and key in checkpoint:
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
@ -1026,6 +1036,15 @@ def convert_open_clip_checkpoint(checkpoint):
|
||||
|
||||
return text_model
|
||||
|
||||
def replace_checkpoint_vae(checkpoint, vae_path:str):
|
||||
if vae_path.endswith(".safetensors"):
|
||||
vae_ckpt = load_file(vae_path)
|
||||
else:
|
||||
vae_ckpt = torch.load(vae_path, map_location="cpu")
|
||||
state_dict = vae_ckpt['state_dict'] if "state_dict" in vae_ckpt else vae_ckpt
|
||||
for vae_key in state_dict:
|
||||
new_key = f'first_stage_model.{vae_key}'
|
||||
checkpoint[new_key] = state_dict[vae_key]
|
||||
|
||||
def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path: str,
|
||||
@ -1038,8 +1057,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
extract_ema: bool = True,
|
||||
upcast_attn: bool = False,
|
||||
vae: AutoencoderKL = None,
|
||||
vae_path: str = None,
|
||||
precision: torch.dtype = torch.float32,
|
||||
return_generator_pipeline: bool = False,
|
||||
scan_needed:bool=True,
|
||||
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
|
||||
"""
|
||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||
@ -1067,6 +1088,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
|
||||
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
|
||||
running stable diffusion 2.1.
|
||||
:param vae: A diffusers VAE to load into the pipeline.
|
||||
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
||||
"""
|
||||
|
||||
with warnings.catch_warnings():
|
||||
@ -1074,11 +1097,13 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
checkpoint = (
|
||||
load_file(checkpoint_path)
|
||||
if Path(checkpoint_path).suffix == ".safetensors"
|
||||
else torch.load(checkpoint_path)
|
||||
)
|
||||
if Path(checkpoint_path).suffix == '.ckpt':
|
||||
if scan_needed:
|
||||
ModelManager.scan_model(checkpoint_path,checkpoint_path)
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
else:
|
||||
checkpoint = load_file(checkpoint_path)
|
||||
|
||||
cache_dir = global_cache_dir("hub")
|
||||
pipeline_class = (
|
||||
StableDiffusionGeneratorPipeline
|
||||
@ -1090,7 +1115,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
print(" | global_step key not found in model")
|
||||
print(" | global_step key not found in model")
|
||||
global_step = None
|
||||
|
||||
# sometimes there is a state_dict key and sometimes not
|
||||
@ -1201,9 +1226,19 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
|
||||
# Convert the VAE model, or use the one passed
|
||||
if not vae:
|
||||
# If a replacement VAE path was specified, we'll incorporate that into
|
||||
# the checkpoint model and then convert it
|
||||
if vae_path:
|
||||
print(f" | Converting VAE {vae_path}")
|
||||
replace_checkpoint_vae(checkpoint,vae_path)
|
||||
# otherwise we use the original VAE, provided that
|
||||
# an externally loaded diffusers VAE was not passed
|
||||
elif not vae:
|
||||
print(" | Using checkpoint model's original VAE")
|
||||
|
||||
if vae:
|
||||
print(" | Using replacement diffusers VAE")
|
||||
else: # convert the original or replacement VAE
|
||||
vae_config = create_vae_diffusers_config(
|
||||
original_config, image_size=image_size
|
||||
)
|
||||
@ -1213,8 +1248,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
else:
|
||||
print(" | Using external VAE specified in config")
|
||||
|
||||
# Convert the text model.
|
||||
model_type = pipeline_type
|
||||
|
@ -34,7 +34,7 @@ from picklescan.scanner import scan_file_path
|
||||
from invokeai.backend.globals import Globals, global_cache_dir
|
||||
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume
|
||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = 1
|
||||
@ -45,9 +45,6 @@ class SDLegacyType(Enum):
|
||||
UNKNOWN = 99
|
||||
|
||||
DEFAULT_MAX_MODELS = 2
|
||||
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
||||
"vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse",
|
||||
}
|
||||
|
||||
class ModelManager(object):
|
||||
'''
|
||||
@ -285,13 +282,13 @@ class ModelManager(object):
|
||||
self.stack.remove(model_name)
|
||||
if delete_files:
|
||||
if weights:
|
||||
print(f"** deleting file {weights}")
|
||||
print(f"** Deleting file {weights}")
|
||||
Path(weights).unlink(missing_ok=True)
|
||||
elif path:
|
||||
print(f"** deleting directory {path}")
|
||||
print(f"** Deleting directory {path}")
|
||||
rmtree(path, ignore_errors=True)
|
||||
elif repo_id:
|
||||
print(f"** deleting the cached model directory for {repo_id}")
|
||||
print(f"** Deleting the cached model directory for {repo_id}")
|
||||
self._delete_model_from_cache(repo_id)
|
||||
|
||||
def add_model(
|
||||
@ -362,6 +359,7 @@ class ModelManager(object):
|
||||
raise NotImplementedError(
|
||||
f"Unknown model format {model_name}: {model_format}"
|
||||
)
|
||||
self._add_embeddings_to_model(model)
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
@ -381,9 +379,9 @@ class ModelManager(object):
|
||||
|
||||
print(f">> Loading diffusers model from {name_or_path}")
|
||||
if using_fp16:
|
||||
print(" | Using faster float16 precision")
|
||||
print(" | Using faster float16 precision")
|
||||
else:
|
||||
print(" | Using more accurate float32 precision")
|
||||
print(" | Using more accurate float32 precision")
|
||||
|
||||
# TODO: scan weights maybe?
|
||||
pipeline_args: dict[str, Any] = dict(
|
||||
@ -434,10 +432,9 @@ class ModelManager(object):
|
||||
# square images???
|
||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||
height = width
|
||||
|
||||
print(f" | Default image dimensions = {width} x {height}")
|
||||
print(f" | Default image dimensions = {width} x {height}")
|
||||
|
||||
self._add_embeddings_to_model(pipeline)
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
def _load_ckpt_model(self, model_name, mconfig):
|
||||
@ -457,15 +454,21 @@ class ModelManager(object):
|
||||
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
self.offload_model(self.current_model)
|
||||
if vae_config := self._choose_diffusers_vae(model_name):
|
||||
vae = self._load_vae(vae_config)
|
||||
try:
|
||||
if self.list_models()[self.current_model]['status'] == 'active':
|
||||
self.offload_model(self.current_model)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
vae_path = None
|
||||
if vae:
|
||||
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=weights,
|
||||
original_config_file=config,
|
||||
vae=vae,
|
||||
vae_path=vae_path,
|
||||
return_generator_pipeline=True,
|
||||
precision=torch.float16 if self.precision == "float16" else torch.float32,
|
||||
)
|
||||
@ -512,18 +515,20 @@ class ModelManager(object):
|
||||
print(f">> Offloading {model_name} to CPU")
|
||||
model = self.models[model_name]["model"]
|
||||
model.offload_all()
|
||||
self.current_model = None
|
||||
|
||||
gc.collect()
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def scan_model(self, model_name, checkpoint):
|
||||
"""
|
||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
print(f">> Scanning Model: {model_name}")
|
||||
print(f" | Scanning Model: {model_name}")
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files == 1:
|
||||
@ -546,7 +551,7 @@ class ModelManager(object):
|
||||
print("### Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print(">> Model scanned ok")
|
||||
print(" | Model scanned ok")
|
||||
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
@ -665,7 +670,7 @@ class ModelManager(object):
|
||||
print(f">> Probing {thing} for import")
|
||||
|
||||
if thing.startswith(("http:", "https:", "ftp:")):
|
||||
print(f" | {thing} appears to be a URL")
|
||||
print(f" | {thing} appears to be a URL")
|
||||
model_path = self._resolve_path(
|
||||
thing, "models/ldm/stable-diffusion-v1"
|
||||
) # _resolve_path does a download if needed
|
||||
@ -673,15 +678,15 @@ class ModelManager(object):
|
||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
||||
print(
|
||||
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
||||
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
||||
)
|
||||
return
|
||||
else:
|
||||
print(f" | {thing} appears to be a checkpoint file on disk")
|
||||
print(f" | {thing} appears to be a checkpoint file on disk")
|
||||
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
||||
|
||||
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
||||
print(f" | {thing} appears to be a diffusers file on disk")
|
||||
print(f" | {thing} appears to be a diffusers file on disk")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing,
|
||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||
@ -692,13 +697,13 @@ class ModelManager(object):
|
||||
|
||||
elif Path(thing).is_dir():
|
||||
if (Path(thing) / "model_index.json").exists():
|
||||
print(f" | {thing} appears to be a diffusers model.")
|
||||
print(f" | {thing} appears to be a diffusers model.")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing, commit_to_conf=commit_to_conf
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" |{thing} appears to be a directory. Will scan for models to import"
|
||||
f" |{thing} appears to be a directory. Will scan for models to import"
|
||||
)
|
||||
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
||||
Path(thing).rglob("*.safetensors")
|
||||
@ -710,7 +715,7 @@ class ModelManager(object):
|
||||
return model_name
|
||||
|
||||
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
||||
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
|
||||
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing, commit_to_conf=commit_to_conf
|
||||
)
|
||||
@ -727,32 +732,33 @@ class ModelManager(object):
|
||||
return
|
||||
|
||||
if model_path.stem in self.config: # already imported
|
||||
print(" | Already imported. Skipping")
|
||||
print(" | Already imported. Skipping")
|
||||
return model_path.stem
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
checkpoint = (
|
||||
safetensors.torch.load_file(model_path)
|
||||
if model_path.suffix == ".safetensors"
|
||||
else torch.load(model_path)
|
||||
)
|
||||
checkpoint = None
|
||||
if model_path.suffix.endswith((".ckpt",".pt")):
|
||||
self.scan_model(model_path,model_path)
|
||||
checkpoint = torch.load(model_path)
|
||||
else:
|
||||
checkpoint = safetensors.torch.load_file(model_path)
|
||||
|
||||
# additional probing needed if no config file provided
|
||||
if model_config_file is None:
|
||||
model_type = self.probe_model_type(checkpoint)
|
||||
if model_type == SDLegacyType.V1:
|
||||
print(" | SD-v1 model detected")
|
||||
print(" | SD-v1 model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
print(" | SD-v1 inpainting model detected")
|
||||
print(" | SD-v1 inpainting model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_v:
|
||||
print(
|
||||
" | SD-v2-v model detected; model will be converted to diffusers format"
|
||||
" | SD-v2-v model detected; model will be converted to diffusers format"
|
||||
)
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||
@ -760,7 +766,7 @@ class ModelManager(object):
|
||||
convert = True
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
print(
|
||||
" | SD-v2-e model detected; model will be converted to diffusers format"
|
||||
" | SD-v2-e model detected; model will be converted to diffusers format"
|
||||
)
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||
@ -788,18 +794,21 @@ class ModelManager(object):
|
||||
model_description=description,
|
||||
original_config_file=model_config_file,
|
||||
commit_to_conf=commit_to_conf,
|
||||
scan_needed=False,
|
||||
)
|
||||
return model_name
|
||||
|
||||
def convert_and_import(
|
||||
self,
|
||||
ckpt_path: Path,
|
||||
diffusers_path: Path,
|
||||
model_name=None,
|
||||
model_description=None,
|
||||
vae=None,
|
||||
original_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
self,
|
||||
ckpt_path: Path,
|
||||
diffusers_path: Path,
|
||||
model_name=None,
|
||||
model_description=None,
|
||||
vae:dict=None,
|
||||
vae_path:Path=None,
|
||||
original_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
scan_needed: bool=True,
|
||||
) -> str:
|
||||
"""
|
||||
Convert a legacy ckpt weights file to diffuser model and import
|
||||
@ -827,18 +836,23 @@ class ModelManager(object):
|
||||
try:
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
vae_model = self._load_vae(vae) if vae else None
|
||||
vae_model=None
|
||||
if vae:
|
||||
vae_model=self._load_vae(vae)
|
||||
vae_path=None
|
||||
convert_ckpt_to_diffusers(
|
||||
ckpt_path,
|
||||
diffusers_path,
|
||||
extract_ema=True,
|
||||
original_config_file=original_config_file,
|
||||
vae=vae_model,
|
||||
vae_path=vae_path,
|
||||
scan_needed=scan_needed,
|
||||
)
|
||||
print(
|
||||
f" | Success. Optimized model is now located at {str(diffusers_path)}"
|
||||
f" | Success. Optimized model is now located at {str(diffusers_path)}"
|
||||
)
|
||||
print(f" | Writing new config file entry for {model_name}")
|
||||
print(f" | Writing new config file entry for {model_name}")
|
||||
new_config = dict(
|
||||
path=str(diffusers_path),
|
||||
description=model_description,
|
||||
@ -849,7 +863,7 @@ class ModelManager(object):
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
print(">> Conversion succeeded")
|
||||
print(" | Conversion succeeded")
|
||||
except Exception as e:
|
||||
print(f"** Conversion failed: {str(e)}")
|
||||
print(
|
||||
@ -879,36 +893,6 @@ class ModelManager(object):
|
||||
|
||||
return search_folder, found_models
|
||||
|
||||
def _choose_diffusers_vae(
|
||||
self, model_name: str, vae: str = None
|
||||
) -> Union[dict, str]:
|
||||
# In the event that the original entry is using a custom ckpt VAE, we try to
|
||||
# map that VAE onto a diffuser VAE using a hard-coded dictionary.
|
||||
# I would prefer to do this differently: We load the ckpt model into memory, swap the
|
||||
# VAE in memory, and then pass that to convert_ckpt_to_diffuser() so that the swapped
|
||||
# VAE is built into the model. However, when I tried this I got obscure key errors.
|
||||
if vae:
|
||||
return vae
|
||||
if model_name in self.config and (
|
||||
vae_ckpt_path := self.model_info(model_name).get("vae", None)
|
||||
):
|
||||
vae_basename = Path(vae_ckpt_path).stem
|
||||
diffusers_vae = None
|
||||
if diffusers_vae := VAE_TO_REPO_ID.get(vae_basename, None):
|
||||
print(
|
||||
f">> {vae_basename} VAE corresponds to known {diffusers_vae} diffusers version"
|
||||
)
|
||||
vae = {"repo_id": diffusers_vae}
|
||||
else:
|
||||
print(
|
||||
f'** Custom VAE "{vae_basename}" found, but corresponding diffusers model unknown'
|
||||
)
|
||||
print(
|
||||
'** Using "stabilityai/sd-vae-ft-mse"; If this isn\'t right, please edit the model config'
|
||||
)
|
||||
vae = {"repo_id": "stabilityai/sd-vae-ft-mse"}
|
||||
return vae
|
||||
|
||||
def _make_cache_room(self) -> None:
|
||||
num_loaded_models = len(self.models)
|
||||
if num_loaded_models >= self.max_loaded_models:
|
||||
@ -1105,7 +1089,7 @@ class ModelManager(object):
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
print(" | Calculating sha256 hash of model files")
|
||||
print(" | Calculating sha256 hash of model files")
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
count = 0
|
||||
@ -1117,7 +1101,7 @@ class ModelManager(object):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
@ -1162,12 +1146,12 @@ class ModelManager(object):
|
||||
local_files_only=not Globals.internet_available,
|
||||
)
|
||||
|
||||
print(f" | Loading diffusers VAE from {name_or_path}")
|
||||
print(f" | Loading diffusers VAE from {name_or_path}")
|
||||
if using_fp16:
|
||||
vae_args.update(torch_dtype=torch.float16)
|
||||
fp_args_list = [{"revision": "fp16"}, {}]
|
||||
else:
|
||||
print(" | Using more accurate float32 precision")
|
||||
print(" | Using more accurate float32 precision")
|
||||
fp_args_list = [{}]
|
||||
|
||||
vae = None
|
||||
@ -1208,7 +1192,7 @@ class ModelManager(object):
|
||||
hashes_to_delete.add(revision.commit_hash)
|
||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||
print(
|
||||
f"** deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||
)
|
||||
strategy.execute()
|
||||
|
||||
|
@ -6,7 +6,6 @@ The interface is through the Concepts() object.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import Callable
|
||||
from urllib import error as ul_error
|
||||
from urllib import request
|
||||
@ -15,7 +14,6 @@ from huggingface_hub import (
|
||||
HfApi,
|
||||
HfFolder,
|
||||
ModelFilter,
|
||||
ModelSearchArguments,
|
||||
hf_hub_url,
|
||||
)
|
||||
|
||||
@ -84,7 +82,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
"""
|
||||
if not concept_name in self.list_concepts():
|
||||
print(
|
||||
f"This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||
)
|
||||
return None
|
||||
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
|
||||
@ -236,7 +234,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
except ul_error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
print(
|
||||
f"This concept is not known to the Hugging Face library. Generation will continue without the concept."
|
||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
@ -246,7 +244,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
return False
|
||||
except ul_error.URLError as e:
|
||||
print(
|
||||
f"ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
|
@ -3,6 +3,9 @@ import math
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
import io
|
||||
import base64
|
||||
|
||||
from collections import abc
|
||||
from inspect import isfunction
|
||||
from pathlib import Path
|
||||
@ -364,3 +367,16 @@ def url_attachment_name(url: str) -> dict:
|
||||
def download_with_progress_bar(url: str, dest: Path) -> bool:
|
||||
result = download_with_resume(url, dest, access_token=None)
|
||||
return result is not None
|
||||
|
||||
|
||||
def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
|
||||
"""
|
||||
Converts an image into a base64 image dataURL.
|
||||
"""
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format=image_format)
|
||||
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
|
||||
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
|
||||
buffered.getvalue()
|
||||
).decode("UTF-8")
|
||||
return image_base64
|
||||
|
@ -772,16 +772,10 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
||||
original_config_file = Path(model_info["config"])
|
||||
model_name = model_name_or_path
|
||||
model_description = model_info["description"]
|
||||
vae = model_info["vae"]
|
||||
vae_path = model_info.get("vae")
|
||||
else:
|
||||
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
|
||||
return
|
||||
if vae_repo := invokeai.backend.model_management.model_manager.VAE_TO_REPO_ID.get(
|
||||
Path(vae).stem
|
||||
):
|
||||
vae_repo = dict(repo_id=vae_repo)
|
||||
else:
|
||||
vae_repo = None
|
||||
model_name = manager.convert_and_import(
|
||||
ckpt_path,
|
||||
diffusers_path=Path(
|
||||
@ -790,7 +784,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
||||
model_name=model_name,
|
||||
model_description=model_description,
|
||||
original_config_file=original_config_file,
|
||||
vae=vae_repo,
|
||||
vae_path=vae_path,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
|
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
Normal file
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
Normal file
File diff suppressed because one or more lines are too long
188
invokeai/frontend/web/dist/assets/App-982926da.js
vendored
188
invokeai/frontend/web/dist/assets/App-982926da.js
vendored
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
||||
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-2ad84bef.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-e63a2dc4.js";var Or=`
|
||||
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-f7f41e1f.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-eaf47ae3.js";var Or=`
|
||||
:root {
|
||||
--chakra-vh: 100vh;
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-2ad84bef.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-f7f41e1f.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index-5483945c.css">
|
||||
</head>
|
||||
|
||||
|
3
invokeai/frontend/web/dist/locales/en.json
vendored
3
invokeai/frontend/web/dist/locales/en.json
vendored
@ -64,6 +64,8 @@
|
||||
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
|
||||
"upload": "Upload",
|
||||
"close": "Close",
|
||||
"cancel": "Cancel",
|
||||
"accept": "Accept",
|
||||
"load": "Load",
|
||||
"back": "Back",
|
||||
"statusConnected": "Connected",
|
||||
@ -333,6 +335,7 @@
|
||||
"addNewModel": "Add New Model",
|
||||
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
|
||||
"addDiffuserModel": "Add Diffusers",
|
||||
"scanForModels": "Scan For Models",
|
||||
"addManually": "Add Manually",
|
||||
"manual": "Manual",
|
||||
"name": "Name",
|
||||
|
36
invokeai/frontend/web/index.d.ts
vendored
36
invokeai/frontend/web/index.d.ts
vendored
@ -1,3 +1,7 @@
|
||||
import React, { PropsWithChildren } from 'react';
|
||||
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
|
||||
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
|
||||
|
||||
export {};
|
||||
|
||||
declare module 'redux-socket.io-middleware';
|
||||
@ -40,5 +44,35 @@ declare global {
|
||||
/* eslint-enable @typescript-eslint/no-explicit-any */
|
||||
}
|
||||
|
||||
declare function Invoke(): React.JSX;
|
||||
declare module '@invoke-ai/invoke-ai-ui' {
|
||||
declare class ThemeChanger extends React.Component<ThemeChangerProps> {
|
||||
public constructor(props: ThemeChangerProps);
|
||||
}
|
||||
|
||||
declare class InvokeAiLogoComponent extends React.Component<InvokeAILogoComponentProps> {
|
||||
public constructor(props: InvokeAILogoComponentProps);
|
||||
}
|
||||
|
||||
declare class IAIPopover extends React.Component<IAIPopoverProps> {
|
||||
public constructor(props: IAIPopoverProps);
|
||||
}
|
||||
|
||||
declare class IAIIconButton extends React.Component<IAIIconButtonProps> {
|
||||
public constructor(props: IAIIconButtonProps);
|
||||
}
|
||||
|
||||
declare class SettingsModal extends React.Component<SettingsModalProps> {
|
||||
public constructor(props: SettingsModalProps);
|
||||
}
|
||||
}
|
||||
|
||||
declare function Invoke(props: PropsWithChildren): JSX.Element;
|
||||
|
||||
export {
|
||||
ThemeChanger,
|
||||
InvokeAiLogoComponent,
|
||||
IAIPopover,
|
||||
IAIIconButton,
|
||||
SettingsModal,
|
||||
};
|
||||
export = Invoke;
|
||||
|
@ -6,7 +6,6 @@
|
||||
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
|
||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||
"build": "yarn run lint && vite build",
|
||||
"build:package": "vite build --mode=package",
|
||||
"preview": "vite preview",
|
||||
"lint:madge": "madge --circular src/main.tsx",
|
||||
"lint:eslint": "eslint --max-warnings=0 .",
|
||||
|
@ -8,7 +8,6 @@
|
||||
"darkTheme": "داكن",
|
||||
"lightTheme": "فاتح",
|
||||
"greenTheme": "أخضر",
|
||||
"text2img": "نص إلى صورة",
|
||||
"img2img": "صورة إلى صورة",
|
||||
"unifiedCanvas": "لوحة موحدة",
|
||||
"nodes": "عقد",
|
||||
|
@ -7,7 +7,6 @@
|
||||
"darkTheme": "Dunkel",
|
||||
"lightTheme": "Hell",
|
||||
"greenTheme": "Grün",
|
||||
"text2img": "Text zu Bild",
|
||||
"img2img": "Bild zu Bild",
|
||||
"nodes": "Knoten",
|
||||
"langGerman": "Deutsch",
|
||||
|
@ -64,6 +64,8 @@
|
||||
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
|
||||
"upload": "Upload",
|
||||
"close": "Close",
|
||||
"cancel": "Cancel",
|
||||
"accept": "Accept",
|
||||
"load": "Load",
|
||||
"back": "Back",
|
||||
"statusConnected": "Connected",
|
||||
@ -333,6 +335,7 @@
|
||||
"addNewModel": "Add New Model",
|
||||
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
|
||||
"addDiffuserModel": "Add Diffusers",
|
||||
"scanForModels": "Scan For Models",
|
||||
"addManually": "Add Manually",
|
||||
"manual": "Manual",
|
||||
"name": "Name",
|
||||
|
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Oscuro",
|
||||
"lightTheme": "Claro",
|
||||
"greenTheme": "Verde",
|
||||
"text2img": "Texto a Imagen",
|
||||
"img2img": "Imagen a Imagen",
|
||||
"unifiedCanvas": "Lienzo Unificado",
|
||||
"nodes": "Nodos",
|
||||
@ -70,7 +69,11 @@
|
||||
"langHebrew": "Hebreo",
|
||||
"pinOptionsPanel": "Pin del panel de opciones",
|
||||
"loading": "Cargando",
|
||||
"loadingInvokeAI": "Cargando invocar a la IA"
|
||||
"loadingInvokeAI": "Cargando invocar a la IA",
|
||||
"postprocessing": "Tratamiento posterior",
|
||||
"txt2img": "De texto a imagen",
|
||||
"accept": "Aceptar",
|
||||
"cancel": "Cancelar"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generaciones",
|
||||
@ -404,7 +407,8 @@
|
||||
"none": "ninguno",
|
||||
"pickModelType": "Elige el tipo de modelo",
|
||||
"v2_768": "v2 (768px)",
|
||||
"addDifference": "Añadir una diferencia"
|
||||
"addDifference": "Añadir una diferencia",
|
||||
"scanForModels": "Buscar modelos"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Imágenes",
|
||||
@ -574,7 +578,7 @@
|
||||
"autoSaveToGallery": "Guardar automáticamente en galería",
|
||||
"saveBoxRegionOnly": "Guardar solo región dentro de la caja",
|
||||
"limitStrokesToBox": "Limitar trazos a la caja",
|
||||
"showCanvasDebugInfo": "Mostrar información de depuración de lienzo",
|
||||
"showCanvasDebugInfo": "Mostrar la información adicional del lienzo",
|
||||
"clearCanvasHistory": "Limpiar historial de lienzo",
|
||||
"clearHistory": "Limpiar historial",
|
||||
"clearCanvasHistoryMessage": "Limpiar el historial de lienzo también restablece completamente el lienzo unificado. Esto incluye todo el historial de deshacer/rehacer, las imágenes en el área de preparación y la capa base del lienzo.",
|
||||
|
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Sombre",
|
||||
"lightTheme": "Clair",
|
||||
"greenTheme": "Vert",
|
||||
"text2img": "Texte en image",
|
||||
"img2img": "Image en image",
|
||||
"unifiedCanvas": "Canvas unifié",
|
||||
"nodes": "Nœuds",
|
||||
@ -47,7 +46,19 @@
|
||||
"statusLoadingModel": "Chargement du modèle",
|
||||
"statusModelChanged": "Modèle changé",
|
||||
"discordLabel": "Discord",
|
||||
"githubLabel": "Github"
|
||||
"githubLabel": "Github",
|
||||
"accept": "Accepter",
|
||||
"statusMergingModels": "Mélange des modèles",
|
||||
"loadingInvokeAI": "Chargement de Invoke AI",
|
||||
"cancel": "Annuler",
|
||||
"langEnglish": "Anglais",
|
||||
"statusConvertingModel": "Conversion du modèle",
|
||||
"statusModelConverted": "Modèle converti",
|
||||
"loading": "Chargement",
|
||||
"pinOptionsPanel": "Épingler la page d'options",
|
||||
"statusMergedModels": "Modèles mélangés",
|
||||
"txt2img": "Texte vers image",
|
||||
"postprocessing": "Post-Traitement"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Générations",
|
||||
@ -518,5 +529,15 @@
|
||||
"betaDarkenOutside": "Assombrir à l'extérieur",
|
||||
"betaLimitToBox": "Limiter à la boîte",
|
||||
"betaPreserveMasked": "Conserver masqué"
|
||||
},
|
||||
"accessibility": {
|
||||
"uploadImage": "Charger une image",
|
||||
"reset": "Réinitialiser",
|
||||
"nextImage": "Image suivante",
|
||||
"previousImage": "Image précédente",
|
||||
"useThisParameter": "Utiliser ce paramètre",
|
||||
"zoomIn": "Zoom avant",
|
||||
"zoomOut": "Zoom arrière",
|
||||
"showOptionsPanel": "Montrer la page d'options"
|
||||
}
|
||||
}
|
||||
|
@ -125,7 +125,6 @@
|
||||
"langSimplifiedChinese": "סינית",
|
||||
"langUkranian": "אוקראינית",
|
||||
"langSpanish": "ספרדית",
|
||||
"text2img": "טקסט לתמונה",
|
||||
"img2img": "תמונה לתמונה",
|
||||
"unifiedCanvas": "קנבס מאוחד",
|
||||
"nodes": "צמתים",
|
||||
|
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Scuro",
|
||||
"lightTheme": "Chiaro",
|
||||
"greenTheme": "Verde",
|
||||
"text2img": "Testo a Immagine",
|
||||
"img2img": "Immagine a Immagine",
|
||||
"unifiedCanvas": "Tela unificata",
|
||||
"nodes": "Nodi",
|
||||
@ -70,7 +69,11 @@
|
||||
"loading": "Caricamento in corso",
|
||||
"oceanTheme": "Oceano",
|
||||
"langHebrew": "Ebraico",
|
||||
"loadingInvokeAI": "Caricamento Invoke AI"
|
||||
"loadingInvokeAI": "Caricamento Invoke AI",
|
||||
"postprocessing": "Post Elaborazione",
|
||||
"txt2img": "Testo a Immagine",
|
||||
"accept": "Accetta",
|
||||
"cancel": "Annulla"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generazioni",
|
||||
@ -404,7 +407,8 @@
|
||||
"v2_768": "v2 (768px)",
|
||||
"none": "niente",
|
||||
"addDifference": "Aggiungi differenza",
|
||||
"pickModelType": "Scegli il tipo di modello"
|
||||
"pickModelType": "Scegli il tipo di modello",
|
||||
"scanForModels": "Cerca modelli"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Immagini",
|
||||
@ -574,7 +578,7 @@
|
||||
"autoSaveToGallery": "Salvataggio automatico nella Galleria",
|
||||
"saveBoxRegionOnly": "Salva solo l'area di selezione",
|
||||
"limitStrokesToBox": "Limita i tratti all'area di selezione",
|
||||
"showCanvasDebugInfo": "Mostra informazioni di debug della Tela",
|
||||
"showCanvasDebugInfo": "Mostra ulteriori informazioni sulla Tela",
|
||||
"clearCanvasHistory": "Cancella cronologia Tela",
|
||||
"clearHistory": "Cancella la cronologia",
|
||||
"clearCanvasHistoryMessage": "La cancellazione della cronologia della tela lascia intatta la tela corrente, ma cancella in modo irreversibile la cronologia degli annullamenti e dei ripristini.",
|
||||
@ -612,7 +616,7 @@
|
||||
"copyMetadataJson": "Copia i metadati JSON",
|
||||
"exitViewer": "Esci dal visualizzatore",
|
||||
"zoomIn": "Zoom avanti",
|
||||
"zoomOut": "Zoom Indietro",
|
||||
"zoomOut": "Zoom indietro",
|
||||
"rotateCounterClockwise": "Ruotare in senso antiorario",
|
||||
"rotateClockwise": "Ruotare in senso orario",
|
||||
"flipHorizontally": "Capovolgi orizzontalmente",
|
||||
|
@ -11,7 +11,6 @@
|
||||
"langArabic": "العربية",
|
||||
"langEnglish": "English",
|
||||
"langDutch": "Nederlands",
|
||||
"text2img": "텍스트->이미지",
|
||||
"unifiedCanvas": "통합 캔버스",
|
||||
"langFrench": "Français",
|
||||
"langGerman": "Deutsch",
|
||||
|
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Donker",
|
||||
"lightTheme": "Licht",
|
||||
"greenTheme": "Groen",
|
||||
"text2img": "Tekst naar afbeelding",
|
||||
"img2img": "Afbeelding naar afbeelding",
|
||||
"unifiedCanvas": "Centraal canvas",
|
||||
"nodes": "Knooppunten",
|
||||
|
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Ciemny",
|
||||
"lightTheme": "Jasny",
|
||||
"greenTheme": "Zielony",
|
||||
"text2img": "Tekst na obraz",
|
||||
"img2img": "Obraz na obraz",
|
||||
"unifiedCanvas": "Tryb uniwersalny",
|
||||
"nodes": "Węzły",
|
||||
|
@ -20,7 +20,6 @@
|
||||
"langSpanish": "Espanhol",
|
||||
"langRussian": "Русский",
|
||||
"langUkranian": "Украї́нська",
|
||||
"text2img": "Texto para Imagem",
|
||||
"img2img": "Imagem para Imagem",
|
||||
"unifiedCanvas": "Tela Unificada",
|
||||
"nodes": "Nós",
|
||||
|
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Noite",
|
||||
"lightTheme": "Dia",
|
||||
"greenTheme": "Verde",
|
||||
"text2img": "Texto Para Imagem",
|
||||
"img2img": "Imagem Para Imagem",
|
||||
"unifiedCanvas": "Tela Unificada",
|
||||
"nodes": "Nódulos",
|
||||
|
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Темная",
|
||||
"lightTheme": "Светлая",
|
||||
"greenTheme": "Зеленая",
|
||||
"text2img": "Изображение из текста (text2img)",
|
||||
"img2img": "Изображение в изображение (img2img)",
|
||||
"unifiedCanvas": "Универсальный холст",
|
||||
"nodes": "Ноды",
|
||||
|
@ -8,7 +8,6 @@
|
||||
"darkTheme": "Темна",
|
||||
"lightTheme": "Світла",
|
||||
"greenTheme": "Зелена",
|
||||
"text2img": "Зображення із тексту (text2img)",
|
||||
"img2img": "Зображення із зображення (img2img)",
|
||||
"unifiedCanvas": "Універсальне полотно",
|
||||
"nodes": "Вузли",
|
||||
|
@ -8,7 +8,6 @@
|
||||
"darkTheme": "暗色",
|
||||
"lightTheme": "亮色",
|
||||
"greenTheme": "绿色",
|
||||
"text2img": "文字到图像",
|
||||
"img2img": "图像到图像",
|
||||
"unifiedCanvas": "统一画布",
|
||||
"nodes": "节点",
|
||||
|
@ -33,7 +33,6 @@
|
||||
"langBrPortuguese": "巴西葡萄牙語",
|
||||
"langRussian": "俄語",
|
||||
"langSpanish": "西班牙語",
|
||||
"text2img": "文字到圖像",
|
||||
"unifiedCanvas": "統一畫布"
|
||||
}
|
||||
}
|
||||
|
@ -14,11 +14,11 @@ import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
|
||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||
import { useAppSelector } from './storeHooks';
|
||||
import { useEffect } from 'react';
|
||||
import { PropsWithChildren, useEffect } from 'react';
|
||||
|
||||
keepGUIAlive();
|
||||
|
||||
const App = () => {
|
||||
const App = (props: PropsWithChildren) => {
|
||||
useToastWatcher();
|
||||
|
||||
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
|
||||
@ -40,7 +40,7 @@ const App = () => {
|
||||
w={APP_WIDTH}
|
||||
h={APP_HEIGHT}
|
||||
>
|
||||
<SiteHeader />
|
||||
{props.children || <SiteHeader />}
|
||||
<Flex gap={4} w="full" h="full">
|
||||
<InvokeTabs />
|
||||
<ImageGalleryPanel />
|
||||
|
@ -31,18 +31,14 @@ export const DIFFUSERS_SAMPLERS: Array<string> = [
|
||||
];
|
||||
|
||||
// Valid image widths
|
||||
export const WIDTHS: Array<number> = [
|
||||
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
||||
1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792,
|
||||
1856, 1920, 1984, 2048,
|
||||
];
|
||||
export const WIDTHS: Array<number> = Array.from(Array(65)).map(
|
||||
(_x, i) => i * 64
|
||||
);
|
||||
|
||||
// Valid image heights
|
||||
export const HEIGHTS: Array<number> = [
|
||||
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
||||
1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792,
|
||||
1856, 1920, 1984, 2048,
|
||||
];
|
||||
export const HEIGHTS: Array<number> = Array.from(Array(65)).map(
|
||||
(_x, i) => i * 64
|
||||
);
|
||||
|
||||
// Valid upscaling levels
|
||||
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
|
||||
|
@ -9,6 +9,7 @@ import {
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { cloneElement, memo, ReactElement, ReactNode, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import IAIButton from './IAIButton';
|
||||
|
||||
type Props = {
|
||||
@ -22,10 +23,12 @@ type Props = {
|
||||
};
|
||||
|
||||
const IAIAlertDialog = forwardRef((props: Props, ref) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const {
|
||||
acceptButtonText = 'Accept',
|
||||
acceptButtonText = t('common.accept'),
|
||||
acceptCallback,
|
||||
cancelButtonText = 'Cancel',
|
||||
cancelButtonText = t('common.cancel'),
|
||||
cancelCallback,
|
||||
children,
|
||||
title,
|
||||
@ -56,6 +59,7 @@ const IAIAlertDialog = forwardRef((props: Props, ref) => {
|
||||
isOpen={isOpen}
|
||||
leastDestructiveRef={cancelRef}
|
||||
onClose={onClose}
|
||||
isCentered
|
||||
>
|
||||
<AlertDialogOverlay>
|
||||
<AlertDialogContent>
|
||||
|
8
invokeai/frontend/web/src/common/components/IAIForm.tsx
Normal file
8
invokeai/frontend/web/src/common/components/IAIForm.tsx
Normal file
@ -0,0 +1,8 @@
|
||||
import { chakra } from '@chakra-ui/react';
|
||||
|
||||
/**
|
||||
* Chakra-enabled <form />
|
||||
*/
|
||||
const IAIForm = chakra.form;
|
||||
|
||||
export default IAIForm;
|
@ -0,0 +1,23 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { ReactElement } from 'react';
|
||||
|
||||
export function IAIFormItemWrapper({
|
||||
children,
|
||||
}: {
|
||||
children: ReactElement | ReactElement[];
|
||||
}) {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
rowGap: 4,
|
||||
borderRadius: 'base',
|
||||
width: 'full',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -8,7 +8,7 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import { memo, ReactNode } from 'react';
|
||||
|
||||
type IAIPopoverProps = PopoverProps & {
|
||||
export type IAIPopoverProps = PopoverProps & {
|
||||
triggerComponent: ReactNode;
|
||||
triggerContainerProps?: BoxProps;
|
||||
children: ReactNode;
|
||||
|
@ -1,4 +1,4 @@
|
||||
import React, { lazy } from 'react';
|
||||
import React, { lazy, PropsWithChildren } from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { PersistGate } from 'redux-persist/integration/react';
|
||||
import { store } from './app/store';
|
||||
@ -21,14 +21,14 @@ import './i18n';
|
||||
const App = lazy(() => import('./app/App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
|
||||
|
||||
export default function Component() {
|
||||
export default function Component(props: PropsWithChildren) {
|
||||
return (
|
||||
<React.StrictMode>
|
||||
<Provider store={store}>
|
||||
<PersistGate loading={<Loading />} persistor={persistor}>
|
||||
<React.Suspense fallback={<Loading showText />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App />
|
||||
<App>{props.children}</App>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
</PersistGate>
|
||||
|
16
invokeai/frontend/web/src/exports.tsx
Normal file
16
invokeai/frontend/web/src/exports.tsx
Normal file
@ -0,0 +1,16 @@
|
||||
import Component from './component';
|
||||
|
||||
import InvokeAiLogoComponent from './features/system/components/InvokeAILogoComponent';
|
||||
import ThemeChanger from './features/system/components/ThemeChanger';
|
||||
import IAIPopover from './common/components/IAIPopover';
|
||||
import IAIIconButton from './common/components/IAIIconButton';
|
||||
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
|
||||
|
||||
export default Component;
|
||||
export {
|
||||
InvokeAiLogoComponent,
|
||||
ThemeChanger,
|
||||
IAIPopover,
|
||||
IAIIconButton,
|
||||
SettingsModal,
|
||||
};
|
@ -104,7 +104,6 @@ const IAICanvasMaskOptions = () => {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<ButtonGroup>
|
||||
<IAIIconButton
|
||||
|
@ -88,7 +88,7 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
isLazy={false}
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
tooltip={t('unifiedCanvas.canvasSettings')}
|
||||
|
@ -219,7 +219,6 @@ const IAICanvasToolChooserOptions = () => {
|
||||
onClick={handleSelectColorPickerTool}
|
||||
/>
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
aria-label={t('unifiedCanvas.brushOptions')}
|
||||
|
@ -405,7 +405,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
>
|
||||
<ButtonGroup isAttached={true}>
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
aria-label={`${t('parameters.sendTo')}...`}
|
||||
@ -505,7 +504,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
|
||||
<ButtonGroup isAttached={true}>
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
icon={<FaGrinStars />}
|
||||
@ -535,7 +533,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
</IAIPopover>
|
||||
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
icon={<FaExpandArrowsAlt />}
|
||||
|
@ -0,0 +1,24 @@
|
||||
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
|
||||
|
||||
type CurrentImageFallbackProps = SpinnerProps;
|
||||
|
||||
const CurrentImageFallback = (props: CurrentImageFallbackProps) => {
|
||||
const { size = 'xl', ...rest } = props;
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
position: 'absolute',
|
||||
color: 'base.400',
|
||||
}}
|
||||
>
|
||||
<Spinner size={size} {...rest} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default CurrentImageFallback;
|
@ -7,6 +7,7 @@ import { isEqual } from 'lodash';
|
||||
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
|
||||
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import CurrentImageFallback from './CurrentImageFallback';
|
||||
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||
|
||||
@ -48,6 +49,7 @@ export default function CurrentImagePreview() {
|
||||
src={imageToDisplay.url}
|
||||
width={imageToDisplay.width}
|
||||
height={imageToDisplay.height}
|
||||
fallback={!isIntermediate ? <CurrentImageFallback /> : undefined}
|
||||
sx={{
|
||||
objectFit: 'contain',
|
||||
maxWidth: '100%',
|
||||
|
@ -55,7 +55,6 @@ export default function LanguagePicker() {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
aria-label={t('common.languagePickerLabel')}
|
||||
|
@ -1,4 +1,5 @@
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormHelperText,
|
||||
@ -25,10 +26,10 @@ import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { InvokeModelConfigProps } from 'app/invokeai';
|
||||
import type { RootState } from 'app/store';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import type { FieldInputProps, FormikProps } from 'formik';
|
||||
import { BiArrowBack } from 'react-icons/bi';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
|
||||
|
||||
const MIN_MODEL_SIZE = 64;
|
||||
const MAX_MODEL_SIZE = 2048;
|
||||
@ -72,243 +73,250 @@ export default function AddCheckpointModel() {
|
||||
|
||||
return (
|
||||
<VStack gap={2} alignItems="flex-start">
|
||||
<IAIIconButton
|
||||
aria-label={t('common.back')}
|
||||
tooltip={t('common.back')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
||||
width="max-content"
|
||||
position="absolute"
|
||||
zIndex={1}
|
||||
size="sm"
|
||||
insetInlineEnd={12}
|
||||
top={3}
|
||||
icon={<BiArrowBack />}
|
||||
/>
|
||||
<Flex columnGap={4}>
|
||||
<IAICheckbox
|
||||
isChecked={!addManually}
|
||||
label={t('modelManager.scanForModels')}
|
||||
onChange={() => setAddmanually(!addManually)}
|
||||
/>
|
||||
<IAICheckbox
|
||||
label={t('modelManager.addManually')}
|
||||
isChecked={addManually}
|
||||
onChange={() => setAddmanually(!addManually)}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
<SearchModels />
|
||||
<IAICheckbox
|
||||
label={t('modelManager.addManually')}
|
||||
isChecked={addManually}
|
||||
onChange={() => setAddmanually(!addManually)}
|
||||
/>
|
||||
|
||||
{addManually && (
|
||||
{addManually ? (
|
||||
<Formik
|
||||
initialValues={addModelFormValues}
|
||||
onSubmit={addModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<IAIForm onSubmit={handleSubmit} sx={{ w: 'full' }}>
|
||||
<VStack rowGap={2}>
|
||||
<Text fontSize={20} fontWeight="bold" alignSelf="start">
|
||||
{t('modelManager.manual')}
|
||||
</Text>
|
||||
{/* Name */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.name && touched.name}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="name" fontSize="sm">
|
||||
{t('modelManager.name')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="name"
|
||||
name="name"
|
||||
type="text"
|
||||
validate={baseValidation}
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.name && touched.name ? (
|
||||
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.nameValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl
|
||||
isInvalid={!!errors.name && touched.name}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="name" fontSize="sm">
|
||||
{t('modelManager.name')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="name"
|
||||
name="name"
|
||||
type="text"
|
||||
validate={baseValidation}
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.name && touched.name ? (
|
||||
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.nameValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="description" fontSize="sm">
|
||||
{t('modelManager.description')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="description"
|
||||
name="description"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.description && touched.description ? (
|
||||
<FormErrorMessage>{errors.description}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.descriptionValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="description" fontSize="sm">
|
||||
{t('modelManager.description')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="description"
|
||||
name="description"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.description && touched.description ? (
|
||||
<FormErrorMessage>
|
||||
{errors.description}
|
||||
</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.descriptionValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* Config */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.config && touched.config}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelManager.config')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="config"
|
||||
name="config"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.config && touched.config ? (
|
||||
<FormErrorMessage>{errors.config}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.configValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl
|
||||
isInvalid={!!errors.config && touched.config}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelManager.config')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="config"
|
||||
name="config"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.config && touched.config ? (
|
||||
<FormErrorMessage>{errors.config}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.configValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* Weights */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.weights && touched.weights}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelManager.modelLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="weights"
|
||||
name="weights"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.weights && touched.weights ? (
|
||||
<FormErrorMessage>{errors.weights}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.modelLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl
|
||||
isInvalid={!!errors.weights && touched.weights}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="config" fontSize="sm">
|
||||
{t('modelManager.modelLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="weights"
|
||||
name="weights"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.weights && touched.weights ? (
|
||||
<FormErrorMessage>{errors.weights}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.modelLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* VAE */}
|
||||
<FormControl isInvalid={!!errors.vae && touched.vae}>
|
||||
<FormLabel htmlFor="vae" fontSize="sm">
|
||||
{t('modelManager.vaeLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae"
|
||||
name="vae"
|
||||
type="text"
|
||||
width="2xl"
|
||||
/>
|
||||
{!!errors.vae && touched.vae ? (
|
||||
<FormErrorMessage>{errors.vae}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.vaeLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl isInvalid={!!errors.vae && touched.vae}>
|
||||
<FormLabel htmlFor="vae" fontSize="sm">
|
||||
{t('modelManager.vaeLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae"
|
||||
name="vae"
|
||||
type="text"
|
||||
width="full"
|
||||
/>
|
||||
{!!errors.vae && touched.vae ? (
|
||||
<FormErrorMessage>{errors.vae}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.vaeLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<HStack width="100%">
|
||||
{/* Width */}
|
||||
<FormControl isInvalid={!!errors.width && touched.width}>
|
||||
<FormLabel htmlFor="width" fontSize="sm">
|
||||
{t('modelManager.width')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field id="width" name="width">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="width"
|
||||
name="width"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
step={64}
|
||||
width="90%"
|
||||
value={form.values.width}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl isInvalid={!!errors.width && touched.width}>
|
||||
<FormLabel htmlFor="width" fontSize="sm">
|
||||
{t('modelManager.width')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field id="width" name="width">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="width"
|
||||
name="width"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
step={64}
|
||||
value={form.values.width}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
|
||||
{!!errors.width && touched.width ? (
|
||||
<FormErrorMessage>{errors.width}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.widthValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
{!!errors.width && touched.width ? (
|
||||
<FormErrorMessage>{errors.width}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.widthValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
{/* Height */}
|
||||
<FormControl isInvalid={!!errors.height && touched.height}>
|
||||
<FormLabel htmlFor="height" fontSize="sm">
|
||||
{t('modelManager.height')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field id="height" name="height">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="height"
|
||||
name="height"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
width="90%"
|
||||
step={64}
|
||||
value={form.values.height}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
<IAIFormItemWrapper>
|
||||
<FormControl isInvalid={!!errors.height && touched.height}>
|
||||
<FormLabel htmlFor="height" fontSize="sm">
|
||||
{t('modelManager.height')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field id="height" name="height">
|
||||
{({
|
||||
field,
|
||||
form,
|
||||
}: {
|
||||
field: FieldInputProps<number>;
|
||||
form: FormikProps<InvokeModelConfigProps>;
|
||||
}) => (
|
||||
<IAINumberInput
|
||||
id="height"
|
||||
name="height"
|
||||
min={MIN_MODEL_SIZE}
|
||||
max={MAX_MODEL_SIZE}
|
||||
step={64}
|
||||
value={form.values.height}
|
||||
onChange={(value) =>
|
||||
form.setFieldValue(field.name, Number(value))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
|
||||
{!!errors.height && touched.height ? (
|
||||
<FormErrorMessage>{errors.height}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.heightValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
{!!errors.height && touched.height ? (
|
||||
<FormErrorMessage>{errors.height}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.heightValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
</HStack>
|
||||
|
||||
<IAIButton
|
||||
@ -319,9 +327,11 @@ export default function AddCheckpointModel() {
|
||||
{t('modelManager.addModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
) : (
|
||||
<SearchModels />
|
||||
)}
|
||||
</VStack>
|
||||
);
|
||||
|
@ -11,36 +11,14 @@ import { InvokeDiffusersModelConfigProps } from 'app/invokeai';
|
||||
import { addNewModel } from 'app/socketio/actions';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import { Field, Formik } from 'formik';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { BiArrowBack } from 'react-icons/bi';
|
||||
|
||||
import type { RootState } from 'app/store';
|
||||
import type { ReactElement } from 'react';
|
||||
|
||||
function FormItemWrapper({
|
||||
children,
|
||||
}: {
|
||||
children: ReactElement | ReactElement[];
|
||||
}) {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
padding: 4,
|
||||
rowGap: 4,
|
||||
borderRadius: 'base',
|
||||
width: 'full',
|
||||
bg: 'base.900',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
|
||||
|
||||
export default function AddDiffusersModel() {
|
||||
const dispatch = useAppDispatch();
|
||||
@ -89,26 +67,14 @@ export default function AddDiffusersModel() {
|
||||
|
||||
return (
|
||||
<Flex>
|
||||
<IAIIconButton
|
||||
aria-label={t('common.back')}
|
||||
tooltip={t('common.back')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
||||
width="max-content"
|
||||
position="absolute"
|
||||
zIndex={1}
|
||||
size="sm"
|
||||
insetInlineEnd={12}
|
||||
top={3}
|
||||
icon={<BiArrowBack />}
|
||||
/>
|
||||
<Formik
|
||||
initialValues={addModelFormValues}
|
||||
onSubmit={addModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<IAIForm onSubmit={handleSubmit}>
|
||||
<VStack rowGap={2}>
|
||||
<FormItemWrapper>
|
||||
<IAIFormItemWrapper>
|
||||
{/* Name */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.name && touched.name}
|
||||
@ -136,9 +102,9 @@ export default function AddDiffusersModel() {
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<FormItemWrapper>
|
||||
<IAIFormItemWrapper>
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
@ -165,9 +131,9 @@ export default function AddDiffusersModel() {
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<FormItemWrapper>
|
||||
<IAIFormItemWrapper>
|
||||
<Text fontWeight="bold" fontSize="sm">
|
||||
{t('modelManager.formMessageDiffusersModelLocation')}
|
||||
</Text>
|
||||
@ -226,9 +192,9 @@ export default function AddDiffusersModel() {
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<FormItemWrapper>
|
||||
<IAIFormItemWrapper>
|
||||
{/* VAE Path */}
|
||||
<Text fontWeight="bold">
|
||||
{t('modelManager.formMessageDiffusersVAELocation')}
|
||||
@ -290,13 +256,13 @@ export default function AddDiffusersModel() {
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</FormItemWrapper>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<IAIButton type="submit" isLoading={isProcessing}>
|
||||
{t('modelManager.addModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
</Flex>
|
||||
|
@ -14,7 +14,7 @@ import {
|
||||
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
|
||||
import { FaPlus } from 'react-icons/fa';
|
||||
import { FaArrowLeft, FaPlus } from 'react-icons/fa';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -23,6 +23,7 @@ import type { RootState } from 'app/store';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import AddCheckpointModel from './AddCheckpointModel';
|
||||
import AddDiffusersModel from './AddDiffusersModel';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
|
||||
function AddModelBox({
|
||||
text,
|
||||
@ -83,8 +84,22 @@ export default function AddModel() {
|
||||
closeOnOverlayClick={false}
|
||||
>
|
||||
<ModalOverlay />
|
||||
<ModalContent margin="auto" paddingInlineEnd={4}>
|
||||
<ModalHeader>{t('modelManager.addNewModel')}</ModalHeader>
|
||||
<ModalContent margin="auto">
|
||||
<ModalHeader>{t('modelManager.addNewModel')} </ModalHeader>
|
||||
{addNewModelUIOption !== null && (
|
||||
<IAIIconButton
|
||||
aria-label={t('common.back')}
|
||||
tooltip={t('common.back')}
|
||||
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
||||
position="absolute"
|
||||
variant="ghost"
|
||||
zIndex={1}
|
||||
size="sm"
|
||||
insetInlineEnd={12}
|
||||
top={2}
|
||||
icon={<FaArrowLeft />}
|
||||
/>
|
||||
)}
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
{addNewModelUIOption == null && (
|
||||
|
@ -28,6 +28,7 @@ import { isEqual, pickBy } from 'lodash';
|
||||
import ModelConvert from './ModelConvert';
|
||||
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
||||
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
|
||||
const selector = createSelector(
|
||||
[systemSelector],
|
||||
@ -120,7 +121,7 @@ export default function CheckpointModelEdit() {
|
||||
onSubmit={editModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<IAIForm onSubmit={handleSubmit}>
|
||||
<VStack rowGap={2} alignItems="start">
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
@ -317,7 +318,7 @@ export default function CheckpointModelEdit() {
|
||||
{t('modelManager.updateModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
</Flex>
|
||||
|
@ -18,6 +18,7 @@ import type { RootState } from 'app/store';
|
||||
import { isEqual, pickBy } from 'lodash';
|
||||
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
|
||||
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
|
||||
const selector = createSelector(
|
||||
[systemSelector],
|
||||
@ -116,7 +117,7 @@ export default function DiffusersModelEdit() {
|
||||
onSubmit={editModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<IAIForm onSubmit={handleSubmit}>
|
||||
<VStack rowGap={2} alignItems="start">
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
@ -259,7 +260,7 @@ export default function DiffusersModelEdit() {
|
||||
{t('modelManager.updateModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
</Flex>
|
||||
|
@ -12,14 +12,13 @@ import {
|
||||
RadioGroup,
|
||||
Spacer,
|
||||
Text,
|
||||
VStack,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { FaPlus, FaSearch } from 'react-icons/fa';
|
||||
import { FaSearch, FaTrash } from 'react-icons/fa';
|
||||
|
||||
import { addNewModel, searchForModels } from 'app/socketio/actions';
|
||||
import {
|
||||
@ -34,7 +33,7 @@ import IAIInput from 'common/components/IAIInput';
|
||||
import { Field, Formik } from 'formik';
|
||||
import { forEach, remove } from 'lodash';
|
||||
import type { ChangeEvent, ReactNode } from 'react';
|
||||
import { BiReset } from 'react-icons/bi';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
|
||||
const existingModelsSelector = createSelector([systemSelector], (system) => {
|
||||
const { model_list } = system;
|
||||
@ -71,34 +70,32 @@ function SearchModelEntry({
|
||||
};
|
||||
|
||||
return (
|
||||
<VStack>
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
gap={2}
|
||||
backgroundColor={
|
||||
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
|
||||
}
|
||||
paddingX={4}
|
||||
paddingY={2}
|
||||
borderRadius={4}
|
||||
>
|
||||
<Flex gap={4}>
|
||||
<IAICheckbox
|
||||
value={model.name}
|
||||
label={<Text fontWeight={500}>{model.name}</Text>}
|
||||
isChecked={modelsToAdd.includes(model.name)}
|
||||
isDisabled={existingModels.includes(model.location)}
|
||||
onChange={foundModelsChangeHandler}
|
||||
></IAICheckbox>
|
||||
{existingModels.includes(model.location) && (
|
||||
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
|
||||
)}
|
||||
</Flex>
|
||||
<Text fontStyle="italic" variant="subtext">
|
||||
{model.location}
|
||||
</Text>
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
gap={2}
|
||||
backgroundColor={
|
||||
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
|
||||
}
|
||||
paddingX={4}
|
||||
paddingY={2}
|
||||
borderRadius={4}
|
||||
>
|
||||
<Flex gap={4} alignItems="center" justifyContent="space-between">
|
||||
<IAICheckbox
|
||||
value={model.name}
|
||||
label={<Text fontWeight={500}>{model.name}</Text>}
|
||||
isChecked={modelsToAdd.includes(model.name)}
|
||||
isDisabled={existingModels.includes(model.location)}
|
||||
onChange={foundModelsChangeHandler}
|
||||
></IAICheckbox>
|
||||
{existingModels.includes(model.location) && (
|
||||
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
|
||||
)}
|
||||
</Flex>
|
||||
</VStack>
|
||||
<Text fontStyle="italic" variant="subtext">
|
||||
{model.location}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
@ -215,10 +212,10 @@ export default function SearchModels() {
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex flexDirection="column" rowGap={4}>
|
||||
{newFoundModels}
|
||||
{shouldShowExistingModelsInSearch && existingFoundModels}
|
||||
</>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@ -245,26 +242,26 @@ export default function SearchModels() {
|
||||
<Text
|
||||
sx={{
|
||||
fontWeight: 500,
|
||||
fontSize: 'sm',
|
||||
}}
|
||||
variant="subtext"
|
||||
>
|
||||
{t('modelManager.checkpointFolder')}
|
||||
</Text>
|
||||
<Text sx={{ fontWeight: 500, fontSize: 'sm' }}>{searchFolder}</Text>
|
||||
<Text sx={{ fontWeight: 500 }}>{searchFolder}</Text>
|
||||
</Flex>
|
||||
<Spacer />
|
||||
<IAIIconButton
|
||||
aria-label={t('modelManager.scanAgain')}
|
||||
tooltip={t('modelManager.scanAgain')}
|
||||
icon={<BiReset />}
|
||||
icon={<FaSearch />}
|
||||
fontSize={18}
|
||||
disabled={isProcessing}
|
||||
onClick={() => dispatch(searchForModels(searchFolder))}
|
||||
/>
|
||||
<IAIIconButton
|
||||
aria-label={t('modelManager.clearCheckpointFolder')}
|
||||
icon={<FaPlus style={{ transform: 'rotate(45deg)' }} />}
|
||||
tooltip={t('modelManager.clearCheckpointFolder')}
|
||||
icon={<FaTrash />}
|
||||
onClick={resetSearchModelHandler}
|
||||
/>
|
||||
</Flex>
|
||||
@ -276,9 +273,9 @@ export default function SearchModels() {
|
||||
}}
|
||||
>
|
||||
{({ handleSubmit }) => (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<HStack columnGap={2} alignItems="flex-end" width="100%">
|
||||
<FormControl isRequired width="lg">
|
||||
<IAIForm onSubmit={handleSubmit} width="100%">
|
||||
<HStack columnGap={2} alignItems="flex-end">
|
||||
<FormControl flexGrow={1}>
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="checkpointFolder"
|
||||
@ -294,12 +291,12 @@ export default function SearchModels() {
|
||||
tooltip={t('modelManager.findModels')}
|
||||
type="submit"
|
||||
disabled={isProcessing}
|
||||
paddingX={10}
|
||||
px={8}
|
||||
>
|
||||
{t('modelManager.findModels')}
|
||||
</IAIButton>
|
||||
</HStack>
|
||||
</form>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
)}
|
||||
@ -410,7 +407,6 @@ export default function SearchModels() {
|
||||
maxHeight={72}
|
||||
overflowY="scroll"
|
||||
borderRadius="sm"
|
||||
paddingInlineEnd={4}
|
||||
gap={2}
|
||||
>
|
||||
{foundModels.length > 0 ? (
|
||||
|
@ -50,7 +50,6 @@ export default function ThemeChanger() {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
aria-label={t('common.themeLabel')}
|
||||
|
@ -166,20 +166,8 @@ export default function InvokeTabs() {
|
||||
[]
|
||||
);
|
||||
|
||||
/**
|
||||
* isLazy means the tabs are mounted and unmounted when changing them. There is a tradeoff here,
|
||||
* as mounting is expensive, but so is retaining all tabs in the DOM at all times.
|
||||
*
|
||||
* Removing isLazy messes with the outside click watcher, which is used by ResizableDrawer.
|
||||
* Because you have multiple handlers listening for an outside click, any click anywhere triggers
|
||||
* the watcher for the hidden drawers, closing the open drawer.
|
||||
*
|
||||
* TODO: Add logic to the `useOutsideClick` in ResizableDrawer to enable it only for the active
|
||||
* tab's drawer.
|
||||
*/
|
||||
return (
|
||||
<Tabs
|
||||
isLazy
|
||||
defaultIndex={activeTab}
|
||||
index={activeTab}
|
||||
onChange={(index: number) => {
|
||||
|
@ -93,12 +93,9 @@ const ResizableDrawer = ({
|
||||
useOutsideClick({
|
||||
ref: outsideClickRef,
|
||||
handler: () => {
|
||||
if (isPinned) {
|
||||
return;
|
||||
}
|
||||
|
||||
onClose();
|
||||
},
|
||||
enabled: isOpen && !isPinned,
|
||||
});
|
||||
|
||||
const handleEnables = useMemo(
|
||||
|
@ -77,7 +77,6 @@ export default function UnifiedCanvasColorPicker() {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
triggerComponent={
|
||||
<Box
|
||||
sx={{
|
||||
|
@ -56,7 +56,7 @@ const UnifiedCanvasSettings = () => {
|
||||
|
||||
return (
|
||||
<IAIPopover
|
||||
trigger="hover"
|
||||
isLazy={false}
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
tooltip={t('unifiedCanvas.canvasSettings')}
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,4 +1,3 @@
|
||||
import path from 'path';
|
||||
import react from '@vitejs/plugin-react-swc';
|
||||
import { visualizer } from 'rollup-plugin-visualizer';
|
||||
import { defineConfig, PluginOption } from 'vite';
|
||||
@ -58,26 +57,6 @@ export default defineConfig(({ mode }) => {
|
||||
// sourcemap: true, // this can be enabled if needed, it adds ovwer 15MB to the commit
|
||||
},
|
||||
};
|
||||
} else if (mode === 'package') {
|
||||
return {
|
||||
...common,
|
||||
build: {
|
||||
...common.build,
|
||||
lib: {
|
||||
entry: path.resolve(__dirname, 'src/component.tsx'),
|
||||
name: 'InvokeAI UI',
|
||||
fileName: (format) => `invoke-ai-ui.${format}.js`,
|
||||
},
|
||||
rollupOptions: {
|
||||
external: ['react', 'react-dom'],
|
||||
output: {
|
||||
globals: {
|
||||
react: 'React',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
...common,
|
||||
|
@ -38,16 +38,16 @@ dependencies = [
|
||||
"albumentations",
|
||||
"click",
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==0.1.10",
|
||||
"compel==1.0.4",
|
||||
"datasets",
|
||||
"diffusers[torch]~=0.14",
|
||||
"dnspython==2.2.1",
|
||||
"einops",
|
||||
"eventlet",
|
||||
"facexlib",
|
||||
"fastapi==0.85.0",
|
||||
"fastapi-events==0.6.0",
|
||||
"fastapi-socketio==0.0.9",
|
||||
"fastapi==0.94.1",
|
||||
"fastapi-events==0.8.0",
|
||||
"fastapi-socketio==0.0.10",
|
||||
"flask==2.1.3",
|
||||
"flask_cors==3.0.10",
|
||||
"flask_socketio==5.3.0",
|
||||
@ -75,7 +75,7 @@ dependencies = [
|
||||
"torchvision>=0.14.1",
|
||||
"torchmetrics",
|
||||
"transformers~=4.26",
|
||||
"uvicorn[standard]==0.20.0",
|
||||
"uvicorn[standard]==0.21.1",
|
||||
"windows-curses; sys_platform=='win32'",
|
||||
]
|
||||
|
||||
@ -139,8 +139,24 @@ version = { attr = "invokeai.version.__version__" }
|
||||
"invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"]
|
||||
"invokeai.frontend.web.dist" = ["**"]
|
||||
|
||||
#=== Begin: PyTest and Coverage
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "-p pytest_cov --junitxml=junit/test-results.xml --cov-report=term:skip-covered --cov=ldm/invoke --cov=backend --cov-branch"
|
||||
addopts = "--cov-report term --cov-report html --cov-report xml"
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
source = ["invokeai"]
|
||||
omit = ["*tests*", "*migrations*", ".venv/*", "*.env"]
|
||||
[tool.coverage.report]
|
||||
show_missing = true
|
||||
fail_under = 85 # let's set something sensible on Day 1 ...
|
||||
[tool.coverage.json]
|
||||
output = "coverage/coverage.json"
|
||||
pretty_print = true
|
||||
[tool.coverage.html]
|
||||
directory = "coverage/html"
|
||||
[tool.coverage.xml]
|
||||
output = "coverage/index.xml"
|
||||
#=== End: PyTest and Coverage
|
||||
|
||||
[flake8]
|
||||
max-line-length = 120
|
||||
|
@ -105,17 +105,20 @@
|
||||
|
||||
// Start building nodes
|
||||
var id = 1;
|
||||
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "sampler": sampler, "steps": steps, "seed": seed};
|
||||
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": seed};
|
||||
id++;
|
||||
var i2iNode = {"id": id.toString(), "type": "img2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": Math.floor(Math.random() * 10000)};
|
||||
id++;
|
||||
var upscaleNode = {"id": id.toString(), "type": "show_image" };
|
||||
id++
|
||||
|
||||
nodes = {};
|
||||
nodes[initialNode.id] = initialNode;
|
||||
nodes[i2iNode.id] = i2iNode;
|
||||
nodes[upscaleNode.id] = upscaleNode;
|
||||
links = [
|
||||
[{ "node_id": initialNode.id, field: "image" },
|
||||
{ "node_id": upscaleNode.id, field: "image" }]
|
||||
{ "source": { "node_id": initialNode.id, field: "image" }, "destination": { "node_id": i2iNode.id, field: "image" }},
|
||||
{ "source": { "node_id": i2iNode.id, field: "image" }, "destination": { "node_id": upscaleNode.id, field: "image" }}
|
||||
];
|
||||
// expandSize = 128;
|
||||
// for (var i = 0; i < 6; ++i) {
|
||||
|
@ -1,15 +1,18 @@
|
||||
from invokeai.app.invocations.image import *
|
||||
|
||||
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
|
||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
||||
from invokeai.app.invocations.upscale import UpscaleInvocation
|
||||
import pytest
|
||||
|
||||
|
||||
# Helpers
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
|
||||
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
|
||||
return Edge(
|
||||
source=EdgeConnection(node_id = from_id, field = from_field),
|
||||
destination=EdgeConnection(node_id = to_id, field = to_field)
|
||||
)
|
||||
|
||||
# Tests
|
||||
def test_connections_are_compatible():
|
||||
@ -108,7 +111,7 @@ def test_graph_allows_non_conflicting_id_change():
|
||||
assert g.get_node("3").prompt == "Banana sushi"
|
||||
|
||||
assert len(g.edges) == 1
|
||||
assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges
|
||||
assert Edge(source=EdgeConnection(node_id = "3", field = "image"), destination=EdgeConnection(node_id = "2", field = "image")) in g.edges
|
||||
|
||||
def test_graph_fails_to_update_node_id_if_conflict():
|
||||
g = Graph()
|
||||
@ -490,10 +493,10 @@ def test_graph_can_deserialize():
|
||||
assert g2.nodes['1'] is not None
|
||||
assert g2.nodes['2'] is not None
|
||||
assert len(g2.edges) == 1
|
||||
assert g2.edges[0][0].node_id == '1'
|
||||
assert g2.edges[0][0].field == 'image'
|
||||
assert g2.edges[0][1].node_id == '2'
|
||||
assert g2.edges[0][1].field == 'image'
|
||||
assert g2.edges[0].source.node_id == '1'
|
||||
assert g2.edges[0].source.field == 'image'
|
||||
assert g2.edges[0].destination.node_id == '2'
|
||||
assert g2.edges[0].destination.field == 'image'
|
||||
|
||||
def test_graph_can_generate_schema():
|
||||
# Not throwing on this line is sufficient
|
||||
|
@ -64,10 +64,12 @@ class PromptCollectionTestInvocation(BaseInvocation):
|
||||
|
||||
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import EdgeConnection
|
||||
from invokeai.app.services.graph import Edge, EdgeConnection
|
||||
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
|
||||
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
|
||||
return Edge(
|
||||
source=EdgeConnection(node_id = from_id, field = from_field),
|
||||
destination=EdgeConnection(node_id = to_id, field = to_field))
|
||||
|
||||
|
||||
class TestEvent:
|
||||
|
Loading…
Reference in New Issue
Block a user