Merge branch 'main' into ip-adapter-style-comp
BIN
docs/assets/gallery/board_settings.png
Normal file
After Width: | Height: | Size: 23 KiB |
BIN
docs/assets/gallery/board_tabs.png
Normal file
After Width: | Height: | Size: 2.7 KiB |
BIN
docs/assets/gallery/board_thumbnails.png
Normal file
After Width: | Height: | Size: 30 KiB |
BIN
docs/assets/gallery/gallery.png
Normal file
After Width: | Height: | Size: 221 KiB |
BIN
docs/assets/gallery/image_menu.png
Normal file
After Width: | Height: | Size: 53 KiB |
BIN
docs/assets/gallery/info_button.png
Normal file
After Width: | Height: | Size: 786 B |
BIN
docs/assets/gallery/thumbnail_menu.png
Normal file
After Width: | Height: | Size: 27 KiB |
BIN
docs/assets/gallery/top_controls.png
Normal file
After Width: | Height: | Size: 3.3 KiB |
92
docs/features/GALLERY.md
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
---
|
||||||
|
title: InvokeAI Gallery Panel
|
||||||
|
---
|
||||||
|
|
||||||
|
# :material-web: InvokeAI Gallery Panel
|
||||||
|
|
||||||
|
## Quick guided walkthrough of the Gallery Panel's features
|
||||||
|
|
||||||
|
The Gallery Panel is a fast way to review, find, and make use of images you've
|
||||||
|
generated and loaded. The Gallery is divided into Boards. The Uncategorized board is always
|
||||||
|
present but you can create your own for better organization.
|
||||||
|
|
||||||
|
![image](../assets/gallery/gallery.png)
|
||||||
|
|
||||||
|
### Board Display and Settings
|
||||||
|
|
||||||
|
At the very top of the Gallery Panel are the boards disclosure and settings buttons.
|
||||||
|
|
||||||
|
![image](../assets/gallery/top_controls.png)
|
||||||
|
|
||||||
|
The disclosure button shows the name of the currently selected board and allows you to show and hide the board thumbnails (shown in the image below).
|
||||||
|
|
||||||
|
![image](../assets/gallery/board_thumbnails.png)
|
||||||
|
|
||||||
|
The settings button opens a list of options.
|
||||||
|
|
||||||
|
![image](../assets/gallery/board_settings.png)
|
||||||
|
|
||||||
|
- ***Image Size*** this slider lets you control the size of the image previews (images of three different sizes).
|
||||||
|
- ***Auto-Switch to New Images*** if you turn this on, whenever a new image is generated, it will automatically be loaded into the current image panel on the Text to Image tab and into the result panel on the [Image to Image](IMG2IMG.md) tab. This will happen invisibly if you are on any other tab when the image is generated.
|
||||||
|
- ***Auto-Assign Board on Click*** whenever an image is generated or saved, it always gets put in a board. The board it gets put into is marked with AUTO (image of board marked). Turning on Auto-Assign Board on Click will make whichever board you last selected be the destination when you click Invoke. That means you can click Invoke, select a different board, and then click Invoke again and the two images will be put in two different boards. (bold)It's the board selected when Invoke is clicked that's used, not the board that's selected when the image is finished generating.(bold) Turning this off, enables the Auto-Add Board drop down which lets you set one specific board to always put generated images into. This also enables and disables the Auto-add to this Board menu item described below.
|
||||||
|
- ***Always Show Image Size Badge*** this toggles whether to show image sizes for each image preview (show two images, one with sizes shown, one without)
|
||||||
|
|
||||||
|
Below these two buttons, you'll see the Search Boards text entry area. You use this to search for specific boards by the name of the board.
|
||||||
|
Next to it is the Add Board (+) button which lets you add new boards. Boards can be renamed by clicking on the name of the board under its thumbnail and typing in the new name.
|
||||||
|
|
||||||
|
### Board Thumbnail Menu
|
||||||
|
|
||||||
|
Each board has a context menu (ctrl+click / right-click).
|
||||||
|
|
||||||
|
![image](../assets/gallery/thumbnail_menu.png)
|
||||||
|
|
||||||
|
- ***Auto-add to this Board*** if you've disabled Auto-Assign Board on Click in the board settings, you can use this option to set this board to be where new images are put.
|
||||||
|
- ***Download Board*** this will add all the images in the board into a zip file and provide a link to it in a notification (image of notification)
|
||||||
|
- ***Delete Board*** this will delete the board
|
||||||
|
> [!CAUTION]
|
||||||
|
> This will delete all the images in the board and the board itself.
|
||||||
|
|
||||||
|
### Board Contents
|
||||||
|
|
||||||
|
Every board is organized by two tabs, Images and Assets.
|
||||||
|
|
||||||
|
![image](../assets/gallery/board_tabs.png)
|
||||||
|
|
||||||
|
Images are the Invoke-generated images that are placed into the board. Assets are images that you upload into Invoke to be used as an [Image Prompt](https://support.invoke.ai/support/solutions/articles/151000159340-using-the-image-prompt-adapter-ip-adapter-) or in the [Image to Image](IMG2IMG.md) tab.
|
||||||
|
|
||||||
|
### Image Thumbnail Menu
|
||||||
|
|
||||||
|
Every image generated by Invoke has its generation information stored as text inside the image file itself. This can be read directly by selecting the image and clicking on the Info button ![image](../assets/gallery/info_button.png) in any of the image result panels.
|
||||||
|
|
||||||
|
Each image also has a context menu (ctrl+click / right-click).
|
||||||
|
|
||||||
|
![image](../assets/gallery/image_menu.png)
|
||||||
|
|
||||||
|
The options are (items marked with an * will not work with images that lack generation information):
|
||||||
|
- ***Open in New Tab*** this will open the image alone in a new browser tab, separate from the Invoke interface.
|
||||||
|
- ***Download Image*** this will trigger your browser to download the image.
|
||||||
|
- ***Load Workflow **** this will load any workflow settings into the Workflow tab and automatically open it.
|
||||||
|
- ***Remix Image **** this will load all of the image's generation information, (bold)excluding its Seed, into the left hand control panel
|
||||||
|
- ***Use Prompt **** this will load only the image's text prompts into the left-hand control panel
|
||||||
|
- ***Use Seed **** this will load only the image's Seed into the left-hand control panel
|
||||||
|
- ***Use All **** this will load all of the image's generation information into the left-hand control panel
|
||||||
|
- ***Send to Image to Image*** this will put the image into the left-hand panel in the Image to Image tab ana automatically open it
|
||||||
|
- ***Send to Unified Canvas*** This will (bold)replace whatever is already present(bold) in the Unified Canvas tab with the image and automatically open the tab
|
||||||
|
- ***Change Board*** this will oipen a small window that will let you move the image to a different board. This is the same as dragging the image to that board's thumbnail.
|
||||||
|
- ***Star Image*** this will add the image to the board's list of starred images that are always kept at the top of the gallery. This is the same as clicking on the star on the top right-hand side of the image that appears when you hover over the image with the mouse
|
||||||
|
- ***Delete Image*** this will delete the image from the board
|
||||||
|
> [!CAUTION]
|
||||||
|
> This will delete the image entirely from Invoke.
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
This walkthrough only covers the Gallery interface and Boards. Actually generating images is handled by [Prompts](PROMPTS.md), the [Image to Image](IMG2IMG.md) tab, and the [Unified Canvas](UNIFIED_CANVAS.md).
|
||||||
|
|
||||||
|
## Acknowledgements
|
||||||
|
|
||||||
|
A huge shout-out to the core team working to make the Web GUI a reality,
|
||||||
|
including [psychedelicious](https://github.com/psychedelicious),
|
||||||
|
[Kyle0654](https://github.com/Kyle0654) and
|
||||||
|
[blessedcoolant](https://github.com/blessedcoolant).
|
||||||
|
[hipsterusername](https://github.com/hipsterusername) was the team's unofficial
|
||||||
|
cheerleader and added tooltips/docs.
|
@ -54,7 +54,7 @@ main sections:
|
|||||||
of buttons at the top lets you modify and manipulate the image in
|
of buttons at the top lets you modify and manipulate the image in
|
||||||
various ways.
|
various ways.
|
||||||
|
|
||||||
3. A **gallery** section on the left that contains a history of the images you
|
3. A **gallery** section on the right that contains a history of the images you
|
||||||
have generated. These images are read and written to the directory specified
|
have generated. These images are read and written to the directory specified
|
||||||
in the `INVOKEAIROOT/invokeai.yaml` initialization file, usually a directory
|
in the `INVOKEAIROOT/invokeai.yaml` initialization file, usually a directory
|
||||||
named `outputs` in `INVOKEAIROOT`.
|
named `outputs` in `INVOKEAIROOT`.
|
||||||
|
@ -23,6 +23,7 @@ If you have an interest in how InvokeAI works, or you would like to add features
|
|||||||
|
|
||||||
1. [Fork and clone] the [InvokeAI repo].
|
1. [Fork and clone] the [InvokeAI repo].
|
||||||
1. Follow the [manual installation] docs to create a new virtual environment for the development install.
|
1. Follow the [manual installation] docs to create a new virtual environment for the development install.
|
||||||
|
- Create a new folder outside the repo root for the installation and create the venv inside that folder.
|
||||||
- When installing the InvokeAI package, add `-e` to the command so you get an [editable install].
|
- When installing the InvokeAI package, add `-e` to the command so you get an [editable install].
|
||||||
1. Install the [frontend dev toolchain] and do a production build of the UI as described.
|
1. Install the [frontend dev toolchain] and do a production build of the UI as described.
|
||||||
1. You can now run the app as described in the [manual installation] docs.
|
1. You can now run the app as described in the [manual installation] docs.
|
||||||
|
@ -28,7 +28,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
|||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
from invokeai.backend.util.devices import get_torch_device_name
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
@ -63,7 +63,7 @@ logger = InvokeAILogger.get_logger(config=app_config)
|
|||||||
mimetypes.add_type("application/javascript", ".js")
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
mimetypes.add_type("text/css", ".css")
|
mimetypes.add_type("text/css", ".css")
|
||||||
|
|
||||||
torch_device_name = get_torch_device_name()
|
torch_device_name = TorchDevice.get_torch_device_name()
|
||||||
logger.info(f"Using torch device: {torch_device_name}")
|
logger.info(f"Using torch device: {torch_device_name}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
ConditioningFieldData,
|
ConditioningFieldData,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
from invokeai.backend.util.devices import torch_dtype
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
from .model import CLIPField
|
from .model import CLIPField
|
||||||
@ -99,7 +99,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||||
truncate_long_prompts=False,
|
truncate_long_prompts=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -193,7 +193,7 @@ class SDXLPromptInvocationBase:
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||||
truncate_long_prompts=False, # TODO:
|
truncate_long_prompts=False, # TODO:
|
||||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||||
requires_pooled=get_pooled,
|
requires_pooled=get_pooled,
|
||||||
|
@ -72,15 +72,12 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
image_resized_to_grid_as_tensor,
|
image_resized_to_grid_as_tensor,
|
||||||
)
|
)
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import TorchDevice
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .model import ModelIdentifierField, UNetField, VAEField
|
from .model import ModelIdentifierField, UNetField, VAEField
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
|
||||||
from torch import mps
|
|
||||||
|
|
||||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("scheduler_output")
|
@invocation_output("scheduler_output")
|
||||||
@ -960,9 +957,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
result_latents = result_latents.to("cpu")
|
result_latents = result_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=result_latents)
|
name = context.tensors.save(tensor=result_latents)
|
||||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||||
@ -1029,9 +1024,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
|
|
||||||
# clear memory as vae decode can request a lot
|
# clear memory as vae decode can request a lot
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# copied from diffusers pipeline
|
# copied from diffusers pipeline
|
||||||
@ -1043,9 +1036,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=image)
|
image_dto = context.images.save(image=image)
|
||||||
|
|
||||||
@ -1084,9 +1075,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
device = TorchDevice.choose_torch_device()
|
||||||
# TODO:
|
|
||||||
device = choose_torch_device()
|
|
||||||
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents.to(device),
|
latents.to(device),
|
||||||
@ -1097,9 +1086,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
resized_latents = resized_latents.to("cpu")
|
resized_latents = resized_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
|
||||||
if device == torch.device("mps"):
|
TorchDevice.empty_cache()
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=resized_latents)
|
name = context.tensors.save(tensor=resized_latents)
|
||||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
@ -1126,8 +1114,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
device = TorchDevice.choose_torch_device()
|
||||||
device = choose_torch_device()
|
|
||||||
|
|
||||||
# resizing
|
# resizing
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
@ -1139,9 +1126,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
resized_latents = resized_latents.to("cpu")
|
resized_latents = resized_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if device == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=resized_latents)
|
name = context.tensors.save(tensor=resized_latents)
|
||||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
@ -1273,8 +1258,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
if latents_a.shape != latents_b.shape:
|
if latents_a.shape != latents_b.shape:
|
||||||
raise Exception("Latents to blend must be the same size.")
|
raise Exception("Latents to blend must be the same size.")
|
||||||
|
|
||||||
# TODO:
|
device = TorchDevice.choose_torch_device()
|
||||||
device = choose_torch_device()
|
|
||||||
|
|
||||||
def slerp(
|
def slerp(
|
||||||
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
||||||
@ -1327,9 +1311,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
blended_latents = blended_latents.to("cpu")
|
blended_latents = blended_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
|
||||||
if device == torch.device("mps"):
|
TorchDevice.empty_cache()
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=blended_latents)
|
name = context.tensors.save(tensor=blended_latents)
|
||||||
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||||
|
@ -9,7 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Laten
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.misc import SEED_MAX
|
from invokeai.app.util.misc import SEED_MAX
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import TorchDevice
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
@ -46,7 +46,7 @@ def get_noise(
|
|||||||
height // downsampling_factor,
|
height // downsampling_factor,
|
||||||
width // downsampling_factor,
|
width // downsampling_factor,
|
||||||
],
|
],
|
||||||
dtype=torch_dtype(device),
|
dtype=TorchDevice.choose_torch_dtype(device=device),
|
||||||
device=noise_device_type,
|
device=noise_device_type,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
).to("cpu")
|
).to("cpu")
|
||||||
@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@field_validator("seed", mode="before")
|
@field_validator("seed", mode="before")
|
||||||
def modulo_seed(cls, v):
|
def modulo_seed(cls, v):
|
||||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||||
return v % (SEED_MAX + 1)
|
return v % (SEED_MAX + 1)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||||
noise = get_noise(
|
noise = get_noise(
|
||||||
width=self.width,
|
width=self.width,
|
||||||
height=self.height,
|
height=self.height,
|
||||||
device=choose_torch_device(),
|
device=TorchDevice.choose_torch_device(),
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
use_cpu=self.use_cpu,
|
use_cpu=self.use_cpu,
|
||||||
)
|
)
|
||||||
|
@ -4,7 +4,6 @@ from typing import Literal
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
@ -14,7 +13,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
|||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
from .fields import InputField, WithBoard, WithMetadata
|
from .fields import InputField, WithBoard, WithMetadata
|
||||||
@ -35,9 +34,6 @@ ESRGAN_MODEL_URLS: dict[str, str] = {
|
|||||||
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
}
|
}
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
from torch import mps
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
|
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
|
||||||
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
@ -120,9 +116,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
upscaled_image = upscaler.upscale(cv2_image)
|
upscaled_image = upscaler.upscale(cv2_image)
|
||||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=pil_image)
|
image_dto = context.images.save(image=pil_image)
|
||||||
|
|
||||||
|
@ -27,12 +27,12 @@ DEFAULT_RAM_CACHE = 10.0
|
|||||||
DEFAULT_VRAM_CACHE = 0.25
|
DEFAULT_VRAM_CACHE = 0.25
|
||||||
DEFAULT_CONVERT_CACHE = 20.0
|
DEFAULT_CONVERT_CACHE = 20.0
|
||||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
|
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||||
CONFIG_SCHEMA_VERSION = "4.0.0"
|
CONFIG_SCHEMA_VERSION = "4.0.1"
|
||||||
|
|
||||||
|
|
||||||
def get_default_ram_cache_size() -> float:
|
def get_default_ram_cache_size() -> float:
|
||||||
@ -105,7 +105,7 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
lazy_offload: Keep models in VRAM until their space is needed.
|
lazy_offload: Keep models in VRAM until their space is needed.
|
||||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
|
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||||
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
||||||
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
||||||
@ -370,6 +370,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
|||||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||||
parsed_config_dict["vram"] = v
|
parsed_config_dict["vram"] = v
|
||||||
|
# autocast was removed in v4.0.1
|
||||||
|
if k == "precision" and v == "autocast":
|
||||||
|
parsed_config_dict["precision"] = "auto"
|
||||||
if k == "conf_path":
|
if k == "conf_path":
|
||||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||||
if k == "legacy_conf_dir":
|
if k == "legacy_conf_dir":
|
||||||
@ -392,6 +395,28 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||||
|
"""Migrate v4.0.0 config dictionary to a current config object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||||
|
"""
|
||||||
|
parsed_config_dict: dict[str, Any] = {}
|
||||||
|
for k, v in config_dict.items():
|
||||||
|
# autocast was removed from precision in v4.0.1
|
||||||
|
if k == "precision" and v == "autocast":
|
||||||
|
parsed_config_dict["precision"] = "auto"
|
||||||
|
else:
|
||||||
|
parsed_config_dict[k] = v
|
||||||
|
if k == "schema_version":
|
||||||
|
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||||
|
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||||
"""Load and migrate a config file to the latest version.
|
"""Load and migrate a config file to the latest version.
|
||||||
|
|
||||||
@ -418,17 +443,21 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
|||||||
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
||||||
migrated_config.write_file(config_path)
|
migrated_config.write_file(config_path)
|
||||||
return migrated_config
|
return migrated_config
|
||||||
else:
|
|
||||||
# Attempt to load as a v4 config file
|
if loaded_config_dict["schema_version"] == "4.0.0":
|
||||||
try:
|
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
||||||
# Meta is not included in the model fields, so we need to validate it separately
|
loaded_config_dict.write_file(config_path)
|
||||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
|
||||||
assert (
|
# Attempt to load as a v4 config file
|
||||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
try:
|
||||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
# Meta is not included in the model fields, so we need to validate it separately
|
||||||
return config
|
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||||
except Exception as e:
|
assert (
|
||||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||||
|
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||||
|
return config
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
|
@ -13,6 +13,7 @@ from shutil import copyfile, copytree, move, rmtree
|
|||||||
from tempfile import mkdtemp
|
from tempfile import mkdtemp
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
@ -42,7 +43,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
|
|||||||
from invokeai.backend.model_manager.probe import ModelProbe
|
from invokeai.backend.model_manager.probe import ModelProbe
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
from invokeai.backend.util import InvokeAILogger
|
from invokeai.backend.util import InvokeAILogger
|
||||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .model_install_base import (
|
from .model_install_base import (
|
||||||
MODEL_SOURCE_TO_TYPE_MAP,
|
MODEL_SOURCE_TO_TYPE_MAP,
|
||||||
@ -634,11 +635,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._next_job_id += 1
|
self._next_job_id += 1
|
||||||
return id
|
return id
|
||||||
|
|
||||||
@staticmethod
|
def _guess_variant(self) -> Optional[ModelRepoVariant]:
|
||||||
def _guess_variant() -> Optional[ModelRepoVariant]:
|
|
||||||
"""Guess the best HuggingFace variant type to download."""
|
"""Guess the best HuggingFace variant type to download."""
|
||||||
precision = choose_precision(choose_torch_device())
|
precision = TorchDevice.choose_torch_dtype()
|
||||||
return ModelRepoVariant.FP16 if precision == "float16" else None
|
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
||||||
|
|
||||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
return ModelInstallJob(
|
return ModelInstallJob(
|
||||||
@ -754,6 +754,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._download_cache[download_job.source] = install_job # matches a download job to an install job
|
self._download_cache[download_job.source] = install_job # matches a download job to an install job
|
||||||
install_job.download_parts.add(download_job)
|
install_job.download_parts.add(download_job)
|
||||||
|
|
||||||
|
# only start the jobs once install_job.download_parts is fully populated
|
||||||
|
for download_job in install_job.download_parts:
|
||||||
self._download_queue.submit_download_job(
|
self._download_queue.submit_download_job(
|
||||||
download_job,
|
download_job,
|
||||||
on_start=self._download_started_callback,
|
on_start=self._download_started_callback,
|
||||||
@ -762,6 +764,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
on_error=self._download_error_callback,
|
on_error=self._download_error_callback,
|
||||||
on_cancelled=self._download_cancelled_callback,
|
on_cancelled=self._download_cancelled_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
return install_job
|
return install_job
|
||||||
|
|
||||||
def _stat_size(self, path: Path) -> int:
|
def _stat_size(self, path: Path) -> int:
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
"""Implementation of ModelManagerServiceBase."""
|
"""Implementation of ModelManagerServiceBase."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from ..config import InvokeAIAppConfig
|
from ..config import InvokeAIAppConfig
|
||||||
@ -67,7 +69,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_record_service: ModelRecordServiceBase,
|
model_record_service: ModelRecordServiceBase,
|
||||||
download_queue: DownloadQueueServiceBase,
|
download_queue: DownloadQueueServiceBase,
|
||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
execution_device: torch.device = choose_torch_device(),
|
execution_device: Optional[torch.device] = None,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""
|
"""
|
||||||
Construct the model manager service instance.
|
Construct the model manager service instance.
|
||||||
@ -82,7 +84,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
max_vram_cache_size=app_config.vram,
|
max_vram_cache_size=app_config.vram,
|
||||||
lazy_offloading=app_config.lazy_offload,
|
lazy_offloading=app_config.lazy_offload,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
execution_device=execution_device,
|
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||||
)
|
)
|
||||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||||
loader = ModelLoadService(
|
loader = ModelLoadService(
|
||||||
|
@ -13,7 +13,7 @@ from invokeai.app.services.config.config_default import get_config
|
|||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
@ -56,7 +56,7 @@ class DepthAnythingDetector:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.model = None
|
self.model = None
|
||||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||||
self.device = choose_torch_device()
|
self.device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
||||||
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
|
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
|
||||||
@ -81,7 +81,7 @@ class DepthAnythingDetector:
|
|||||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
self.model.to(choose_torch_device())
|
self.model.to(self.device)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||||
@ -94,7 +94,7 @@ class DepthAnythingDetector:
|
|||||||
|
|
||||||
image_height, image_width = np_image.shape[:2]
|
image_height, image_width = np_image.shape[:2]
|
||||||
np_image = transform({"image": np_image})["image"]
|
np_image = transform({"image": np_image})["image"]
|
||||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
|
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
depth = self.model(tensor_image)
|
depth = self.model(tensor_image)
|
||||||
|
@ -7,7 +7,7 @@ import onnxruntime as ort
|
|||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .onnxdet import inference_detector
|
from .onnxdet import inference_detector
|
||||||
from .onnxpose import inference_pose
|
from .onnxpose import inference_pose
|
||||||
@ -28,9 +28,9 @@ config = get_config()
|
|||||||
|
|
||||||
class Wholebody:
|
class Wholebody:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
device = choose_torch_device()
|
device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
|
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||||
|
|
||||||
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
|
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
|
||||||
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
||||||
|
@ -8,7 +8,7 @@ from PIL import Image
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
def norm_img(np_img):
|
def norm_img(np_img):
|
||||||
@ -29,7 +29,7 @@ def load_jit_model(url_or_path, device):
|
|||||||
|
|
||||||
class LaMA:
|
class LaMA:
|
||||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||||
device = choose_torch_device()
|
device = TorchDevice.choose_torch_device()
|
||||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
||||||
|
|
||||||
if not model_location.exists():
|
if not model_location.exists():
|
||||||
|
@ -11,7 +11,7 @@ from cv2.typing import MatLike
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
|
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
|
||||||
@ -65,7 +65,7 @@ class RealESRGAN:
|
|||||||
self.pre_pad = pre_pad
|
self.pre_pad = pre_pad
|
||||||
self.mod_scale: Optional[int] = None
|
self.mod_scale: Optional[int] = None
|
||||||
self.half = half
|
self.half = half
|
||||||
self.device = choose_torch_device()
|
self.device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from transformers import AutoFeatureExtractor
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||||
@ -51,7 +51,7 @@ class SafetyChecker:
|
|||||||
cls._load_safety_checker()
|
cls._load_safety_checker()
|
||||||
if cls.safety_checker is None or cls.feature_extractor is None:
|
if cls.safety_checker is None or cls.feature_extractor is None:
|
||||||
return False
|
return False
|
||||||
device = choose_torch_device()
|
device = TorchDevice.choose_torch_device()
|
||||||
features = cls.feature_extractor([image], return_tensors="pt")
|
features = cls.feature_extractor([image], return_tensors="pt")
|
||||||
features.to(device)
|
features.to(device)
|
||||||
cls.safety_checker.to(device)
|
cls.safety_checker.to(device)
|
||||||
|
@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoad
|
|||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
# TO DO: The loader is not thread safe!
|
# TO DO: The loader is not thread safe!
|
||||||
@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._ram_cache = ram_cache
|
self._ram_cache = ram_cache
|
||||||
self._convert_cache = convert_cache
|
self._convert_cache = convert_cache
|
||||||
self._torch_dtype = torch_dtype(choose_torch_device())
|
self._torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
|
@ -30,15 +30,12 @@ import torch
|
|||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
||||||
from .model_locker import ModelLocker
|
from .model_locker import ModelLocker
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
from torch import mps
|
|
||||||
|
|
||||||
# Maximum size of the cache, in gigs
|
# Maximum size of the cache, in gigs
|
||||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||||
@ -244,9 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||||
"""Move model into the indicated device.
|
"""Move model into the indicated device.
|
||||||
@ -416,10 +411,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self.stats.cleared = models_cleared
|
self.stats.cleared = models_cleared
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||||
|
|
||||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||||
|
@ -17,7 +17,7 @@ from diffusers.utils import logging as dlogging
|
|||||||
|
|
||||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -43,6 +43,7 @@ class ModelMerger(object):
|
|||||||
Initialize a ModelMerger object with the model installer.
|
Initialize a ModelMerger object with the model installer.
|
||||||
"""
|
"""
|
||||||
self._installer = installer
|
self._installer = installer
|
||||||
|
self._dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
def merge_diffusion_models(
|
def merge_diffusion_models(
|
||||||
self,
|
self,
|
||||||
@ -68,7 +69,7 @@ class ModelMerger(object):
|
|||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
dlogging.set_verbosity_error()
|
dlogging.set_verbosity_error()
|
||||||
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
dtype = torch.float16 if variant == "fp16" else self._dtype
|
||||||
|
|
||||||
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
|
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
|
||||||
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
|
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
|
||||||
@ -151,7 +152,7 @@ class ModelMerger(object):
|
|||||||
dump_path.mkdir(parents=True, exist_ok=True)
|
dump_path.mkdir(parents=True, exist_ok=True)
|
||||||
dump_path = dump_path / merged_model_name
|
dump_path = dump_path / merged_model_name
|
||||||
|
|
||||||
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
dtype = torch.float16 if variant == "fp16" else self._dtype
|
||||||
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
|
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
|
||||||
|
|
||||||
# register model and get its unique key
|
# register model and get its unique key
|
||||||
|
@ -25,7 +25,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdap
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||||
from invokeai.backend.util.devices import normalize_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -255,7 +255,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
||||||
mem_free = psutil.virtual_memory().free
|
mem_free = psutil.virtual_memory().free
|
||||||
elif self.unet.device.type == "cuda":
|
elif self.unet.device.type == "cuda":
|
||||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
|
mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unrecognized device {self.unet.device}")
|
raise ValueError(f"unrecognized device {self.unet.device}")
|
||||||
# input tensor of [1, 4, h/8, w/8]
|
# input tensor of [1, 4, h/8, w/8]
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
Initialization file for invokeai.backend.util
|
Initialization file for invokeai.backend.util
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .devices import choose_precision, choose_torch_device
|
|
||||||
from .logging import InvokeAILogger
|
from .logging import InvokeAILogger
|
||||||
from .util import GIG, Chdir, directory_size
|
from .util import GIG, Chdir, directory_size
|
||||||
|
|
||||||
@ -11,6 +10,4 @@ __all__ = [
|
|||||||
"directory_size",
|
"directory_size",
|
||||||
"Chdir",
|
"Chdir",
|
||||||
"InvokeAILogger",
|
"InvokeAILogger",
|
||||||
"choose_precision",
|
|
||||||
"choose_torch_device",
|
|
||||||
]
|
]
|
||||||
|
@ -1,89 +1,110 @@
|
|||||||
from __future__ import annotations
|
from typing import Dict, Literal, Optional, Union
|
||||||
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from typing import Literal, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from deprecated import deprecated
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import PRECISION, get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
|
||||||
|
# legacy APIs
|
||||||
|
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
CUDA_DEVICE = torch.device("cuda")
|
CUDA_DEVICE = torch.device("cuda")
|
||||||
MPS_DEVICE = torch.device("mps")
|
MPS_DEVICE = torch.device("mps")
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
||||||
|
def choose_precision(device: torch.device) -> TorchPrecisionNames:
|
||||||
|
"""Return the string representation of the recommended torch device."""
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype(device)
|
||||||
|
return PRECISION_TO_NAME[torch_dtype]
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore
|
||||||
def choose_torch_device() -> torch.device:
|
def choose_torch_device() -> torch.device:
|
||||||
"""Convenience routine for guessing which GPU device to run model on"""
|
"""Return the torch.device to use for accelerated inference."""
|
||||||
config = get_config()
|
return TorchDevice.choose_torch_device()
|
||||||
if config.device == "auto":
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return torch.device("cuda")
|
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
||||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||||
return torch.device("mps")
|
"""Return the torch precision for the recommended torch device."""
|
||||||
|
return TorchDevice.choose_torch_dtype(device)
|
||||||
|
|
||||||
|
|
||||||
|
NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
}
|
||||||
|
PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()}
|
||||||
|
|
||||||
|
|
||||||
|
class TorchDevice:
|
||||||
|
"""Abstraction layer for torch devices."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def choose_torch_device(cls) -> torch.device:
|
||||||
|
"""Return the torch.device to use for accelerated inference."""
|
||||||
|
app_config = get_config()
|
||||||
|
if app_config.device != "auto":
|
||||||
|
device = torch.device(app_config.device)
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = CUDA_DEVICE
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = MPS_DEVICE
|
||||||
else:
|
else:
|
||||||
return CPU_DEVICE
|
device = CPU_DEVICE
|
||||||
else:
|
return cls.normalize(device)
|
||||||
return torch.device(config.device)
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
||||||
|
"""Return the precision to use for accelerated inference."""
|
||||||
|
device = device or cls.choose_torch_device()
|
||||||
|
config = get_config()
|
||||||
|
if device.type == "cuda" and torch.cuda.is_available():
|
||||||
|
device_name = torch.cuda.get_device_name(device)
|
||||||
|
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||||
|
# These GPUs have limited support for float16
|
||||||
|
return cls._to_dtype("float32")
|
||||||
|
elif config.precision == "auto":
|
||||||
|
# Default to float16 for CUDA devices
|
||||||
|
return cls._to_dtype("float16")
|
||||||
|
else:
|
||||||
|
# Use the user-defined precision
|
||||||
|
return cls._to_dtype(config.precision)
|
||||||
|
|
||||||
def get_torch_device_name() -> str:
|
elif device.type == "mps" and torch.backends.mps.is_available():
|
||||||
device = choose_torch_device()
|
if config.precision == "auto":
|
||||||
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
# Default to float16 for MPS devices
|
||||||
|
return cls._to_dtype("float16")
|
||||||
|
else:
|
||||||
|
# Use the user-defined precision
|
||||||
|
return cls._to_dtype(config.precision)
|
||||||
|
# CPU / safe fallback
|
||||||
|
return cls._to_dtype("float32")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_torch_device_name(cls) -> str:
|
||||||
|
"""Return the device name for the current torch device."""
|
||||||
|
device = cls.choose_torch_device()
|
||||||
|
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
||||||
|
|
||||||
def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]:
|
@classmethod
|
||||||
"""Return an appropriate precision for the given torch device."""
|
def normalize(cls, device: Union[str, torch.device]) -> torch.device:
|
||||||
app_config = get_config()
|
"""Add the device index to CUDA devices."""
|
||||||
if device.type == "cuda":
|
device = torch.device(device)
|
||||||
device_name = torch.cuda.get_device_name(device)
|
if device.index is None and device.type == "cuda" and torch.cuda.is_available():
|
||||||
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
|
||||||
# These GPUs have limited support for float16
|
|
||||||
return "float32"
|
|
||||||
elif app_config.precision == "auto" or app_config.precision == "autocast":
|
|
||||||
# Default to float16 for CUDA devices
|
|
||||||
return "float16"
|
|
||||||
else:
|
|
||||||
# Use the user-defined precision
|
|
||||||
return app_config.precision
|
|
||||||
elif device.type == "mps":
|
|
||||||
if app_config.precision == "auto" or app_config.precision == "autocast":
|
|
||||||
# Default to float16 for MPS devices
|
|
||||||
return "float16"
|
|
||||||
else:
|
|
||||||
# Use the user-defined precision
|
|
||||||
return app_config.precision
|
|
||||||
# CPU / safe fallback
|
|
||||||
return "float32"
|
|
||||||
|
|
||||||
|
|
||||||
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
|
|
||||||
device = device or choose_torch_device()
|
|
||||||
precision = choose_precision(device)
|
|
||||||
if precision == "float16":
|
|
||||||
return torch.float16
|
|
||||||
if precision == "bfloat16":
|
|
||||||
return torch.bfloat16
|
|
||||||
else:
|
|
||||||
# "auto", "autocast", "float32"
|
|
||||||
return torch.float32
|
|
||||||
|
|
||||||
|
|
||||||
def choose_autocast(precision: PRECISION):
|
|
||||||
"""Returns an autocast context or nullcontext for the given precision string"""
|
|
||||||
# float16 currently requires autocast to avoid errors like:
|
|
||||||
# 'expected scalar type Half but found Float'
|
|
||||||
if precision == "autocast" or precision == "float16":
|
|
||||||
return autocast
|
|
||||||
return nullcontext
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_device(device: Union[str, torch.device]) -> torch.device:
|
|
||||||
"""Ensure device has a device index defined, if appropriate."""
|
|
||||||
device = torch.device(device)
|
|
||||||
if device.index is None:
|
|
||||||
# cuda might be the only torch backend that currently uses the device index?
|
|
||||||
# I don't see anything like `current_device` for cpu or mps.
|
|
||||||
if device.type == "cuda":
|
|
||||||
device = torch.device(device.type, torch.cuda.current_device())
|
device = torch.device(device.type, torch.cuda.current_device())
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty_cache(cls) -> None:
|
||||||
|
"""Clear the GPU device cache."""
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
||||||
|
return NAME_TO_PRECISION[precision_name]
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
<meta http-equiv="Pragma" content="no-cache">
|
<meta http-equiv="Pragma" content="no-cache">
|
||||||
<meta http-equiv="Expires" content="0">
|
<meta http-equiv="Expires" content="0">
|
||||||
<title>Invoke - Community Edition</title>
|
<title>Invoke - Community Edition</title>
|
||||||
<link rel="icon" type="icon" href="assets/images/invoke-favicon.svg" />
|
<link id="invoke-favicon" rel="icon" type="icon" href="assets/images/invoke-favicon.svg" />
|
||||||
<style>
|
<style>
|
||||||
html,
|
html,
|
||||||
body {
|
body {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import type { KnipConfig } from 'knip';
|
import type { KnipConfig } from 'knip';
|
||||||
|
|
||||||
const config: KnipConfig = {
|
const config: KnipConfig = {
|
||||||
|
project: ['src/**/*.{ts,tsx}!'],
|
||||||
ignore: [
|
ignore: [
|
||||||
// This file is only used during debugging
|
// This file is only used during debugging
|
||||||
'src/app/store/middleware/debugLoggerMiddleware.ts',
|
'src/app/store/middleware/debugLoggerMiddleware.ts',
|
||||||
@ -10,6 +11,9 @@ const config: KnipConfig = {
|
|||||||
'src/features/nodes/types/v2/**',
|
'src/features/nodes/types/v2/**',
|
||||||
],
|
],
|
||||||
ignoreBinaries: ['only-allow'],
|
ignoreBinaries: ['only-allow'],
|
||||||
|
paths: {
|
||||||
|
'public/*': ['public/*'],
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
export default config;
|
export default config;
|
||||||
|
@ -24,7 +24,7 @@
|
|||||||
"build": "pnpm run lint && vite build",
|
"build": "pnpm run lint && vite build",
|
||||||
"typegen": "node scripts/typegen.js",
|
"typegen": "node scripts/typegen.js",
|
||||||
"preview": "vite preview",
|
"preview": "vite preview",
|
||||||
"lint:knip": "knip --tags=-@knipignore",
|
"lint:knip": "knip",
|
||||||
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:1 src/main.tsx",
|
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:1 src/main.tsx",
|
||||||
"lint:eslint": "eslint --max-warnings=0 .",
|
"lint:eslint": "eslint --max-warnings=0 .",
|
||||||
"lint:prettier": "prettier --check .",
|
"lint:prettier": "prettier --check .",
|
||||||
|
@ -0,0 +1,5 @@
|
|||||||
|
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<rect width="16" height="16" rx="2" fill="#E6FD13"/>
|
||||||
|
<path d="M9.61889 5.45H12.5V3.5H3.5V5.45H6.38111L9.61889 10.55H12.5V12.5H3.5V10.55H6.38111" stroke="black"/>
|
||||||
|
<circle cx="12" cy="4" r="3" fill="#f5480c" stroke="#0d1117" stroke-width="1"/>
|
||||||
|
</svg>
|
After Width: | Height: | Size: 345 B |
@ -330,7 +330,8 @@
|
|||||||
"drop": "Drop",
|
"drop": "Drop",
|
||||||
"dropOrUpload": "$t(gallery.drop) or Upload",
|
"dropOrUpload": "$t(gallery.drop) or Upload",
|
||||||
"dropToUpload": "$t(gallery.drop) to Upload",
|
"dropToUpload": "$t(gallery.drop) to Upload",
|
||||||
"deleteImage": "Delete Image",
|
"deleteImage_one": "Delete Image",
|
||||||
|
"deleteImage_other": "Delete {{count}} Images",
|
||||||
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
||||||
"deleteImagePermanent": "Deleted images cannot be restored.",
|
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||||
"download": "Download",
|
"download": "Download",
|
||||||
@ -773,6 +774,8 @@
|
|||||||
"float": "Float",
|
"float": "Float",
|
||||||
"fullyContainNodes": "Fully Contain Nodes to Select",
|
"fullyContainNodes": "Fully Contain Nodes to Select",
|
||||||
"fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected",
|
"fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected",
|
||||||
|
"showEdgeLabels": "Show Edge Labels",
|
||||||
|
"showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes",
|
||||||
"hideLegendNodes": "Hide Field Type Legend",
|
"hideLegendNodes": "Hide Field Type Legend",
|
||||||
"hideMinimapnodes": "Hide MiniMap",
|
"hideMinimapnodes": "Hide MiniMap",
|
||||||
"inputMayOnlyHaveOneConnection": "Input may only have one connection",
|
"inputMayOnlyHaveOneConnection": "Input may only have one connection",
|
||||||
@ -1428,6 +1431,7 @@
|
|||||||
"eraseBoundingBox": "Erase Bounding Box",
|
"eraseBoundingBox": "Erase Bounding Box",
|
||||||
"eraser": "Eraser",
|
"eraser": "Eraser",
|
||||||
"fillBoundingBox": "Fill Bounding Box",
|
"fillBoundingBox": "Fill Bounding Box",
|
||||||
|
"hideBoundingBox": "Hide Bounding Box",
|
||||||
"initialFitImageSize": "Fit Image Size on Drop",
|
"initialFitImageSize": "Fit Image Size on Drop",
|
||||||
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
|
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
|
||||||
"layer": "Layer",
|
"layer": "Layer",
|
||||||
@ -1445,6 +1449,7 @@
|
|||||||
"saveMask": "Save $t(unifiedCanvas.mask)",
|
"saveMask": "Save $t(unifiedCanvas.mask)",
|
||||||
"saveToGallery": "Save To Gallery",
|
"saveToGallery": "Save To Gallery",
|
||||||
"scaledBoundingBox": "Scaled Bounding Box",
|
"scaledBoundingBox": "Scaled Bounding Box",
|
||||||
|
"showBoundingBox": "Show Bounding Box",
|
||||||
"showCanvasDebugInfo": "Show Additional Canvas Info",
|
"showCanvasDebugInfo": "Show Additional Canvas Info",
|
||||||
"showGrid": "Show Grid",
|
"showGrid": "Show Grid",
|
||||||
"showResultsOn": "Show Results (On)",
|
"showResultsOn": "Show Results (On)",
|
||||||
|
@ -444,7 +444,8 @@
|
|||||||
"hfTokenInvalidErrorMessage2": "Aggiornalo in ",
|
"hfTokenInvalidErrorMessage2": "Aggiornalo in ",
|
||||||
"main": "Principali",
|
"main": "Principali",
|
||||||
"noModelsInstalledDesc1": "Installa i modelli con",
|
"noModelsInstalledDesc1": "Installa i modelli con",
|
||||||
"ipAdapters": "Adattatori IP"
|
"ipAdapters": "Adattatori IP",
|
||||||
|
"noMatchingModels": "Nessun modello corrispondente"
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "Immagini",
|
"images": "Immagini",
|
||||||
@ -526,7 +527,12 @@
|
|||||||
"aspect": "Aspetto",
|
"aspect": "Aspetto",
|
||||||
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (potrebbe essere troppo grande)",
|
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (potrebbe essere troppo grande)",
|
||||||
"remixImage": "Remixa l'immagine",
|
"remixImage": "Remixa l'immagine",
|
||||||
"coherenceEdgeSize": "Dim. bordo"
|
"coherenceEdgeSize": "Dim. bordo",
|
||||||
|
"infillMosaicTileWidth": "Larghezza piastrella",
|
||||||
|
"infillMosaicMinColor": "Colore minimo",
|
||||||
|
"infillMosaicMaxColor": "Colore massimo",
|
||||||
|
"infillMosaicTileHeight": "Altezza piastrella",
|
||||||
|
"infillColorValue": "Colore di riempimento"
|
||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
"models": "Modelli",
|
"models": "Modelli",
|
||||||
@ -620,7 +626,8 @@
|
|||||||
"uploadInitialImage": "Carica l'immagine iniziale",
|
"uploadInitialImage": "Carica l'immagine iniziale",
|
||||||
"problemDownloadingImage": "Impossibile scaricare l'immagine",
|
"problemDownloadingImage": "Impossibile scaricare l'immagine",
|
||||||
"prunedQueue": "Coda ripulita",
|
"prunedQueue": "Coda ripulita",
|
||||||
"modelImportCanceled": "Importazione del modello annullata"
|
"modelImportCanceled": "Importazione del modello annullata",
|
||||||
|
"parameters": "Parametri"
|
||||||
},
|
},
|
||||||
"tooltip": {
|
"tooltip": {
|
||||||
"feature": {
|
"feature": {
|
||||||
@ -689,7 +696,10 @@
|
|||||||
"coherenceModeBoxBlur": "Sfocatura Box",
|
"coherenceModeBoxBlur": "Sfocatura Box",
|
||||||
"coherenceModeStaged": "Maschera espansa",
|
"coherenceModeStaged": "Maschera espansa",
|
||||||
"invertBrushSizeScrollDirection": "Inverti scorrimento per dimensione pennello",
|
"invertBrushSizeScrollDirection": "Inverti scorrimento per dimensione pennello",
|
||||||
"discardCurrent": "Scarta l'attuale"
|
"discardCurrent": "Scarta l'attuale",
|
||||||
|
"initialFitImageSize": "Adatta dimensione immagine al rilascio",
|
||||||
|
"hideBoundingBox": "Nascondi il rettangolo di selezione",
|
||||||
|
"showBoundingBox": "Mostra il rettangolo di selezione"
|
||||||
},
|
},
|
||||||
"accessibility": {
|
"accessibility": {
|
||||||
"invokeProgressBar": "Barra di avanzamento generazione",
|
"invokeProgressBar": "Barra di avanzamento generazione",
|
||||||
@ -832,7 +842,8 @@
|
|||||||
"editMode": "Modifica nell'editor del flusso di lavoro",
|
"editMode": "Modifica nell'editor del flusso di lavoro",
|
||||||
"resetToDefaultValue": "Ripristina il valore predefinito",
|
"resetToDefaultValue": "Ripristina il valore predefinito",
|
||||||
"noFieldsViewMode": "Questo flusso di lavoro non ha campi selezionati da visualizzare. Visualizza il flusso di lavoro completo per configurare i valori.",
|
"noFieldsViewMode": "Questo flusso di lavoro non ha campi selezionati da visualizzare. Visualizza il flusso di lavoro completo per configurare i valori.",
|
||||||
"edit": "Modifica"
|
"edit": "Modifica",
|
||||||
|
"graph": "Grafico"
|
||||||
},
|
},
|
||||||
"boards": {
|
"boards": {
|
||||||
"autoAddBoard": "Aggiungi automaticamente bacheca",
|
"autoAddBoard": "Aggiungi automaticamente bacheca",
|
||||||
@ -1346,13 +1357,13 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"seamlessTilingXAxis": {
|
"seamlessTilingXAxis": {
|
||||||
"heading": "Asse X di piastrellatura senza cuciture",
|
"heading": "Piastrella senza giunte sull'asse X",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Affianca senza soluzione di continuità un'immagine lungo l'asse orizzontale."
|
"Affianca senza soluzione di continuità un'immagine lungo l'asse orizzontale."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"seamlessTilingYAxis": {
|
"seamlessTilingYAxis": {
|
||||||
"heading": "Asse Y di piastrellatura senza cuciture",
|
"heading": "Piastrella senza giunte sull'asse Y",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Affianca senza soluzione di continuità un'immagine lungo l'asse verticale."
|
"Affianca senza soluzione di continuità un'immagine lungo l'asse verticale."
|
||||||
]
|
]
|
||||||
@ -1476,7 +1487,11 @@
|
|||||||
"name": "Nome",
|
"name": "Nome",
|
||||||
"updated": "Aggiornato",
|
"updated": "Aggiornato",
|
||||||
"projectWorkflows": "Flussi di lavoro del progetto",
|
"projectWorkflows": "Flussi di lavoro del progetto",
|
||||||
"opened": "Aperto"
|
"opened": "Aperto",
|
||||||
|
"convertGraph": "Converti grafico",
|
||||||
|
"loadWorkflow": "$t(common.load) Flusso di lavoro",
|
||||||
|
"autoLayout": "Disposizione automatica",
|
||||||
|
"loadFromGraph": "Carica il flusso di lavoro dal grafico"
|
||||||
},
|
},
|
||||||
"app": {
|
"app": {
|
||||||
"storeNotInitialized": "Il negozio non è inizializzato"
|
"storeNotInitialized": "Il negozio non è inizializzato"
|
||||||
|
@ -448,7 +448,9 @@
|
|||||||
"loraModels": "LoRAs",
|
"loraModels": "LoRAs",
|
||||||
"main": "Основные",
|
"main": "Основные",
|
||||||
"noModelsInstalled": "Нет установленных моделей",
|
"noModelsInstalled": "Нет установленных моделей",
|
||||||
"noModelsInstalledDesc1": "Установите модели с помощью"
|
"noModelsInstalledDesc1": "Установите модели с помощью",
|
||||||
|
"noMatchingModels": "Нет подходящих моделей",
|
||||||
|
"ipAdapters": "IP адаптеры"
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "Изображения",
|
"images": "Изображения",
|
||||||
@ -532,7 +534,12 @@
|
|||||||
"lockAspectRatio": "Заблокировать соотношение",
|
"lockAspectRatio": "Заблокировать соотношение",
|
||||||
"remixImage": "Ремикс изображения",
|
"remixImage": "Ремикс изображения",
|
||||||
"coherenceMinDenoise": "Мин. шумоподавление",
|
"coherenceMinDenoise": "Мин. шумоподавление",
|
||||||
"coherenceEdgeSize": "Размер края"
|
"coherenceEdgeSize": "Размер края",
|
||||||
|
"infillMosaicTileWidth": "Ширина плиток",
|
||||||
|
"infillMosaicTileHeight": "Высота плиток",
|
||||||
|
"infillMosaicMinColor": "Мин цвет",
|
||||||
|
"infillMosaicMaxColor": "Макс цвет",
|
||||||
|
"infillColorValue": "Цвет заливки"
|
||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
"models": "Модели",
|
"models": "Модели",
|
||||||
@ -626,7 +633,8 @@
|
|||||||
"uploadInitialImage": "Загрузить начальное изображение",
|
"uploadInitialImage": "Загрузить начальное изображение",
|
||||||
"resetInitialImage": "Сбросить начальное изображение",
|
"resetInitialImage": "Сбросить начальное изображение",
|
||||||
"prunedQueue": "Урезанная очередь",
|
"prunedQueue": "Урезанная очередь",
|
||||||
"modelImportCanceled": "Импорт модели отменен"
|
"modelImportCanceled": "Импорт модели отменен",
|
||||||
|
"parameters": "Параметры"
|
||||||
},
|
},
|
||||||
"tooltip": {
|
"tooltip": {
|
||||||
"feature": {
|
"feature": {
|
||||||
@ -695,7 +703,8 @@
|
|||||||
"coherenceModeGaussianBlur": "Размытие по Гауссу",
|
"coherenceModeGaussianBlur": "Размытие по Гауссу",
|
||||||
"coherenceModeBoxBlur": "коробчатое размытие",
|
"coherenceModeBoxBlur": "коробчатое размытие",
|
||||||
"discardCurrent": "Отбросить текущее",
|
"discardCurrent": "Отбросить текущее",
|
||||||
"invertBrushSizeScrollDirection": "Инвертировать прокрутку для размера кисти"
|
"invertBrushSizeScrollDirection": "Инвертировать прокрутку для размера кисти",
|
||||||
|
"initialFitImageSize": "Подогнать размер изображения при перебросе"
|
||||||
},
|
},
|
||||||
"accessibility": {
|
"accessibility": {
|
||||||
"uploadImage": "Загрузить изображение",
|
"uploadImage": "Загрузить изображение",
|
||||||
@ -921,7 +930,8 @@
|
|||||||
"modelSize": "Размер модели",
|
"modelSize": "Размер модели",
|
||||||
"small": "Маленький",
|
"small": "Маленький",
|
||||||
"body": "Тело",
|
"body": "Тело",
|
||||||
"hands": "Руки"
|
"hands": "Руки",
|
||||||
|
"selectCLIPVisionModel": "Выбрать модель CLIP Vision"
|
||||||
},
|
},
|
||||||
"boards": {
|
"boards": {
|
||||||
"autoAddBoard": "Авто добавление Доски",
|
"autoAddBoard": "Авто добавление Доски",
|
||||||
|
@ -65,7 +65,12 @@
|
|||||||
"nextPage": "下一页",
|
"nextPage": "下一页",
|
||||||
"saveAs": "保存为",
|
"saveAs": "保存为",
|
||||||
"ai": "ai",
|
"ai": "ai",
|
||||||
"or": "或"
|
"or": "或",
|
||||||
|
"aboutDesc": "使用 Invoke 工作?查看:",
|
||||||
|
"add": "添加",
|
||||||
|
"loglevel": "日志级别",
|
||||||
|
"copy": "复制",
|
||||||
|
"localSystem": "本地系统"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"galleryImageSize": "预览大小",
|
"galleryImageSize": "预览大小",
|
||||||
@ -599,7 +604,8 @@
|
|||||||
"loadMore": "加载更多",
|
"loadMore": "加载更多",
|
||||||
"mode": "模式",
|
"mode": "模式",
|
||||||
"resetUI": "$t(accessibility.reset) UI",
|
"resetUI": "$t(accessibility.reset) UI",
|
||||||
"createIssue": "创建问题"
|
"createIssue": "创建问题",
|
||||||
|
"about": "关于"
|
||||||
},
|
},
|
||||||
"tooltip": {
|
"tooltip": {
|
||||||
"feature": {
|
"feature": {
|
||||||
@ -1201,7 +1207,16 @@
|
|||||||
"workflows": "工作流",
|
"workflows": "工作流",
|
||||||
"noDescription": "无描述",
|
"noDescription": "无描述",
|
||||||
"uploadWorkflow": "从文件中加载",
|
"uploadWorkflow": "从文件中加载",
|
||||||
"newWorkflowCreated": "已创建新的工作流"
|
"newWorkflowCreated": "已创建新的工作流",
|
||||||
|
"name": "名称",
|
||||||
|
"defaultWorkflows": "默认工作流",
|
||||||
|
"created": "已创建",
|
||||||
|
"ascending": "升序",
|
||||||
|
"descending": "降序",
|
||||||
|
"updated": "已更新",
|
||||||
|
"userWorkflows": "我的工作流",
|
||||||
|
"projectWorkflows": "项目工作流",
|
||||||
|
"opened": "已打开"
|
||||||
},
|
},
|
||||||
"app": {
|
"app": {
|
||||||
"storeNotInitialized": "商店尚未初始化"
|
"storeNotInitialized": "商店尚未初始化"
|
||||||
@ -1219,7 +1234,8 @@
|
|||||||
"title": "生成"
|
"title": "生成"
|
||||||
},
|
},
|
||||||
"advanced": {
|
"advanced": {
|
||||||
"title": "高级"
|
"title": "高级",
|
||||||
|
"options": "$t(accordions.advanced.title) 选项"
|
||||||
},
|
},
|
||||||
"image": {
|
"image": {
|
||||||
"title": "图像"
|
"title": "图像"
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { Box, useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
import { Box, useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
||||||
import { useSocketIO } from 'app/hooks/useSocketIO';
|
import { useSocketIO } from 'app/hooks/useSocketIO';
|
||||||
|
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
|
||||||
import { useLogger } from 'app/logging/useLogger';
|
import { useLogger } from 'app/logging/useLogger';
|
||||||
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
@ -70,6 +71,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
|
|||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
useStarterModelsToast();
|
useStarterModelsToast();
|
||||||
|
useSyncQueueStatus();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||||
|
25
invokeai/frontend/web/src/app/hooks/useSyncQueueStatus.ts
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import { useEffect } from 'react';
|
||||||
|
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||||
|
|
||||||
|
const baseTitle = document.title;
|
||||||
|
const invokeLogoSVG = 'assets/images/invoke-favicon.svg';
|
||||||
|
const invokeAlertLogoSVG = 'assets/images/invoke-alert-favicon.svg';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This hook synchronizes the queue status with the page's title and favicon.
|
||||||
|
* It should be considered a singleton and only used once in the component tree.
|
||||||
|
*/
|
||||||
|
export const useSyncQueueStatus = () => {
|
||||||
|
const { queueSize } = useGetQueueStatusQuery(undefined, {
|
||||||
|
selectFromResult: (res) => ({
|
||||||
|
queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
useEffect(() => {
|
||||||
|
document.title = queueSize > 0 ? `(${queueSize}) ${baseTitle}` : baseTitle;
|
||||||
|
const faviconEl = document.getElementById('invoke-favicon');
|
||||||
|
if (faviconEl instanceof HTMLLinkElement) {
|
||||||
|
faviconEl.href = queueSize > 0 ? invokeAlertLogoSVG : invokeLogoSVG;
|
||||||
|
}
|
||||||
|
}, [queueSize]);
|
||||||
|
};
|
@ -1,5 +1,4 @@
|
|||||||
import { Flex, Image, Spinner } from '@invoke-ai/ui-library';
|
import { Flex, Image, Spinner } from '@invoke-ai/ui-library';
|
||||||
/** @knipignore */
|
|
||||||
import InvokeLogoWhite from 'public/assets/images/invoke-symbol-wht-lrg.svg';
|
import InvokeLogoWhite from 'public/assets/images/invoke-symbol-wht-lrg.svg';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
|||||||
|
|
||||||
export const useGlobalHotkeys = () => {
|
export const useGlobalHotkeys = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const isModelManagerEnabled = useFeatureStatus('modelManager').isFeatureEnabled;
|
const isModelManagerEnabled = useFeatureStatus('modelManager');
|
||||||
const { queueBack, isDisabled: isDisabledQueueBack, isLoading: isLoadingQueueBack } = useQueueBack();
|
const { queueBack, isDisabled: isDisabledQueueBack, isLoading: isLoadingQueueBack } = useQueueBack();
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
|
@ -13,7 +13,13 @@ import {
|
|||||||
} from 'features/canvas/store/actions';
|
} from 'features/canvas/store/actions';
|
||||||
import { $canvasBaseLayer, $tool } from 'features/canvas/store/canvasNanostore';
|
import { $canvasBaseLayer, $tool } from 'features/canvas/store/canvasNanostore';
|
||||||
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import { resetCanvas, resetCanvasView, setIsMaskEnabled, setLayer } from 'features/canvas/store/canvasSlice';
|
import {
|
||||||
|
resetCanvas,
|
||||||
|
resetCanvasView,
|
||||||
|
setIsMaskEnabled,
|
||||||
|
setLayer,
|
||||||
|
setShouldShowBoundingBox,
|
||||||
|
} from 'features/canvas/store/canvasSlice';
|
||||||
import type { CanvasLayer } from 'features/canvas/store/canvasTypes';
|
import type { CanvasLayer } from 'features/canvas/store/canvasTypes';
|
||||||
import { LAYER_NAMES_DICT } from 'features/canvas/store/canvasTypes';
|
import { LAYER_NAMES_DICT } from 'features/canvas/store/canvasTypes';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
@ -23,6 +29,8 @@ import {
|
|||||||
PiCopyBold,
|
PiCopyBold,
|
||||||
PiCrosshairSimpleBold,
|
PiCrosshairSimpleBold,
|
||||||
PiDownloadSimpleBold,
|
PiDownloadSimpleBold,
|
||||||
|
PiEyeBold,
|
||||||
|
PiEyeSlashBold,
|
||||||
PiFloppyDiskBold,
|
PiFloppyDiskBold,
|
||||||
PiHandGrabbingBold,
|
PiHandGrabbingBold,
|
||||||
PiStackBold,
|
PiStackBold,
|
||||||
@ -44,6 +52,7 @@ const IAICanvasToolbar = () => {
|
|||||||
const isStaging = useAppSelector(isStagingSelector);
|
const isStaging = useAppSelector(isStagingSelector);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { isClipboardAPIAvailable } = useCopyImageToClipboard();
|
const { isClipboardAPIAvailable } = useCopyImageToClipboard();
|
||||||
|
const shouldShowBoundingBox = useAppSelector((s) => s.canvas.shouldShowBoundingBox);
|
||||||
|
|
||||||
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
|
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
|
||||||
postUploadAction: { type: 'SET_CANVAS_INITIAL_IMAGE' },
|
postUploadAction: { type: 'SET_CANVAS_INITIAL_IMAGE' },
|
||||||
@ -61,6 +70,18 @@ const IAICanvasToolbar = () => {
|
|||||||
[]
|
[]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
useHotkeys(
|
||||||
|
'shift+h',
|
||||||
|
() => {
|
||||||
|
dispatch(setShouldShowBoundingBox(!shouldShowBoundingBox));
|
||||||
|
},
|
||||||
|
{
|
||||||
|
enabled: () => !isStaging,
|
||||||
|
preventDefault: true,
|
||||||
|
},
|
||||||
|
[shouldShowBoundingBox]
|
||||||
|
);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
['r'],
|
['r'],
|
||||||
() => {
|
() => {
|
||||||
@ -125,6 +146,10 @@ const IAICanvasToolbar = () => {
|
|||||||
$tool.set('move');
|
$tool.set('move');
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const handleSetShouldShowBoundingBox = useCallback(() => {
|
||||||
|
dispatch(setShouldShowBoundingBox(!shouldShowBoundingBox));
|
||||||
|
}, [dispatch, shouldShowBoundingBox]);
|
||||||
|
|
||||||
const handleResetCanvasView = useCallback(
|
const handleResetCanvasView = useCallback(
|
||||||
(shouldScaleTo1 = false) => {
|
(shouldScaleTo1 = false) => {
|
||||||
const canvasBaseLayer = $canvasBaseLayer.get();
|
const canvasBaseLayer = $canvasBaseLayer.get();
|
||||||
@ -212,6 +237,13 @@ const IAICanvasToolbar = () => {
|
|||||||
isChecked={tool === 'move' || isStaging}
|
isChecked={tool === 'move' || isStaging}
|
||||||
onClick={handleSelectMoveTool}
|
onClick={handleSelectMoveTool}
|
||||||
/>
|
/>
|
||||||
|
<IconButton
|
||||||
|
aria-label={`${shouldShowBoundingBox ? t('unifiedCanvas.hideBoundingBox') : t('unifiedCanvas.showBoundingBox')} (Shift + H)`}
|
||||||
|
tooltip={`${shouldShowBoundingBox ? t('unifiedCanvas.hideBoundingBox') : t('unifiedCanvas.showBoundingBox')} (Shift + H)`}
|
||||||
|
icon={shouldShowBoundingBox ? <PiEyeBold /> : <PiEyeSlashBold />}
|
||||||
|
onClick={handleSetShouldShowBoundingBox}
|
||||||
|
isDisabled={isStaging}
|
||||||
|
/>
|
||||||
<IconButton
|
<IconButton
|
||||||
aria-label={`${t('unifiedCanvas.resetView')} (R)`}
|
aria-label={`${t('unifiedCanvas.resetView')} (R)`}
|
||||||
tooltip={`${t('unifiedCanvas.resetView')} (R)`}
|
tooltip={`${t('unifiedCanvas.resetView')} (R)`}
|
||||||
|
@ -7,12 +7,7 @@ import {
|
|||||||
resetToolInteractionState,
|
resetToolInteractionState,
|
||||||
} from 'features/canvas/store/canvasNanostore';
|
} from 'features/canvas/store/canvasNanostore';
|
||||||
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import {
|
import { clearMask, setIsMaskEnabled, setShouldSnapToGrid } from 'features/canvas/store/canvasSlice';
|
||||||
clearMask,
|
|
||||||
setIsMaskEnabled,
|
|
||||||
setShouldShowBoundingBox,
|
|
||||||
setShouldSnapToGrid,
|
|
||||||
} from 'features/canvas/store/canvasSlice';
|
|
||||||
import { isInteractiveTarget } from 'features/canvas/util/isInteractiveTarget';
|
import { isInteractiveTarget } from 'features/canvas/util/isInteractiveTarget';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { useCallback, useEffect } from 'react';
|
import { useCallback, useEffect } from 'react';
|
||||||
@ -21,7 +16,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
|||||||
const useInpaintingCanvasHotkeys = () => {
|
const useInpaintingCanvasHotkeys = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||||
const shouldShowBoundingBox = useAppSelector((s) => s.canvas.shouldShowBoundingBox);
|
|
||||||
const isStaging = useAppSelector(isStagingSelector);
|
const isStaging = useAppSelector(isStagingSelector);
|
||||||
const isMaskEnabled = useAppSelector((s) => s.canvas.isMaskEnabled);
|
const isMaskEnabled = useAppSelector((s) => s.canvas.isMaskEnabled);
|
||||||
const shouldSnapToGrid = useAppSelector((s) => s.canvas.shouldSnapToGrid);
|
const shouldSnapToGrid = useAppSelector((s) => s.canvas.shouldSnapToGrid);
|
||||||
@ -79,18 +73,6 @@ const useInpaintingCanvasHotkeys = () => {
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
useHotkeys(
|
|
||||||
'shift+h',
|
|
||||||
() => {
|
|
||||||
dispatch(setShouldShowBoundingBox(!shouldShowBoundingBox));
|
|
||||||
},
|
|
||||||
{
|
|
||||||
enabled: () => !isStaging,
|
|
||||||
preventDefault: true,
|
|
||||||
},
|
|
||||||
[activeTabName, shouldShowBoundingBox]
|
|
||||||
);
|
|
||||||
|
|
||||||
const onKeyDown = useCallback(
|
const onKeyDown = useCallback(
|
||||||
(e: KeyboardEvent) => {
|
(e: KeyboardEvent) => {
|
||||||
if (e.repeat || e.key !== ' ' || isInteractiveTarget(e.target) || activeTabName !== 'unifiedCanvas') {
|
if (e.repeat || e.key !== ' ' || isInteractiveTarget(e.target) || activeTabName !== 'unifiedCanvas') {
|
||||||
|
@ -103,7 +103,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex sx={{ gap: 2 }}>
|
<Flex sx={{ gap: 2 }}>
|
||||||
<Tooltip label={value?.description}>
|
<Tooltip label={selectedModel?.description}>
|
||||||
<FormControl
|
<FormControl
|
||||||
isDisabled={!isEnabled}
|
isDisabled={!isEnabled}
|
||||||
isInvalid={!value || mainModel?.base !== modelConfig?.base}
|
isInvalid={!value || mainModel?.base !== modelConfig?.base}
|
||||||
|
@ -13,13 +13,15 @@ export const DeleteImageButton = memo((props: DeleteImageButtonProps) => {
|
|||||||
const { onClick, isDisabled } = props;
|
const { onClick, isDisabled } = props;
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||||
|
const imageSelectionLength: number = useAppSelector((s) => s.gallery.selection.length);
|
||||||
|
const labelMessage: string = `${t('gallery.deleteImage', { count: imageSelectionLength })} (Del)`;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IconButton
|
<IconButton
|
||||||
onClick={onClick}
|
onClick={onClick}
|
||||||
icon={<PiTrashSimpleBold />}
|
icon={<PiTrashSimpleBold />}
|
||||||
tooltip={`${t('gallery.deleteImage')} (Del)`}
|
tooltip={labelMessage}
|
||||||
aria-label={`${t('gallery.deleteImage')} (Del)`}
|
aria-label={labelMessage}
|
||||||
isDisabled={isDisabled || !isConnected}
|
isDisabled={isDisabled || !isConnected}
|
||||||
colorScheme="error"
|
colorScheme="error"
|
||||||
/>
|
/>
|
||||||
|
@ -80,7 +80,7 @@ const DeleteImageModal = () => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<ConfirmationAlertDialog
|
<ConfirmationAlertDialog
|
||||||
title={t('gallery.deleteImage')}
|
title={t('gallery.deleteImage', { count: imagesToDelete.length })}
|
||||||
isOpen={isModalOpen}
|
isOpen={isModalOpen}
|
||||||
onClose={handleClose}
|
onClose={handleClose}
|
||||||
cancelButtonText={t('boards.cancel')}
|
cancelButtonText={t('boards.cancel')}
|
||||||
|
@ -32,7 +32,7 @@ const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props
|
|||||||
|
|
||||||
const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd);
|
const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd);
|
||||||
const boardName = useBoardName(board_id);
|
const boardName = useBoardName(board_id);
|
||||||
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled;
|
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
|
||||||
|
|
||||||
const [bulkDownload] = useBulkDownloadImagesMutation();
|
const [bulkDownload] = useBulkDownloadImagesMutation();
|
||||||
|
|
||||||
|
@ -6,7 +6,6 @@ import type { RemoveFromBoardDropData } from 'features/dnd/types';
|
|||||||
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
|
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
|
||||||
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
|
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
|
||||||
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
|
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
|
||||||
/** @knipignore */
|
|
||||||
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
|
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
|
||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
@ -51,9 +51,10 @@ const CurrentImageButtons = () => {
|
|||||||
const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails);
|
const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails);
|
||||||
const shouldShowProgressInViewer = useAppSelector((s) => s.ui.shouldShowProgressInViewer);
|
const shouldShowProgressInViewer = useAppSelector((s) => s.ui.shouldShowProgressInViewer);
|
||||||
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
|
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
|
||||||
|
const selection = useAppSelector((s) => s.gallery.selection);
|
||||||
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
|
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
|
||||||
|
|
||||||
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
|
const isUpscalingEnabled = useFeatureStatus('upscaling');
|
||||||
const isQueueMutationInProgress = useIsQueueMutationInProgress();
|
const isQueueMutationInProgress = useIsQueueMutationInProgress();
|
||||||
const toaster = useAppToaster();
|
const toaster = useAppToaster();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -102,8 +103,8 @@ const CurrentImageButtons = () => {
|
|||||||
if (!imageDTO) {
|
if (!imageDTO) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(imagesToDeleteSelected([imageDTO]));
|
dispatch(imagesToDeleteSelected(selection));
|
||||||
}, [dispatch, imageDTO]);
|
}, [dispatch, imageDTO, selection]);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'Shift+U',
|
'Shift+U',
|
||||||
|
@ -20,7 +20,7 @@ const MultipleSelectionMenuItems = () => {
|
|||||||
const selection = useAppSelector((s) => s.gallery.selection);
|
const selection = useAppSelector((s) => s.gallery.selection);
|
||||||
const customStarUi = useStore($customStarUI);
|
const customStarUi = useStore($customStarUI);
|
||||||
|
|
||||||
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled;
|
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
|
||||||
|
|
||||||
const [starImages] = useStarImagesMutation();
|
const [starImages] = useStarImagesMutation();
|
||||||
const [unstarImages] = useUnstarImagesMutation();
|
const [unstarImages] = useUnstarImagesMutation();
|
||||||
|
@ -45,7 +45,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const toaster = useAppToaster();
|
const toaster = useAppToaster();
|
||||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
const isCanvasEnabled = useFeatureStatus('unifiedCanvas');
|
||||||
const customStarUi = useStore($customStarUI);
|
const customStarUi = useStore($customStarUI);
|
||||||
const { downloadImage } = useDownloadImage();
|
const { downloadImage } = useDownloadImage();
|
||||||
|
|
||||||
@ -188,7 +188,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
)}
|
)}
|
||||||
<MenuDivider />
|
<MenuDivider />
|
||||||
<MenuItem color="error.300" icon={<PiTrashSimpleBold />} onClickCapture={handleDelete}>
|
<MenuItem color="error.300" icon={<PiTrashSimpleBold />} onClickCapture={handleDelete}>
|
||||||
{t('gallery.deleteImage')}
|
{t('gallery.deleteImage', { count: 1 })}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
|
@ -180,7 +180,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
<IAIDndImageIcon
|
<IAIDndImageIcon
|
||||||
onClick={handleDelete}
|
onClick={handleDelete}
|
||||||
icon={<PiTrashSimpleFill size="16px" />}
|
icon={<PiTrashSimpleFill size="16px" />}
|
||||||
tooltip={t('gallery.deleteImage')}
|
tooltip={t('gallery.deleteImage', { count: 1 })}
|
||||||
styleOverrides={imageIconStyleOverrides}
|
styleOverrides={imageIconStyleOverrides}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
@ -18,7 +18,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
|
|||||||
[imageDTO?.image_name]
|
[imageDTO?.image_name]
|
||||||
);
|
);
|
||||||
const isSelected = useAppSelector(selectIsSelected);
|
const isSelected = useAppSelector(selectIsSelected);
|
||||||
const isMultiSelectEnabled = useFeatureStatus('multiselect').isFeatureEnabled;
|
const isMultiSelectEnabled = useFeatureStatus('multiselect');
|
||||||
|
|
||||||
const handleClick = useCallback(
|
const handleClick = useCallback(
|
||||||
(e: MouseEvent<HTMLDivElement>) => {
|
(e: MouseEvent<HTMLDivElement>) => {
|
||||||
|
@ -8,7 +8,7 @@ import ParamHrfStrength from './ParamHrfStrength';
|
|||||||
import ParamHrfToggle from './ParamHrfToggle';
|
import ParamHrfToggle from './ParamHrfToggle';
|
||||||
|
|
||||||
export const HrfSettings = memo(() => {
|
export const HrfSettings = memo(() => {
|
||||||
const isHRFFeatureEnabled = useFeatureStatus('hrf').isFeatureEnabled;
|
const isHRFFeatureEnabled = useFeatureStatus('hrf');
|
||||||
const hrfEnabled = useAppSelector((s) => s.hrf.hrfEnabled);
|
const hrfEnabled = useAppSelector((s) => s.hrf.hrfEnabled);
|
||||||
|
|
||||||
if (!isHRFFeatureEnabled) {
|
if (!isHRFFeatureEnabled) {
|
||||||
|
@ -156,8 +156,13 @@ const parseSteps: MetadataParseFunc<ParameterSteps> = (metadata) => getProperty(
|
|||||||
const parseStrength: MetadataParseFunc<ParameterStrength> = (metadata) =>
|
const parseStrength: MetadataParseFunc<ParameterStrength> = (metadata) =>
|
||||||
getProperty(metadata, 'strength', isParameterStrength);
|
getProperty(metadata, 'strength', isParameterStrength);
|
||||||
|
|
||||||
const parseHRFEnabled: MetadataParseFunc<ParameterHRFEnabled> = (metadata) =>
|
const parseHRFEnabled: MetadataParseFunc<ParameterHRFEnabled> = async (metadata) => {
|
||||||
getProperty(metadata, 'hrf_enabled', isParameterHRFEnabled);
|
try {
|
||||||
|
return await getProperty(metadata, 'hrf_enabled', isParameterHRFEnabled);
|
||||||
|
} catch {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const parseHRFStrength: MetadataParseFunc<ParameterStrength> = (metadata) =>
|
const parseHRFStrength: MetadataParseFunc<ParameterStrength> = (metadata) =>
|
||||||
getProperty(metadata, 'hrf_strength', isParameterStrength);
|
getProperty(metadata, 'hrf_strength', isParameterStrength);
|
||||||
@ -224,12 +229,16 @@ const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const parseAllLoRAs: MetadataParseFunc<LoRA[]> = async (metadata) => {
|
const parseAllLoRAs: MetadataParseFunc<LoRA[]> = async (metadata) => {
|
||||||
const lorasRaw = await getProperty(metadata, 'loras', isArray);
|
try {
|
||||||
const parseResults = await Promise.allSettled(lorasRaw.map((lora) => parseLoRA(lora)));
|
const lorasRaw = await getProperty(metadata, 'loras', isArray);
|
||||||
const loras = parseResults
|
const parseResults = await Promise.allSettled(lorasRaw.map((lora) => parseLoRA(lora)));
|
||||||
.filter((result): result is PromiseFulfilledResult<LoRA> => result.status === 'fulfilled')
|
const loras = parseResults
|
||||||
.map((result) => result.value);
|
.filter((result): result is PromiseFulfilledResult<LoRA> => result.status === 'fulfilled')
|
||||||
return loras;
|
.map((result) => result.value);
|
||||||
|
return loras;
|
||||||
|
} catch {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (metadataItem) => {
|
const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (metadataItem) => {
|
||||||
@ -288,12 +297,16 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
|
|||||||
};
|
};
|
||||||
|
|
||||||
const parseAllControlNets: MetadataParseFunc<ControlNetConfigMetadata[]> = async (metadata) => {
|
const parseAllControlNets: MetadataParseFunc<ControlNetConfigMetadata[]> = async (metadata) => {
|
||||||
const controlNetsRaw = await getProperty(metadata, 'controlnets', isArray);
|
try {
|
||||||
const parseResults = await Promise.allSettled(controlNetsRaw.map((cn) => parseControlNet(cn)));
|
const controlNetsRaw = await getProperty(metadata, 'controlnets', isArray || undefined);
|
||||||
const controlNets = parseResults
|
const parseResults = await Promise.allSettled(controlNetsRaw.map((cn) => parseControlNet(cn)));
|
||||||
.filter((result): result is PromiseFulfilledResult<ControlNetConfigMetadata> => result.status === 'fulfilled')
|
const controlNets = parseResults
|
||||||
.map((result) => result.value);
|
.filter((result): result is PromiseFulfilledResult<ControlNetConfigMetadata> => result.status === 'fulfilled')
|
||||||
return controlNets;
|
.map((result) => result.value);
|
||||||
|
return controlNets;
|
||||||
|
} catch {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (metadataItem) => {
|
const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (metadataItem) => {
|
||||||
@ -348,12 +361,16 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
|
|||||||
};
|
};
|
||||||
|
|
||||||
const parseAllT2IAdapters: MetadataParseFunc<T2IAdapterConfigMetadata[]> = async (metadata) => {
|
const parseAllT2IAdapters: MetadataParseFunc<T2IAdapterConfigMetadata[]> = async (metadata) => {
|
||||||
const t2iAdaptersRaw = await getProperty(metadata, 't2iAdapters', isArray);
|
try {
|
||||||
const parseResults = await Promise.allSettled(t2iAdaptersRaw.map((t2iAdapter) => parseT2IAdapter(t2iAdapter)));
|
const t2iAdaptersRaw = await getProperty(metadata, 't2iAdapters', isArray);
|
||||||
const t2iAdapters = parseResults
|
const parseResults = await Promise.allSettled(t2iAdaptersRaw.map((t2iAdapter) => parseT2IAdapter(t2iAdapter)));
|
||||||
.filter((result): result is PromiseFulfilledResult<T2IAdapterConfigMetadata> => result.status === 'fulfilled')
|
const t2iAdapters = parseResults
|
||||||
.map((result) => result.value);
|
.filter((result): result is PromiseFulfilledResult<T2IAdapterConfigMetadata> => result.status === 'fulfilled')
|
||||||
return t2iAdapters;
|
.map((result) => result.value);
|
||||||
|
return t2iAdapters;
|
||||||
|
} catch {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metadataItem) => {
|
const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metadataItem) => {
|
||||||
@ -399,12 +416,16 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
|||||||
};
|
};
|
||||||
|
|
||||||
const parseAllIPAdapters: MetadataParseFunc<IPAdapterConfigMetadata[]> = async (metadata) => {
|
const parseAllIPAdapters: MetadataParseFunc<IPAdapterConfigMetadata[]> = async (metadata) => {
|
||||||
const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray);
|
try {
|
||||||
const parseResults = await Promise.allSettled(ipAdaptersRaw.map((ipAdapter) => parseIPAdapter(ipAdapter)));
|
const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray);
|
||||||
const ipAdapters = parseResults
|
const parseResults = await Promise.allSettled(ipAdaptersRaw.map((ipAdapter) => parseIPAdapter(ipAdapter)));
|
||||||
.filter((result): result is PromiseFulfilledResult<IPAdapterConfigMetadata> => result.status === 'fulfilled')
|
const ipAdapters = parseResults
|
||||||
.map((result) => result.value);
|
.filter((result): result is PromiseFulfilledResult<IPAdapterConfigMetadata> => result.status === 'fulfilled')
|
||||||
return ipAdapters;
|
.map((result) => result.value);
|
||||||
|
return ipAdapters;
|
||||||
|
} catch {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
export const parsers = {
|
export const parsers = {
|
||||||
|
@ -177,11 +177,11 @@ const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const recallAllLoRAs: MetadataRecallFunc<LoRA[]> = (loras) => {
|
const recallAllLoRAs: MetadataRecallFunc<LoRA[]> = (loras) => {
|
||||||
|
const { dispatch } = getStore();
|
||||||
|
dispatch(lorasReset());
|
||||||
if (!loras.length) {
|
if (!loras.length) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const { dispatch } = getStore();
|
|
||||||
dispatch(lorasReset());
|
|
||||||
loras.forEach((lora) => {
|
loras.forEach((lora) => {
|
||||||
dispatch(loraRecalled(lora));
|
dispatch(loraRecalled(lora));
|
||||||
});
|
});
|
||||||
@ -192,11 +192,11 @@ const recallControlNet: MetadataRecallFunc<ControlNetConfigMetadata> = (controlN
|
|||||||
};
|
};
|
||||||
|
|
||||||
const recallControlNets: MetadataRecallFunc<ControlNetConfigMetadata[]> = (controlNets) => {
|
const recallControlNets: MetadataRecallFunc<ControlNetConfigMetadata[]> = (controlNets) => {
|
||||||
|
const { dispatch } = getStore();
|
||||||
|
dispatch(controlNetsReset());
|
||||||
if (!controlNets.length) {
|
if (!controlNets.length) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const { dispatch } = getStore();
|
|
||||||
dispatch(controlNetsReset());
|
|
||||||
controlNets.forEach((controlNet) => {
|
controlNets.forEach((controlNet) => {
|
||||||
dispatch(controlAdapterRecalled(controlNet));
|
dispatch(controlAdapterRecalled(controlNet));
|
||||||
});
|
});
|
||||||
@ -207,11 +207,11 @@ const recallT2IAdapter: MetadataRecallFunc<T2IAdapterConfigMetadata> = (t2iAdapt
|
|||||||
};
|
};
|
||||||
|
|
||||||
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfigMetadata[]> = (t2iAdapters) => {
|
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfigMetadata[]> = (t2iAdapters) => {
|
||||||
|
const { dispatch } = getStore();
|
||||||
|
dispatch(t2iAdaptersReset());
|
||||||
if (!t2iAdapters.length) {
|
if (!t2iAdapters.length) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const { dispatch } = getStore();
|
|
||||||
dispatch(t2iAdaptersReset());
|
|
||||||
t2iAdapters.forEach((t2iAdapter) => {
|
t2iAdapters.forEach((t2iAdapter) => {
|
||||||
dispatch(controlAdapterRecalled(t2iAdapter));
|
dispatch(controlAdapterRecalled(t2iAdapter));
|
||||||
});
|
});
|
||||||
@ -222,11 +222,11 @@ const recallIPAdapter: MetadataRecallFunc<IPAdapterConfigMetadata> = (ipAdapter)
|
|||||||
};
|
};
|
||||||
|
|
||||||
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapters) => {
|
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapters) => {
|
||||||
|
const { dispatch } = getStore();
|
||||||
|
dispatch(ipAdaptersReset());
|
||||||
if (!ipAdapters.length) {
|
if (!ipAdapters.length) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const { dispatch } = getStore();
|
|
||||||
dispatch(ipAdaptersReset());
|
|
||||||
ipAdapters.forEach((ipAdapter) => {
|
ipAdapters.forEach((ipAdapter) => {
|
||||||
dispatch(controlAdapterRecalled(ipAdapter));
|
dispatch(controlAdapterRecalled(ipAdapter));
|
||||||
});
|
});
|
||||||
|
@ -10,7 +10,7 @@ const TOAST_ID = 'starterModels';
|
|||||||
|
|
||||||
export const useStarterModelsToast = () => {
|
export const useStarterModelsToast = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const isEnabled = useFeatureStatus('starterModels').isFeatureEnabled;
|
const isEnabled = useFeatureStatus('starterModels');
|
||||||
const [didToast, setDidToast] = useState(false);
|
const [didToast, setDidToast] = useState(false);
|
||||||
const [mainModels, { data }] = useMainModels();
|
const [mainModels, { data }] = useMainModels();
|
||||||
const toast = useToast();
|
const toast = useToast();
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
|
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import type { CSSProperties } from 'react';
|
import type { CSSProperties } from 'react';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import type { EdgeProps } from 'reactflow';
|
import type { EdgeProps } from 'reactflow';
|
||||||
import { BaseEdge, getBezierPath } from 'reactflow';
|
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
|
||||||
|
|
||||||
import { makeEdgeSelector } from './util/makeEdgeSelector';
|
import { makeEdgeSelector } from './util/makeEdgeSelector';
|
||||||
|
|
||||||
@ -25,9 +26,10 @@ const InvocationDefaultEdge = ({
|
|||||||
[source, sourceHandleId, target, targetHandleId, selected]
|
[source, sourceHandleId, target, targetHandleId, selected]
|
||||||
);
|
);
|
||||||
|
|
||||||
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
|
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
|
||||||
|
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels);
|
||||||
|
|
||||||
const [edgePath] = getBezierPath({
|
const [edgePath, labelX, labelY] = getBezierPath({
|
||||||
sourceX,
|
sourceX,
|
||||||
sourceY,
|
sourceY,
|
||||||
sourcePosition,
|
sourcePosition,
|
||||||
@ -47,7 +49,33 @@ const InvocationDefaultEdge = ({
|
|||||||
[isSelected, shouldAnimate, stroke]
|
[isSelected, shouldAnimate, stroke]
|
||||||
);
|
);
|
||||||
|
|
||||||
return <BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />;
|
return (
|
||||||
|
<>
|
||||||
|
<BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />
|
||||||
|
{label && shouldShowEdgeLabels && (
|
||||||
|
<EdgeLabelRenderer>
|
||||||
|
<Flex
|
||||||
|
className="nodrag nopan"
|
||||||
|
pointerEvents="all"
|
||||||
|
position="absolute"
|
||||||
|
transform={`translate(-50%, -50%) translate(${labelX}px,${labelY}px)`}
|
||||||
|
bg="base.800"
|
||||||
|
borderRadius="base"
|
||||||
|
borderWidth={1}
|
||||||
|
borderColor={isSelected ? 'undefined' : 'transparent'}
|
||||||
|
opacity={isSelected ? 1 : 0.5}
|
||||||
|
py={1}
|
||||||
|
px={3}
|
||||||
|
shadow="md"
|
||||||
|
>
|
||||||
|
<Text size="sm" fontWeight="semibold" color={isSelected ? 'base.100' : 'base.300'}>
|
||||||
|
{label}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
</EdgeLabelRenderer>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(InvocationDefaultEdge);
|
export default memo(InvocationDefaultEdge);
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
import { selectFieldOutputTemplate, selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
|
|
||||||
import { getFieldColor } from './getEdgeColor';
|
import { getFieldColor } from './getEdgeColor';
|
||||||
@ -10,6 +10,7 @@ const defaultReturnValue = {
|
|||||||
isSelected: false,
|
isSelected: false,
|
||||||
shouldAnimate: false,
|
shouldAnimate: false,
|
||||||
stroke: colorTokenToCssVar('base.500'),
|
stroke: colorTokenToCssVar('base.500'),
|
||||||
|
label: '',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const makeEdgeSelector = (
|
export const makeEdgeSelector = (
|
||||||
@ -19,25 +20,34 @@ export const makeEdgeSelector = (
|
|||||||
targetHandleId: string | null | undefined,
|
targetHandleId: string | null | undefined,
|
||||||
selected?: boolean
|
selected?: boolean
|
||||||
) =>
|
) =>
|
||||||
createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => {
|
createMemoizedSelector(
|
||||||
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
selectNodesSlice,
|
||||||
const targetNode = nodes.nodes.find((node) => node.id === target);
|
(nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
|
||||||
|
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||||
|
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||||
|
|
||||||
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||||
|
|
||||||
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
|
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
|
||||||
if (!sourceNode || !sourceHandleId) {
|
if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
|
||||||
return defaultReturnValue;
|
return defaultReturnValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
|
||||||
|
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
|
||||||
|
|
||||||
|
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
||||||
|
|
||||||
|
const sourceNodeTemplate = selectNodeTemplate(nodes, sourceNode.id);
|
||||||
|
const targetNodeTemplate = selectNodeTemplate(nodes, targetNode.id);
|
||||||
|
|
||||||
|
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
|
||||||
|
|
||||||
|
return {
|
||||||
|
isSelected,
|
||||||
|
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
|
||||||
|
stroke,
|
||||||
|
label,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
);
|
||||||
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
|
|
||||||
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
|
|
||||||
|
|
||||||
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
|
||||||
|
|
||||||
return {
|
|
||||||
isSelected,
|
|
||||||
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
|
|
||||||
stroke,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
@ -16,7 +16,7 @@ const props: ChakraProps = { w: 'unset' };
|
|||||||
|
|
||||||
const InvocationNodeFooter = ({ nodeId }: Props) => {
|
const InvocationNodeFooter = ({ nodeId }: Props) => {
|
||||||
const hasImageOutput = useHasImageOutput(nodeId);
|
const hasImageOutput = useHasImageOutput(nodeId);
|
||||||
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
|
const isCacheEnabled = useFeatureStatus('invocationCache');
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
className={DRAG_HANDLE_CLASSNAME}
|
className={DRAG_HANDLE_CLASSNAME}
|
||||||
|
@ -24,6 +24,7 @@ import {
|
|||||||
selectNodesSlice,
|
selectNodesSlice,
|
||||||
shouldAnimateEdgesChanged,
|
shouldAnimateEdgesChanged,
|
||||||
shouldColorEdgesChanged,
|
shouldColorEdgesChanged,
|
||||||
|
shouldShowEdgeLabelsChanged,
|
||||||
shouldSnapToGridChanged,
|
shouldSnapToGridChanged,
|
||||||
shouldValidateGraphChanged,
|
shouldValidateGraphChanged,
|
||||||
} from 'features/nodes/store/nodesSlice';
|
} from 'features/nodes/store/nodesSlice';
|
||||||
@ -35,12 +36,20 @@ import { SelectionMode } from 'reactflow';
|
|||||||
const formLabelProps: FormLabelProps = { flexGrow: 1 };
|
const formLabelProps: FormLabelProps = { flexGrow: 1 };
|
||||||
|
|
||||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||||
const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionMode } = nodes;
|
const {
|
||||||
|
shouldAnimateEdges,
|
||||||
|
shouldValidateGraph,
|
||||||
|
shouldSnapToGrid,
|
||||||
|
shouldColorEdges,
|
||||||
|
shouldShowEdgeLabels,
|
||||||
|
selectionMode,
|
||||||
|
} = nodes;
|
||||||
return {
|
return {
|
||||||
shouldAnimateEdges,
|
shouldAnimateEdges,
|
||||||
shouldValidateGraph,
|
shouldValidateGraph,
|
||||||
shouldSnapToGrid,
|
shouldSnapToGrid,
|
||||||
shouldColorEdges,
|
shouldColorEdges,
|
||||||
|
shouldShowEdgeLabels,
|
||||||
selectionModeIsChecked: selectionMode === SelectionMode.Full,
|
selectionModeIsChecked: selectionMode === SelectionMode.Full,
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
@ -52,8 +61,14 @@ type Props = {
|
|||||||
const WorkflowEditorSettings = ({ children }: Props) => {
|
const WorkflowEditorSettings = ({ children }: Props) => {
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionModeIsChecked } =
|
const {
|
||||||
useAppSelector(selector);
|
shouldAnimateEdges,
|
||||||
|
shouldValidateGraph,
|
||||||
|
shouldSnapToGrid,
|
||||||
|
shouldColorEdges,
|
||||||
|
shouldShowEdgeLabels,
|
||||||
|
selectionModeIsChecked,
|
||||||
|
} = useAppSelector(selector);
|
||||||
|
|
||||||
const handleChangeShouldValidate = useCallback(
|
const handleChangeShouldValidate = useCallback(
|
||||||
(e: ChangeEvent<HTMLInputElement>) => {
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
@ -90,6 +105,13 @@ const WorkflowEditorSettings = ({ children }: Props) => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleChangeShouldShowEdgeLabels = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
dispatch(shouldShowEdgeLabelsChanged(e.target.checked));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -137,6 +159,14 @@ const WorkflowEditorSettings = ({ children }: Props) => {
|
|||||||
<FormHelperText>{t('nodes.fullyContainNodesHelp')}</FormHelperText>
|
<FormHelperText>{t('nodes.fullyContainNodesHelp')}</FormHelperText>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<Divider />
|
<Divider />
|
||||||
|
<FormControl>
|
||||||
|
<Flex w="full">
|
||||||
|
<FormLabel>{t('nodes.showEdgeLabels')}</FormLabel>
|
||||||
|
<Switch isChecked={shouldShowEdgeLabels} onChange={handleChangeShouldShowEdgeLabels} />
|
||||||
|
</Flex>
|
||||||
|
<FormHelperText>{t('nodes.showEdgeLabelsHelp')}</FormHelperText>
|
||||||
|
</FormControl>
|
||||||
|
<Divider />
|
||||||
<Heading size="sm" pt={4}>
|
<Heading size="sm" pt={4}>
|
||||||
{t('common.advanced')}
|
{t('common.advanced')}
|
||||||
</Heading>
|
</Heading>
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import { Button, Flex, Image, Text } from '@invoke-ai/ui-library';
|
import { Button, Flex, Image, Text } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { workflowModeChanged } from 'features/nodes/store/workflowSlice';
|
import { workflowModeChanged } from 'features/nodes/store/workflowSlice';
|
||||||
/** @knipignore */
|
|
||||||
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
|
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||||
@ -10,7 +10,7 @@ import { useMemo } from 'react';
|
|||||||
export const useOutputFieldNames = (nodeId: string) => {
|
export const useOutputFieldNames = (nodeId: string) => {
|
||||||
const selector = useMemo(
|
const selector = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createSelector(selectNodesSlice, (nodes) => {
|
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||||
const template = selectNodeTemplate(nodes, nodeId);
|
const template = selectNodeTemplate(nodes, nodeId);
|
||||||
if (!template) {
|
if (!template) {
|
||||||
return EMPTY_ARRAY;
|
return EMPTY_ARRAY;
|
||||||
|
@ -5,8 +5,7 @@ import { useHasImageOutput } from './useHasImageOutput';
|
|||||||
|
|
||||||
export const useWithFooter = (nodeId: string) => {
|
export const useWithFooter = (nodeId: string) => {
|
||||||
const hasImageOutput = useHasImageOutput(nodeId);
|
const hasImageOutput = useHasImageOutput(nodeId);
|
||||||
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
|
const isCacheEnabled = useFeatureStatus('invocationCache');
|
||||||
|
|
||||||
const withFooter = useMemo(() => hasImageOutput || isCacheEnabled, [hasImageOutput, isCacheEnabled]);
|
const withFooter = useMemo(() => hasImageOutput || isCacheEnabled, [hasImageOutput, isCacheEnabled]);
|
||||||
return withFooter;
|
return withFooter;
|
||||||
};
|
};
|
||||||
|
@ -103,6 +103,7 @@ const initialNodesState: NodesState = {
|
|||||||
shouldAnimateEdges: true,
|
shouldAnimateEdges: true,
|
||||||
shouldSnapToGrid: false,
|
shouldSnapToGrid: false,
|
||||||
shouldColorEdges: true,
|
shouldColorEdges: true,
|
||||||
|
shouldShowEdgeLabels: false,
|
||||||
isAddNodePopoverOpen: false,
|
isAddNodePopoverOpen: false,
|
||||||
nodeOpacity: 1,
|
nodeOpacity: 1,
|
||||||
selectedNodes: [],
|
selectedNodes: [],
|
||||||
@ -549,6 +550,9 @@ export const nodesSlice = createSlice({
|
|||||||
shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => {
|
shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldAnimateEdges = action.payload;
|
state.shouldAnimateEdges = action.payload;
|
||||||
},
|
},
|
||||||
|
shouldShowEdgeLabelsChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldShowEdgeLabels = action.payload;
|
||||||
|
},
|
||||||
shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => {
|
shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldSnapToGrid = action.payload;
|
state.shouldSnapToGrid = action.payload;
|
||||||
},
|
},
|
||||||
@ -831,6 +835,7 @@ export const {
|
|||||||
viewportChanged,
|
viewportChanged,
|
||||||
edgeAdded,
|
edgeAdded,
|
||||||
nodeTemplatesBuilt,
|
nodeTemplatesBuilt,
|
||||||
|
shouldShowEdgeLabelsChanged,
|
||||||
} = nodesSlice.actions;
|
} = nodesSlice.actions;
|
||||||
|
|
||||||
// This is used for tracking `state.workflow.isTouched`
|
// This is used for tracking `state.workflow.isTouched`
|
||||||
|
@ -32,6 +32,7 @@ export type NodesState = {
|
|||||||
isAddNodePopoverOpen: boolean;
|
isAddNodePopoverOpen: boolean;
|
||||||
addNewNodePosition: XYPosition | null;
|
addNewNodePosition: XYPosition | null;
|
||||||
selectionMode: SelectionMode;
|
selectionMode: SelectionMode;
|
||||||
|
shouldShowEdgeLabels: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type WorkflowMode = 'edit' | 'view';
|
export type WorkflowMode = 'edit' | 'view';
|
||||||
|
@ -1,24 +1,18 @@
|
|||||||
import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||||
import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import type { RgbaColor } from 'react-colorful';
|
import type { RgbaColor } from 'react-colorful';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const selectInfillColor = createMemoizedSelector(selectGenerationSlice, (generation) => generation.infillColorValue);
|
||||||
|
|
||||||
const ParamInfillColorOptions = () => {
|
const ParamInfillColorOptions = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const selector = useMemo(
|
const infillColor = useAppSelector(selectInfillColor);
|
||||||
() =>
|
|
||||||
createSelector(selectGenerationSlice, (generation) => ({
|
|
||||||
infillColor: generation.infillColorValue,
|
|
||||||
})),
|
|
||||||
[]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { infillColor } = useAppSelector(selector);
|
|
||||||
|
|
||||||
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
||||||
|
|
||||||
|
@ -1,35 +1,23 @@
|
|||||||
import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||||
import {
|
import {
|
||||||
selectGenerationSlice,
|
|
||||||
setInfillMosaicMaxColor,
|
setInfillMosaicMaxColor,
|
||||||
setInfillMosaicMinColor,
|
setInfillMosaicMinColor,
|
||||||
setInfillMosaicTileHeight,
|
setInfillMosaicTileHeight,
|
||||||
setInfillMosaicTileWidth,
|
setInfillMosaicTileWidth,
|
||||||
} from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import type { RgbaColor } from 'react-colorful';
|
import type { RgbaColor } from 'react-colorful';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
const ParamInfillMosaicTileSize = () => {
|
const ParamInfillMosaicTileSize = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const selector = useMemo(
|
const infillMosaicTileWidth = useAppSelector((s) => s.generation.infillMosaicTileWidth);
|
||||||
() =>
|
const infillMosaicTileHeight = useAppSelector((s) => s.generation.infillMosaicTileHeight);
|
||||||
createSelector(selectGenerationSlice, (generation) => ({
|
const infillMosaicMinColor = useAppSelector((s) => s.generation.infillMosaicMinColor);
|
||||||
infillMosaicTileWidth: generation.infillMosaicTileWidth,
|
const infillMosaicMaxColor = useAppSelector((s) => s.generation.infillMosaicMaxColor);
|
||||||
infillMosaicTileHeight: generation.infillMosaicTileHeight,
|
|
||||||
infillMosaicMinColor: generation.infillMosaicMinColor,
|
|
||||||
infillMosaicMaxColor: generation.infillMosaicMaxColor,
|
|
||||||
})),
|
|
||||||
[]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { infillMosaicTileWidth, infillMosaicTileHeight, infillMosaicMinColor, infillMosaicMaxColor } =
|
|
||||||
useAppSelector(selector);
|
|
||||||
|
|
||||||
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
|
import { Box, Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
@ -46,20 +46,22 @@ const ParamMainModelSelect = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Tooltip label={tooltipLabel}>
|
<FormControl isDisabled={!modelConfigs.length} isInvalid={!value || !modelConfigs.length}>
|
||||||
<FormControl isDisabled={!modelConfigs.length} isInvalid={!value || !modelConfigs.length}>
|
<InformationalPopover feature="paramModel">
|
||||||
<InformationalPopover feature="paramModel">
|
<FormLabel>{t('modelManager.model')}</FormLabel>
|
||||||
<FormLabel>{t('modelManager.model')}</FormLabel>
|
</InformationalPopover>
|
||||||
</InformationalPopover>
|
<Tooltip label={tooltipLabel}>
|
||||||
<Combobox
|
<Box w="full">
|
||||||
value={value}
|
<Combobox
|
||||||
placeholder={placeholder}
|
value={value}
|
||||||
options={options}
|
placeholder={placeholder}
|
||||||
onChange={onChange}
|
options={options}
|
||||||
noOptionsMessage={noOptionsMessage}
|
onChange={onChange}
|
||||||
/>
|
noOptionsMessage={noOptionsMessage}
|
||||||
</FormControl>
|
/>
|
||||||
</Tooltip>
|
</Box>
|
||||||
|
</Tooltip>
|
||||||
|
</FormControl>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -27,8 +27,8 @@ export const QueueActionsMenuButton = memo(() => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const clearQueueDisclosure = useDisclosure();
|
const clearQueueDisclosure = useDisclosure();
|
||||||
const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled;
|
const isPauseEnabled = useFeatureStatus('pauseQueue');
|
||||||
const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled;
|
const isResumeEnabled = useFeatureStatus('resumeQueue');
|
||||||
const { queueSize } = useGetQueueStatusQuery(undefined, {
|
const { queueSize } = useGetQueueStatusQuery(undefined, {
|
||||||
selectFromResult: (res) => ({
|
selectFromResult: (res) => ({
|
||||||
queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0,
|
queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0,
|
||||||
|
@ -9,7 +9,7 @@ import { InvokeQueueBackButton } from './InvokeQueueBackButton';
|
|||||||
import { QueueActionsMenuButton } from './QueueActionsMenuButton';
|
import { QueueActionsMenuButton } from './QueueActionsMenuButton';
|
||||||
|
|
||||||
const QueueControls = () => {
|
const QueueControls = () => {
|
||||||
const isPrependEnabled = useFeatureStatus('prependQueue').isFeatureEnabled;
|
const isPrependEnabled = useFeatureStatus('prependQueue');
|
||||||
return (
|
return (
|
||||||
<Flex w="full" position="relative" borderRadius="base" gap={2} pt={2} flexDir="column">
|
<Flex w="full" position="relative" borderRadius="base" gap={2} pt={2} flexDir="column">
|
||||||
<ButtonGroup size="lg" isAttached={false}>
|
<ButtonGroup size="lg" isAttached={false}>
|
||||||
|
@ -8,7 +8,7 @@ import QueueStatus from './QueueStatus';
|
|||||||
import QueueTabQueueControls from './QueueTabQueueControls';
|
import QueueTabQueueControls from './QueueTabQueueControls';
|
||||||
|
|
||||||
const QueueTabContent = () => {
|
const QueueTabContent = () => {
|
||||||
const isInvocationCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
|
const isInvocationCacheEnabled = useFeatureStatus('invocationCache');
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
|
<Flex borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
|
||||||
|
@ -8,8 +8,8 @@ import PruneQueueButton from './PruneQueueButton';
|
|||||||
import ResumeProcessorButton from './ResumeProcessorButton';
|
import ResumeProcessorButton from './ResumeProcessorButton';
|
||||||
|
|
||||||
const QueueTabQueueControls = () => {
|
const QueueTabQueueControls = () => {
|
||||||
const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled;
|
const isPauseEnabled = useFeatureStatus('pauseQueue');
|
||||||
const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled;
|
const isResumeEnabled = useFeatureStatus('resumeQueue');
|
||||||
return (
|
return (
|
||||||
<Flex layerStyle="first" borderRadius="base" p={2} gap={2}>
|
<Flex layerStyle="first" borderRadius="base" p={2} gap={2}>
|
||||||
{isPauseEnabled || isResumeEnabled ? (
|
{isPauseEnabled || isResumeEnabled ? (
|
||||||
|
@ -13,7 +13,7 @@ export const useQueueFront = () => {
|
|||||||
const [_, { isLoading }] = useEnqueueBatchMutation({
|
const [_, { isLoading }] = useEnqueueBatchMutation({
|
||||||
fixedCacheKey: 'enqueueBatch',
|
fixedCacheKey: 'enqueueBatch',
|
||||||
});
|
});
|
||||||
const prependEnabled = useFeatureStatus('prependQueue').isFeatureEnabled;
|
const prependEnabled = useFeatureStatus('prependQueue');
|
||||||
|
|
||||||
const isDisabled = useMemo(() => {
|
const isDisabled = useMemo(() => {
|
||||||
return !isReady || !prependEnabled;
|
return !isReady || !prependEnabled;
|
||||||
|
@ -62,7 +62,7 @@ const selector = createMemoizedSelector(selectControlAdaptersSlice, (controlAdap
|
|||||||
export const ControlSettingsAccordion: React.FC = memo(() => {
|
export const ControlSettingsAccordion: React.FC = memo(() => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { controlAdapterIds, badges } = useAppSelector(selector);
|
const { controlAdapterIds, badges } = useAppSelector(selector);
|
||||||
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
|
const isControlNetEnabled = useFeatureStatus('controlNet');
|
||||||
const { isOpen, onToggle } = useStandaloneAccordionToggle({
|
const { isOpen, onToggle } = useStandaloneAccordionToggle({
|
||||||
id: 'control-settings',
|
id: 'control-settings',
|
||||||
defaultIsOpen: true,
|
defaultIsOpen: true,
|
||||||
@ -71,7 +71,7 @@ export const ControlSettingsAccordion: React.FC = memo(() => {
|
|||||||
const [addIPAdapter, isAddIPAdapterDisabled] = useAddControlAdapter('ip_adapter');
|
const [addIPAdapter, isAddIPAdapterDisabled] = useAddControlAdapter('ip_adapter');
|
||||||
const [addT2IAdapter, isAddT2IAdapterDisabled] = useAddControlAdapter('t2i_adapter');
|
const [addT2IAdapter, isAddT2IAdapterDisabled] = useAddControlAdapter('t2i_adapter');
|
||||||
|
|
||||||
if (isControlNetDisabled) {
|
if (!isControlNetEnabled) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,6 +29,7 @@ const selector = createMemoizedSelector(
|
|||||||
const { shouldRandomizeSeed, model } = generation;
|
const { shouldRandomizeSeed, model } = generation;
|
||||||
const { hrfEnabled } = hrf;
|
const { hrfEnabled } = hrf;
|
||||||
const badges: string[] = [];
|
const badges: string[] = [];
|
||||||
|
const isSDXL = model?.base === 'sdxl';
|
||||||
|
|
||||||
if (activeTabName === 'unifiedCanvas') {
|
if (activeTabName === 'unifiedCanvas') {
|
||||||
const {
|
const {
|
||||||
@ -53,10 +54,10 @@ const selector = createMemoizedSelector(
|
|||||||
badges.push('Manual Seed');
|
badges.push('Manual Seed');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hrfEnabled) {
|
if (hrfEnabled && !isSDXL) {
|
||||||
badges.push('HiRes Fix');
|
badges.push('HiRes Fix');
|
||||||
}
|
}
|
||||||
return { badges, activeTabName, isSDXL: model?.base === 'sdxl' };
|
return { badges, activeTabName, isSDXL };
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -21,7 +21,6 @@ import {
|
|||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
import { discordLink, githubLink, websiteLink } from 'features/system/store/constants';
|
import { discordLink, githubLink, websiteLink } from 'features/system/store/constants';
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
/** @knipignore */
|
|
||||||
import InvokeLogoYellow from 'public/assets/images/invoke-tag-lrg.svg';
|
import InvokeLogoYellow from 'public/assets/images/invoke-tag-lrg.svg';
|
||||||
import type { ReactElement } from 'react';
|
import type { ReactElement } from 'react';
|
||||||
import { cloneElement, memo, useCallback } from 'react';
|
import { cloneElement, memo, useCallback } from 'react';
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
import { Image, Text, Tooltip } from '@invoke-ai/ui-library';
|
import { Image, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||||
import { useStore } from '@nanostores/react';
|
import { useStore } from '@nanostores/react';
|
||||||
import { $logo } from 'app/store/nanostores/logo';
|
import { $logo } from 'app/store/nanostores/logo';
|
||||||
/** @knipignore */
|
|
||||||
import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg';
|
import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg';
|
||||||
import { memo, useMemo, useRef } from 'react';
|
import { memo, useMemo, useRef } from 'react';
|
||||||
import { useGetAppVersionQuery } from 'services/api/endpoints/appInfo';
|
import { useGetAppVersionQuery } from 'services/api/endpoints/appInfo';
|
||||||
|
@ -40,7 +40,7 @@ export const SettingsLanguageSelect = memo(() => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const language = useAppSelector((s) => s.system.language);
|
const language = useAppSelector((s) => s.system.language);
|
||||||
const isLocalizationEnabled = useFeatureStatus('localization').isFeatureEnabled;
|
const isLocalizationEnabled = useFeatureStatus('localization');
|
||||||
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === language), [language]);
|
const value = useMemo(() => options.find((o) => o.value === language), [language]);
|
||||||
|
|
||||||
|
@ -23,9 +23,9 @@ const SettingsMenu = () => {
|
|||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
useGlobalMenuClose(onClose);
|
useGlobalMenuClose(onClose);
|
||||||
|
|
||||||
const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled;
|
const isBugLinkEnabled = useFeatureStatus('bugLink');
|
||||||
const isDiscordLinkEnabled = useFeatureStatus('discordLink').isFeatureEnabled;
|
const isDiscordLinkEnabled = useFeatureStatus('discordLink');
|
||||||
const isGithubLinkEnabled = useFeatureStatus('githubLink').isFeatureEnabled;
|
const isGithubLinkEnabled = useFeatureStatus('githubLink');
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose}>
|
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose}>
|
||||||
|
@ -1,32 +1,24 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import type { AppFeature, SDFeature } from 'app/types/invokeai';
|
import type { AppFeature, SDFeature } from 'app/types/invokeai';
|
||||||
|
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
export const useFeatureStatus = (feature: AppFeature | SDFeature | InvokeTabName) => {
|
export const useFeatureStatus = (feature: AppFeature | SDFeature | InvokeTabName) => {
|
||||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
const selectIsFeatureEnabled = useMemo(
|
||||||
|
|
||||||
const disabledFeatures = useAppSelector((s) => s.config.disabledFeatures);
|
|
||||||
|
|
||||||
const disabledSDFeatures = useAppSelector((s) => s.config.disabledSDFeatures);
|
|
||||||
|
|
||||||
const isFeatureDisabled = useMemo(
|
|
||||||
() =>
|
() =>
|
||||||
disabledFeatures.includes(feature as AppFeature) ||
|
createSelector(selectConfigSlice, (config) => {
|
||||||
disabledSDFeatures.includes(feature as SDFeature) ||
|
return !(
|
||||||
disabledTabs.includes(feature as InvokeTabName),
|
config.disabledFeatures.includes(feature as AppFeature) ||
|
||||||
[disabledFeatures, disabledSDFeatures, disabledTabs, feature]
|
config.disabledSDFeatures.includes(feature as SDFeature) ||
|
||||||
|
config.disabledTabs.includes(feature as InvokeTabName)
|
||||||
|
);
|
||||||
|
}),
|
||||||
|
[feature]
|
||||||
);
|
);
|
||||||
|
|
||||||
const isFeatureEnabled = useMemo(
|
const isFeatureEnabled = useAppSelector(selectIsFeatureEnabled);
|
||||||
() =>
|
|
||||||
!(
|
|
||||||
disabledFeatures.includes(feature as AppFeature) ||
|
|
||||||
disabledSDFeatures.includes(feature as SDFeature) ||
|
|
||||||
disabledTabs.includes(feature as InvokeTabName)
|
|
||||||
),
|
|
||||||
[disabledFeatures, disabledSDFeatures, disabledTabs, feature]
|
|
||||||
);
|
|
||||||
|
|
||||||
return { isFeatureDisabled, isFeatureEnabled };
|
return isFeatureEnabled;
|
||||||
};
|
};
|
||||||
|
@ -33,7 +33,7 @@ classifiers = [
|
|||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# Core generation dependencies, pinned for reproducible builds.
|
# Core generation dependencies, pinned for reproducible builds.
|
||||||
"accelerate==0.28.0",
|
"accelerate==0.29.2",
|
||||||
"clip_anytorch==2.5.2", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch==2.5.2", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel==2.0.2",
|
"compel==2.0.2",
|
||||||
"controlnet-aux==0.0.7",
|
"controlnet-aux==0.0.7",
|
||||||
@ -47,16 +47,16 @@ dependencies = [
|
|||||||
"pytorch-lightning==2.1.3",
|
"pytorch-lightning==2.1.3",
|
||||||
"safetensors==0.4.2",
|
"safetensors==0.4.2",
|
||||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||||
"torch==2.2.1",
|
"torch==2.2.2",
|
||||||
"torchmetrics==0.11.4",
|
"torchmetrics==0.11.4",
|
||||||
"torchsde==0.2.6",
|
"torchsde==0.2.6",
|
||||||
"torchvision==0.17.1",
|
"torchvision==0.17.2",
|
||||||
"transformers==4.39.1",
|
"transformers==4.39.3",
|
||||||
|
|
||||||
# Core application dependencies, pinned for reproducible builds.
|
# Core application dependencies, pinned for reproducible builds.
|
||||||
"fastapi-events==0.11.0",
|
"fastapi-events==0.11.0",
|
||||||
"fastapi==0.110.0",
|
"fastapi==0.110.0",
|
||||||
"huggingface-hub==0.21.4",
|
"huggingface-hub==0.22.2",
|
||||||
"pydantic-settings==2.2.1",
|
"pydantic-settings==2.2.1",
|
||||||
"pydantic==2.6.3",
|
"pydantic==2.6.3",
|
||||||
"python-socketio==5.11.1",
|
"python-socketio==5.11.1",
|
||||||
@ -96,7 +96,7 @@ dependencies = [
|
|||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
"xformers" = [
|
"xformers" = [
|
||||||
# Core generation dependencies, pinned for reproducible builds.
|
# Core generation dependencies, pinned for reproducible builds.
|
||||||
"xformers==0.0.25; sys_platform!='darwin'",
|
"xformers==0.0.25post1; sys_platform!='darwin'",
|
||||||
# Auxiliary dependencies, pinned only if necessary.
|
# Auxiliary dependencies, pinned only if necessary.
|
||||||
"triton; sys_platform=='linux'",
|
"triton; sys_platform=='linux'",
|
||||||
]
|
]
|
||||||
|
132
tests/backend/util/test_devices.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
"""
|
||||||
|
Test abstract device class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.services.config import get_config
|
||||||
|
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
|
||||||
|
|
||||||
|
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
|
||||||
|
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
|
||||||
|
device_types_cuda = [("cpu", torch.float32), ("cuda:0", torch.float16), ("mps", torch.float32)]
|
||||||
|
device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float16)]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_name", devices)
|
||||||
|
def test_device_choice(device_name):
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
torch_device = TorchDevice.choose_torch_device()
|
||||||
|
assert torch_device == torch.device(device_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
|
||||||
|
def test_device_dtype_cpu(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.is_available", return_value=False),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=False),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
assert torch_dtype == dtype
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
|
||||||
|
def test_device_dtype_cuda(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.is_available", return_value=True),
|
||||||
|
patch("torch.cuda.get_device_name", return_value="RTX4070"),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=False),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
assert torch_dtype == dtype
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_mps)
|
||||||
|
def test_device_dtype_mps(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.is_available", return_value=False),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=True),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
assert torch_dtype == dtype
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
|
||||||
|
def test_device_dtype_override(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.get_device_name", return_value="RTX4070"),
|
||||||
|
patch("torch.cuda.is_available", return_value=True),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=False),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
config.precision = "float32"
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
assert torch_dtype == torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize():
|
||||||
|
assert (
|
||||||
|
TorchDevice.normalize("cuda") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
TorchDevice.normalize("cuda:0") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
TorchDevice.normalize("cuda:1") == torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cuda")
|
||||||
|
)
|
||||||
|
assert TorchDevice.normalize("mps") == torch.device("mps")
|
||||||
|
assert TorchDevice.normalize("cpu") == torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_name", devices)
|
||||||
|
def test_legacy_device_choice(device_name):
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
with pytest.deprecated_call():
|
||||||
|
torch_device = choose_torch_device()
|
||||||
|
assert torch_device == torch.device(device_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
|
||||||
|
def test_legacy_device_dtype_cpu(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.is_available", return_value=False),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=False),
|
||||||
|
patch("torch.cuda.get_device_name", return_value="RTX9090"),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
with pytest.deprecated_call():
|
||||||
|
torch_device = choose_torch_device()
|
||||||
|
returned_dtype = torch_dtype(torch_device)
|
||||||
|
assert returned_dtype == dtype
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_precision_name():
|
||||||
|
config = get_config()
|
||||||
|
config.precision = "auto"
|
||||||
|
with (
|
||||||
|
pytest.deprecated_call(),
|
||||||
|
patch("torch.cuda.is_available", return_value=True),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=True),
|
||||||
|
patch("torch.cuda.get_device_name", return_value="RTX9090"),
|
||||||
|
):
|
||||||
|
assert "float16" == choose_precision(torch.device("cuda"))
|
||||||
|
assert "float16" == choose_precision(torch.device("mps"))
|
||||||
|
assert "float32" == choose_precision(torch.device("cpu"))
|