diff --git a/.github/workflows/close-inactive-issues.yml b/.github/workflows/close-inactive-issues.yml index 58ec78c0b8..9636911b2e 100644 --- a/.github/workflows/close-inactive-issues.yml +++ b/.github/workflows/close-inactive-issues.yml @@ -1,11 +1,11 @@ name: Close inactive issues on: schedule: - - cron: "00 6 * * *" + - cron: "00 4 * * *" env: - DAYS_BEFORE_ISSUE_STALE: 14 - DAYS_BEFORE_ISSUE_CLOSE: 28 + DAYS_BEFORE_ISSUE_STALE: 30 + DAYS_BEFORE_ISSUE_CLOSE: 14 jobs: close-issues: @@ -14,7 +14,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v5 + - uses: actions/stale@v8 with: days-before-issue-stale: ${{ env.DAYS_BEFORE_ISSUE_STALE }} days-before-issue-close: ${{ env.DAYS_BEFORE_ISSUE_CLOSE }} @@ -23,5 +23,6 @@ jobs: close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue." days-before-pr-stale: -1 days-before-pr-close: -1 + exempt-issue-labels: "Active Issue" repo-token: ${{ secrets.GITHUB_TOKEN }} operations-per-run: 500 diff --git a/docs/assets/nodes/groupsallscale.png b/docs/assets/nodes/groupsallscale.png new file mode 100644 index 0000000000..aa67db9a3e Binary files /dev/null and b/docs/assets/nodes/groupsallscale.png differ diff --git a/docs/assets/nodes/groupsconditioning.png b/docs/assets/nodes/groupsconditioning.png new file mode 100644 index 0000000000..b988cee78c Binary files /dev/null and b/docs/assets/nodes/groupsconditioning.png differ diff --git a/docs/assets/nodes/groupscontrol.png b/docs/assets/nodes/groupscontrol.png new file mode 100644 index 0000000000..ad696c3087 Binary files /dev/null and b/docs/assets/nodes/groupscontrol.png differ diff --git a/docs/assets/nodes/groupsimgvae.png b/docs/assets/nodes/groupsimgvae.png new file mode 100644 index 0000000000..c60bf40d67 Binary files /dev/null and b/docs/assets/nodes/groupsimgvae.png differ diff --git a/docs/assets/nodes/groupsiterate.png b/docs/assets/nodes/groupsiterate.png new file mode 100644 index 0000000000..9c1cd15bc2 Binary files /dev/null and b/docs/assets/nodes/groupsiterate.png differ diff --git a/docs/assets/nodes/groupslora.png b/docs/assets/nodes/groupslora.png new file mode 100644 index 0000000000..befcee6490 Binary files /dev/null and b/docs/assets/nodes/groupslora.png differ diff --git a/docs/assets/nodes/groupsmultigenseeding.png b/docs/assets/nodes/groupsmultigenseeding.png new file mode 100644 index 0000000000..a644146c86 Binary files /dev/null and b/docs/assets/nodes/groupsmultigenseeding.png differ diff --git a/docs/assets/nodes/groupsnoise.png b/docs/assets/nodes/groupsnoise.png new file mode 100644 index 0000000000..afb1e23e81 Binary files /dev/null and b/docs/assets/nodes/groupsnoise.png differ diff --git a/docs/assets/nodes/groupsrandseed.png b/docs/assets/nodes/groupsrandseed.png new file mode 100644 index 0000000000..06430cdee4 Binary files /dev/null and b/docs/assets/nodes/groupsrandseed.png differ diff --git a/docs/assets/nodes/nodescontrol.png b/docs/assets/nodes/nodescontrol.png new file mode 100644 index 0000000000..8b179e43ac Binary files /dev/null and b/docs/assets/nodes/nodescontrol.png differ diff --git a/docs/assets/nodes/nodesi2i.png b/docs/assets/nodes/nodesi2i.png new file mode 100644 index 0000000000..9908833804 Binary files /dev/null and b/docs/assets/nodes/nodesi2i.png differ diff --git a/docs/assets/nodes/nodest2i.png b/docs/assets/nodes/nodest2i.png new file mode 100644 index 0000000000..7e882dbf1b Binary files /dev/null and b/docs/assets/nodes/nodest2i.png differ diff --git a/docs/features/NODES.md b/docs/features/NODES.md index 5af40678d9..eef71eb974 100644 --- a/docs/features/NODES.md +++ b/docs/features/NODES.md @@ -118,49 +118,49 @@ There are several node grouping concepts that can be examined with a narrow focu As described, an initial noise tensor is necessary for the latent diffusion process. As a result, all non-image *ToLatents nodes require a noise node input. -groupsnoise +![groupsnoise](../assets/nodes/groupsnoise.png) ### Conditioning As described, conditioning is necessary for the latent diffusion process, whether empty or not. As a result, all non-image *ToLatents nodes require positive and negative conditioning inputs. Conditioning is reliant on a CLIP tokenizer provided by the Model Loader node. -groupsconditioning +![groupsconditioning](../assets/nodes/groupsconditioning.png) ### Image Space & VAE The ImageToLatents node doesn't require a noise node input, but requires a VAE input to convert the image from image space into latent space. In reverse, the LatentsToImage node requires a VAE input to convert from latent space back into image space. -groupsimgvae +![groupsimgvae](../assets/nodes/groupsimgvae.png) ### Defined & Random Seeds It is common to want to use both the same seed (for continuity) and random seeds (for variance). To define a seed, simply enter it into the 'Seed' field on a noise node. Conversely, the RandomInt node generates a random integer between 'Low' and 'High', and can be used as input to the 'Seed' edge point on a noise node to randomize your seed. -groupsrandseed +![groupsrandseed](../assets/nodes/groupsrandseed.png) ### Control Control means to guide the diffusion process to adhere to a defined input or structure. Control can be provided as input to non-image *ToLatents nodes from ControlNet nodes. ControlNet nodes usually require an image processor which converts an input image for use with ControlNet. -groupscontrol +![groupscontrol](../assets/nodes/groupscontrol.png) ### LoRA The Lora Loader node lets you load a LoRA (say that ten times fast) and pass it as output to both the Prompt (Compel) and non-image *ToLatents nodes. A model's CLIP tokenizer is passed through the LoRA into Prompt (Compel), where it affects conditioning. A model's U-Net is also passed through the LoRA into a non-image *ToLatents node, where it affects noise prediction. -groupslora +![groupslora](../assets/nodes/groupslora.png) ### Scaling Use the ImageScale, ScaleLatents, and Upscale nodes to upscale images and/or latent images. The chosen method differs across contexts. However, be aware that latents are already noisy and compressed at their original resolution; scaling an image could produce more detailed results. -groupsallscale +![groupsallscale](../assets/nodes/groupsallscale.png) ### Iteration + Multiple Images as Input Iteration is a common concept in any processing, and means to repeat a process with given input. In nodes, you're able to use the Iterate node to iterate through collections usually gathered by the Collect node. The Iterate node has many potential uses, from processing a collection of images one after another, to varying seeds across multiple image generations and more. This screenshot demonstrates how to collect several images and pass them out one at a time. -groupsiterate +![groupsiterate](../assets/nodes/groupsiterate.png) ### Multiple Image Generation + Random Seeds @@ -168,7 +168,7 @@ Multiple image generation in the node editor is done using the RandomRange node. To control seeds across generations takes some care. The first row in the screenshot will generate multiple images with different seeds, but using the same RandomRange parameters across invocations will result in the same group of random seeds being used across the images, producing repeatable results. In the second row, adding the RandomInt node as input to RandomRange's 'Seed' edge point will ensure that seeds are varied across all images across invocations, producing varied results. -groupsmultigenseeding +![groupsmultigenseeding](../assets/nodes/groupsmultigenseeding.png) ## Examples @@ -176,7 +176,7 @@ With our knowledge of node grouping and the diffusion process, let’s break dow ### Basic text-to-image Node Graph -nodest2i +![nodest2i](../assets/nodes/nodest2i.png) - Model Loader: A necessity to generating images (as we’ve read above). We choose our model from the dropdown. It outputs a U-Net, CLIP tokenizer, and VAE. - Prompt (Compel): Another necessity. Two prompt nodes are created. One will output positive conditioning (what you want, ‘dog’), one will output negative (what you don’t want, ‘cat’). They both input the CLIP tokenizer that the Model Loader node outputs. @@ -186,7 +186,7 @@ With our knowledge of node grouping and the diffusion process, let’s break dow ### Basic image-to-image Node Graph -nodesi2i +![nodesi2i](../assets/nodes/nodesi2i.png) - Model Loader: Choose a model from the dropdown. - Prompt (Compel): Two prompt nodes. One positive (dog), one negative (dog). Same CLIP inputs from the Model Loader node as before. @@ -197,7 +197,7 @@ With our knowledge of node grouping and the diffusion process, let’s break dow ### Basic ControlNet Node Graph -nodescontrol +![nodescontrol](../assets/nodes/nodescontrol.png) - Model Loader - Prompt (Compel) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 870ca33534..759f6c9f59 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -298,7 +298,7 @@ async def search_for_models( )->List[pathlib.Path]: if not search_path.is_dir(): raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") - return ApiDependencies.invoker.services.model_manager.search_for_models([search_path]) + return ApiDependencies.invoker.services.model_manager.search_for_models(search_path) @models_router.get( "/ckpt_confs", diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index fa1b6939d2..6aadbf509d 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -95,7 +95,7 @@ class CompelInvocation(BaseInvocation): def _lora_loader(): for lora in self.clip.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}), context=context) yield (lora_info.context.model, lora.weight) del lora_info return @@ -171,16 +171,16 @@ class CompelInvocation(BaseInvocation): class SDXLPromptInvocationBase: def run_clip_raw(self, context, clip_field, prompt, get_pooled): tokenizer_info = context.services.model_manager.get_model( - **clip_field.tokenizer.dict(), + **clip_field.tokenizer.dict(), context=context, ) text_encoder_info = context.services.model_manager.get_model( - **clip_field.text_encoder.dict(), + **clip_field.text_encoder.dict(), context=context, ) def _lora_loader(): for lora in clip_field.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}), context=context) yield (lora_info.context.model, lora.weight) del lora_info return @@ -196,6 +196,7 @@ class SDXLPromptInvocationBase: model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, + context=context, ).context.model ) except ModelNotFoundException: @@ -240,16 +241,16 @@ class SDXLPromptInvocationBase: def run_clip_compel(self, context, clip_field, prompt, get_pooled): tokenizer_info = context.services.model_manager.get_model( - **clip_field.tokenizer.dict(), + **clip_field.tokenizer.dict(), context=context, ) text_encoder_info = context.services.model_manager.get_model( - **clip_field.text_encoder.dict(), + **clip_field.text_encoder.dict(), context=context, ) def _lora_loader(): for lora in clip_field.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}), context=context) yield (lora_info.context.model, lora.weight) del lora_info return @@ -265,6 +266,7 @@ class SDXLPromptInvocationBase: model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, + context=context, ).context.model ) except ModelNotFoundException: diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index be9515d5d2..b0742d0419 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -2,15 +2,18 @@ from typing import Literal, Optional, Union from pydantic import BaseModel, Field -from invokeai.app.invocations.baseinvocation import (BaseInvocation, - BaseInvocationOutput, InvocationConfig, - InvocationContext) +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InvocationConfig, + InvocationContext, +) from invokeai.app.invocations.controlnet_image_processors import ControlField -from invokeai.app.invocations.model import (LoRAModelField, MainModelField, - VAEModelField) +from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField class LoRAMetadataField(BaseModel): """LoRA metadata for an image generated in InvokeAI.""" + lora: LoRAModelField = Field(description="The LoRA model") weight: float = Field(description="The weight of the LoRA model") @@ -18,7 +21,9 @@ class LoRAMetadataField(BaseModel): class CoreMetadata(BaseModel): """Core generation metadata for an image generated in InvokeAI.""" - generation_mode: str = Field(description="The generation mode that output this image",) + generation_mode: str = Field( + description="The generation mode that output this image", + ) positive_prompt: str = Field(description="The positive prompt parameter") negative_prompt: str = Field(description="The negative prompt parameter") width: int = Field(description="The width parameter") @@ -28,10 +33,20 @@ class CoreMetadata(BaseModel): cfg_scale: float = Field(description="The classifier-free guidance scale parameter") steps: int = Field(description="The number of steps used for inference") scheduler: str = Field(description="The scheduler used for inference") - clip_skip: int = Field(description="The number of skipped CLIP layers",) + clip_skip: int = Field( + description="The number of skipped CLIP layers", + ) model: MainModelField = Field(description="The main model used for inference") - controlnets: list[ControlField]= Field(description="The ControlNets used for inference") + controlnets: list[ControlField] = Field( + description="The ControlNets used for inference" + ) loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") + vae: Union[VAEModelField, None] = Field( + default=None, + description="The VAE used for decoding, if the main model's default was not used", + ) + + # Latents-to-Latents strength: Union[float, None] = Field( default=None, description="The strength used for latents-to-latents", @@ -39,9 +54,34 @@ class CoreMetadata(BaseModel): init_image: Union[str, None] = Field( default=None, description="The name of the initial image" ) - vae: Union[VAEModelField, None] = Field( + + # SDXL + positive_style_prompt: Union[str, None] = Field( + default=None, description="The positive style prompt parameter" + ) + negative_style_prompt: Union[str, None] = Field( + default=None, description="The negative style prompt parameter" + ) + + # SDXL Refiner + refiner_model: Union[MainModelField, None] = Field( + default=None, description="The SDXL Refiner model used" + ) + refiner_cfg_scale: Union[float, None] = Field( default=None, - description="The VAE used for decoding, if the main model's default was not used", + description="The classifier-free guidance scale parameter used for the refiner", + ) + refiner_steps: Union[int, None] = Field( + default=None, description="The number of steps used for the refiner" + ) + refiner_scheduler: Union[str, None] = Field( + default=None, description="The scheduler used for the refiner" + ) + refiner_aesthetic_store: Union[float, None] = Field( + default=None, description="The aesthetic score used for the refiner" + ) + refiner_start: Union[float, None] = Field( + default=None, description="The start value used for refiner denoising" ) @@ -70,7 +110,9 @@ class MetadataAccumulatorInvocation(BaseInvocation): type: Literal["metadata_accumulator"] = "metadata_accumulator" - generation_mode: str = Field(description="The generation mode that output this image",) + generation_mode: str = Field( + description="The generation mode that output this image", + ) positive_prompt: str = Field(description="The positive prompt parameter") negative_prompt: str = Field(description="The negative prompt parameter") width: int = Field(description="The width parameter") @@ -80,9 +122,13 @@ class MetadataAccumulatorInvocation(BaseInvocation): cfg_scale: float = Field(description="The classifier-free guidance scale parameter") steps: int = Field(description="The number of steps used for inference") scheduler: str = Field(description="The scheduler used for inference") - clip_skip: int = Field(description="The number of skipped CLIP layers",) + clip_skip: int = Field( + description="The number of skipped CLIP layers", + ) model: MainModelField = Field(description="The main model used for inference") - controlnets: list[ControlField]= Field(description="The ControlNets used for inference") + controlnets: list[ControlField] = Field( + description="The ControlNets used for inference" + ) loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") strength: Union[float, None] = Field( default=None, @@ -96,36 +142,44 @@ class MetadataAccumulatorInvocation(BaseInvocation): description="The VAE used for decoding, if the main model's default was not used", ) + # SDXL + positive_style_prompt: Union[str, None] = Field( + default=None, description="The positive style prompt parameter" + ) + negative_style_prompt: Union[str, None] = Field( + default=None, description="The negative style prompt parameter" + ) + + # SDXL Refiner + refiner_model: Union[MainModelField, None] = Field( + default=None, description="The SDXL Refiner model used" + ) + refiner_cfg_scale: Union[float, None] = Field( + default=None, + description="The classifier-free guidance scale parameter used for the refiner", + ) + refiner_steps: Union[int, None] = Field( + default=None, description="The number of steps used for the refiner" + ) + refiner_scheduler: Union[str, None] = Field( + default=None, description="The scheduler used for the refiner" + ) + refiner_aesthetic_store: Union[float, None] = Field( + default=None, description="The aesthetic score used for the refiner" + ) + refiner_start: Union[float, None] = Field( + default=None, description="The start value used for refiner denoising" + ) + class Config(InvocationConfig): schema_extra = { "ui": { "title": "Metadata Accumulator", - "tags": ["image", "metadata", "generation"] + "tags": ["image", "metadata", "generation"], }, } - def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput: """Collects and outputs a CoreMetadata object""" - return MetadataAccumulatorOutput( - metadata=CoreMetadata( - generation_mode=self.generation_mode, - positive_prompt=self.positive_prompt, - negative_prompt=self.negative_prompt, - width=self.width, - height=self.height, - seed=self.seed, - rand_device=self.rand_device, - cfg_scale=self.cfg_scale, - steps=self.steps, - scheduler=self.scheduler, - model=self.model, - strength=self.strength, - init_image=self.init_image, - vae=self.vae, - controlnets=self.controlnets, - loras=self.loras, - clip_skip=self.clip_skip, - ) - ) + return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict())) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 442557520a..fff0f29f14 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -119,8 +119,8 @@ class NoiseInvocation(BaseInvocation): @validator("seed", pre=True) def modulo_seed(cls, v): - """Returns the seed modulo SEED_MAX to ensure it is within the valid range.""" - return v % SEED_MAX + """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" + return v % (SEED_MAX + 1) def invoke(self, context: InvocationContext) -> NoiseOutput: noise = get_noise( diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index f877b22924..3b4d3a9d86 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -138,7 +138,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): "ui": { "title": "SDXL Refiner Model Loader", "tags": ["model", "loader", "sdxl_refiner"], - "type_hints": {"model": "model"}, + "type_hints": {"model": "refiner_model"}, }, } @@ -295,7 +295,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict() + **self.unet.unet.dict(), context=context ) do_classifier_free_guidance = True cross_attention_kwargs = None @@ -463,8 +463,8 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): unet: UNetField = Field(default=None, description="UNet submodel") latents: Optional[LatentsField] = Field(description="Initial latents") - denoising_start: float = Field(default=0.0, ge=0, lt=1, description="") - denoising_end: float = Field(default=1.0, gt=0, le=1, description="") + denoising_start: float = Field(default=0.0, ge=0, le=1, description="") + denoising_end: float = Field(default=1.0, ge=0, le=1, description="") #control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") #seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) @@ -549,13 +549,13 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): num_inference_steps = num_inference_steps - t_start # apply noise(if provided) - if self.noise is not None: + if self.noise is not None and timesteps.shape[0] > 0: noise = context.services.latents.get(self.noise.latents_name) latents = scheduler.add_noise(latents, noise, timesteps[:1]) del noise unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict() + **self.unet.unet.dict(), context=context, ) do_classifier_free_guidance = True cross_attention_kwargs = None diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 35003536e6..73d74de2d9 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -3,7 +3,13 @@ from typing import Any, Optional from invokeai.app.models.image import ProgressImage from invokeai.app.util.misc import get_timestamp -from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo +from invokeai.app.services.model_manager_service import ( + BaseModelType, + ModelType, + SubModelType, + ModelInfo, +) + class EventServiceBase: session_event: str = "session_event" @@ -38,7 +44,9 @@ class EventServiceBase: graph_execution_state_id=graph_execution_state_id, node=node, source_node_id=source_node_id, - progress_image=progress_image.dict() if progress_image is not None else None, + progress_image=progress_image.dict() + if progress_image is not None + else None, step=step, total_steps=total_steps, ), @@ -67,6 +75,7 @@ class EventServiceBase: graph_execution_state_id: str, node: dict, source_node_id: str, + error_type: str, error: str, ) -> None: """Emitted when an invocation has completed""" @@ -76,6 +85,7 @@ class EventServiceBase: graph_execution_state_id=graph_execution_state_id, node=node, source_node_id=source_node_id, + error_type=error_type, error=error, ), ) @@ -102,13 +112,13 @@ class EventServiceBase: ), ) - def emit_model_load_started ( - self, - graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, + def emit_model_load_started( + self, + graph_execution_state_id: str, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: SubModelType, ) -> None: """Emitted when a model is requested""" self.__emit_session_event( @@ -123,13 +133,13 @@ class EventServiceBase: ) def emit_model_load_completed( - self, - graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, - model_info: ModelInfo, + self, + graph_execution_state_id: str, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: SubModelType, + model_info: ModelInfo, ) -> None: """Emitted when a model is correctly loaded (returns model info)""" self.__emit_session_event( @@ -145,3 +155,37 @@ class EventServiceBase: precision=str(model_info.precision), ), ) + + def emit_session_retrieval_error( + self, + graph_execution_state_id: str, + error_type: str, + error: str, + ) -> None: + """Emitted when session retrieval fails""" + self.__emit_session_event( + event_name="session_retrieval_error", + payload=dict( + graph_execution_state_id=graph_execution_state_id, + error_type=error_type, + error=error, + ), + ) + + def emit_invocation_retrieval_error( + self, + graph_execution_state_id: str, + node_id: str, + error_type: str, + error: str, + ) -> None: + """Emitted when invocation retrieval fails""" + self.__emit_session_event( + event_name="invocation_retrieval_error", + payload=dict( + graph_execution_state_id=graph_execution_state_id, + node_id=node_id, + error_type=error_type, + error=error, + ), + ) diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index b1b995309e..f7d3b3a7a7 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -600,7 +600,7 @@ class ModelManagerService(ModelManagerServiceBase): """ Return list of all models found in the designated directory. """ - search = FindModels(directory,self.logger) + search = FindModels([directory], self.logger) return search.list_models() def sync_to_config(self): diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index e11eb84b3d..5995e4ffc3 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -39,21 +39,41 @@ class DefaultInvocationProcessor(InvocationProcessorABC): try: queue_item: InvocationQueueItem = self.__invoker.services.queue.get() except Exception as e: - logger.debug("Exception while getting from queue: %s" % e) + self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e) if not queue_item: # Probably stopping # do not hammer the queue time.sleep(0.5) continue - graph_execution_state = ( - self.__invoker.services.graph_execution_manager.get( - queue_item.graph_execution_state_id + try: + graph_execution_state = ( + self.__invoker.services.graph_execution_manager.get( + queue_item.graph_execution_state_id + ) ) - ) - invocation = graph_execution_state.execution_graph.get_node( - queue_item.invocation_id - ) + except Exception as e: + self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e) + self.__invoker.services.events.emit_session_retrieval_error( + graph_execution_state_id=queue_item.graph_execution_state_id, + error_type=e.__class__.__name__, + error=traceback.format_exc(), + ) + continue + + try: + invocation = graph_execution_state.execution_graph.get_node( + queue_item.invocation_id + ) + except Exception as e: + self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e) + self.__invoker.services.events.emit_invocation_retrieval_error( + graph_execution_state_id=queue_item.graph_execution_state_id, + node_id=queue_item.invocation_id, + error_type=e.__class__.__name__, + error=traceback.format_exc(), + ) + continue # get the source node id to provide to clients (the prepared node id is not as useful) source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] @@ -114,11 +134,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC): graph_execution_state ) + self.__invoker.services.logger.error("Error while invoking:\n%s" % e) # Send error event self.__invoker.services.events.emit_invocation_error( graph_execution_state_id=graph_execution_state.id, node=invocation.dict(), source_node_id=source_node_id, + error_type=e.__class__.__name__, error=error, ) @@ -136,11 +158,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC): try: self.__invoker.invoke(graph_execution_state, invoke_all=True) except Exception as e: - logger.error("Error while invoking: %s" % e) + self.__invoker.services.logger.error("Error while invoking:\n%s" % e) self.__invoker.services.events.emit_invocation_error( graph_execution_state_id=graph_execution_state.id, node=invocation.dict(), source_node_id=source_node_id, + error_type=e.__class__.__name__, error=traceback.format_exc() ) elif is_complete: diff --git a/invokeai/app/util/misc.py b/invokeai/app/util/misc.py index 7c674674e2..503f3af4c8 100644 --- a/invokeai/app/util/misc.py +++ b/invokeai/app/util/misc.py @@ -14,8 +14,9 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime: return datetime.datetime.fromisoformat(iso_timestamp) -SEED_MAX = np.iinfo(np.int32).max +SEED_MAX = np.iinfo(np.uint32).max def get_random_seed(): - return np.random.randint(0, SEED_MAX) + rng = np.random.default_rng(seed=0) + return int(rng.integers(0, SEED_MAX)) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index b0481f3cfa..222169afbb 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -474,7 +474,7 @@ class ModelPatcher: @staticmethod def _lora_forward_hook( - applied_loras: List[Tuple[LoraModel, float]], + applied_loras: List[Tuple[LoRAModel, float]], layer_name: str, ): @@ -519,7 +519,7 @@ class ModelPatcher: def apply_lora( cls, model: torch.nn.Module, - loras: List[Tuple[LoraModel, float]], + loras: List[Tuple[LoRAModel, float]], prefix: str, ): original_weights = dict() diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management/model_search.py index 1e282b4bb8..5657bd9549 100644 --- a/invokeai/backend/model_management/model_search.py +++ b/invokeai/backend/model_management/model_search.py @@ -98,6 +98,6 @@ class FindModels(ModelSearch): def list_models(self) -> List[Path]: self.search() - return self.models_found + return list(self.models_found) diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 5387ade0e5..eb771841ec 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -10,6 +10,7 @@ from .base import ( SubModelType, classproperty, InvalidModelException, + ModelNotFoundException, ) # TODO: naming from ..lora import LoRAModel as LoRAModelRaw diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index 1194ea467b..b8fab16c1c 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -1,2 +1,2 @@ export const NUMPY_RAND_MIN = 0; -export const NUMPY_RAND_MAX = 2147483647; +export const NUMPY_RAND_MAX = 4294967295; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 04f0ce7a0b..f06c324bc6 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -65,11 +65,14 @@ import { addGeneratorProgressEventListener as addGeneratorProgressListener } fro import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete'; import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete'; import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError'; +import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError'; import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted'; import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad'; +import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError'; import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed'; import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed'; import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved'; +import { addTabChangedListener } from './listeners/tabChanged'; import { addUpscaleRequestedListener } from './listeners/upscaleRequested'; import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; @@ -153,6 +156,8 @@ addSocketDisconnectedListener(); addSocketSubscribedListener(); addSocketUnsubscribedListener(); addModelLoadEventListener(); +addSessionRetrievalErrorEventListener(); +addInvocationRetrievalErrorEventListener(); // Session Created addSessionCreatedPendingListener(); @@ -197,3 +202,6 @@ addFirstListImagesListener(); // Ad-hoc upscale workflwo addUpscaleRequestedListener(); + +// Tab Change +addTabChangedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/appConfigReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/appConfigReceived.ts index 68148a192f..ef81377f99 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/appConfigReceived.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/appConfigReceived.ts @@ -1,4 +1,8 @@ import { setInfillMethod } from 'features/parameters/store/generationSlice'; +import { + shouldUseNSFWCheckerChanged, + shouldUseWatermarkerChanged, +} from 'features/system/store/systemSlice'; import { appInfoApi } from 'services/api/endpoints/appInfo'; import { startAppListening } from '..'; @@ -6,12 +10,21 @@ export const addAppConfigReceivedListener = () => { startAppListening({ matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled, effect: async (action, { getState, dispatch }) => { - const { infill_methods } = action.payload; + const { infill_methods, nsfw_methods, watermarking_methods } = + action.payload; const infillMethod = getState().generation.infillMethod; if (!infill_methods.includes(infillMethod)) { dispatch(setInfillMethod(infill_methods[0])); } + + if (!nsfw_methods.includes('nsfw_checker')) { + dispatch(shouldUseNSFWCheckerChanged(false)); + } + + if (!watermarking_methods.includes('invisible_watermark')) { + dispatch(shouldUseWatermarkerChanged(false)); + } }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index 57981918d8..1e0b3dbc61 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -9,13 +9,19 @@ import { zMainModel, zVaeModel, } from 'features/parameters/types/parameterSchemas'; +import { + refinerModelChanged, + setShouldUseSDXLRefiner, +} from 'features/sdxl/store/sdxlSlice'; import { forEach, some } from 'lodash-es'; import { modelsApi } from 'services/api/endpoints/models'; import { startAppListening } from '..'; export const addModelsLoadedListener = () => { startAppListening({ - matcher: modelsApi.endpoints.getMainModels.matchFulfilled, + predicate: (state, action) => + modelsApi.endpoints.getMainModels.matchFulfilled(action) && + !action.meta.arg.originalArgs.includes('sdxl-refiner'), effect: async (action, { getState, dispatch }) => { // models loaded, we need to ensure the selected model is available and if not, select the first one const log = logger('models'); @@ -59,6 +65,54 @@ export const addModelsLoadedListener = () => { dispatch(modelChanged(result.data)); }, }); + startAppListening({ + predicate: (state, action) => + modelsApi.endpoints.getMainModels.matchFulfilled(action) && + action.meta.arg.originalArgs.includes('sdxl-refiner'), + effect: async (action, { getState, dispatch }) => { + // models loaded, we need to ensure the selected model is available and if not, select the first one + const log = logger('models'); + log.info( + { models: action.payload.entities }, + `SDXL Refiner models loaded (${action.payload.ids.length})` + ); + + const currentModel = getState().sdxl.refinerModel; + + const isCurrentModelAvailable = some( + action.payload.entities, + (m) => + m?.model_name === currentModel?.model_name && + m?.base_model === currentModel?.base_model + ); + + if (isCurrentModelAvailable) { + return; + } + + const firstModelId = action.payload.ids[0]; + const firstModel = action.payload.entities[firstModelId]; + + if (!firstModel) { + // No models loaded at all + dispatch(refinerModelChanged(null)); + dispatch(setShouldUseSDXLRefiner(false)); + return; + } + + const result = zMainModel.safeParse(firstModel); + + if (!result.success) { + log.error( + { error: result.error.format() }, + 'Failed to parse SDXL Refiner Model' + ); + return; + } + + dispatch(refinerModelChanged(result.data)); + }, + }); startAppListening({ matcher: modelsApi.endpoints.getVaeModels.matchFulfilled, effect: async (action, { getState, dispatch }) => { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts index 5709d87d22..e89acb7542 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts @@ -33,12 +33,11 @@ export const addSessionCreatedRejectedListener = () => { effect: (action) => { const log = logger('session'); if (action.payload) { - const { error } = action.payload; + const { error, status } = action.payload; const graph = parseify(action.meta.arg); - const stringifiedError = JSON.stringify(error); log.error( - { graph, error: serializeError(error) }, - `Problem creating session: ${stringifiedError}` + { graph, status, error: serializeError(error) }, + `Problem creating session` ); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts index 60009ed194..a62f75d957 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts @@ -31,13 +31,12 @@ export const addSessionInvokedRejectedListener = () => { const { session_id } = action.meta.arg; if (action.payload) { const { error } = action.payload; - const stringifiedError = JSON.stringify(error); log.error( { session_id, error: serializeError(error), }, - `Problem invoking session: ${stringifiedError}` + `Problem invoking session` ); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index ae0f049ae7..32a6cce203 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -1,4 +1,6 @@ import { logger } from 'app/logging/logger'; +import { LIST_TAG } from 'services/api'; +import { appInfoApi } from 'services/api/endpoints/appInfo'; import { modelsApi } from 'services/api/endpoints/models'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { appSocketConnected, socketConnected } from 'services/events/actions'; @@ -24,11 +26,18 @@ export const addSocketConnectedEventListener = () => { dispatch(appSocketConnected(action.payload)); // update all server state - dispatch(modelsApi.endpoints.getMainModels.initiate()); - dispatch(modelsApi.endpoints.getControlNetModels.initiate()); - dispatch(modelsApi.endpoints.getLoRAModels.initiate()); - dispatch(modelsApi.endpoints.getTextualInversionModels.initiate()); - dispatch(modelsApi.endpoints.getVaeModels.initiate()); + dispatch( + modelsApi.util.invalidateTags([ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + { type: 'LoRAModel', id: LIST_TAG }, + { type: 'ControlNetModel', id: LIST_TAG }, + { type: 'VaeModel', id: LIST_TAG }, + { type: 'TextualInversionModel', id: LIST_TAG }, + { type: 'ScannedModels', id: LIST_TAG }, + ]) + ); + dispatch(appInfoApi.util.invalidateTags(['AppConfig', 'AppVersion'])); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts new file mode 100644 index 0000000000..aa88457eb7 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts @@ -0,0 +1,20 @@ +import { logger } from 'app/logging/logger'; +import { + appSocketInvocationRetrievalError, + socketInvocationRetrievalError, +} from 'services/events/actions'; +import { startAppListening } from '../..'; + +export const addInvocationRetrievalErrorEventListener = () => { + startAppListening({ + actionCreator: socketInvocationRetrievalError, + effect: (action, { dispatch }) => { + const log = logger('socketio'); + log.error( + action.payload, + `Invocation retrieval error (${action.payload.data.graph_execution_state_id})` + ); + dispatch(appSocketInvocationRetrievalError(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts index 15447e5350..4e77811762 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts @@ -21,7 +21,10 @@ export const addInvocationStartedEventListener = () => { return; } - log.debug(action.payload, 'Invocation started'); + log.debug( + action.payload, + `Invocation started (${action.payload.data.node.type})` + ); dispatch(appSocketInvocationStarted(action.payload)); }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts new file mode 100644 index 0000000000..7efb7f463a --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts @@ -0,0 +1,20 @@ +import { logger } from 'app/logging/logger'; +import { + appSocketSessionRetrievalError, + socketSessionRetrievalError, +} from 'services/events/actions'; +import { startAppListening } from '../..'; + +export const addSessionRetrievalErrorEventListener = () => { + startAppListening({ + actionCreator: socketSessionRetrievalError, + effect: (action, { dispatch }) => { + const log = logger('socketio'); + log.error( + action.payload, + `Session retrieval error (${action.payload.data.graph_execution_state_id})` + ); + dispatch(appSocketSessionRetrievalError(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts new file mode 100644 index 0000000000..578241573c --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/tabChanged.ts @@ -0,0 +1,56 @@ +import { modelChanged } from 'features/parameters/store/generationSlice'; +import { setActiveTab } from 'features/ui/store/uiSlice'; +import { forEach } from 'lodash-es'; +import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; +import { + MainModelConfigEntity, + modelsApi, +} from 'services/api/endpoints/models'; +import { startAppListening } from '..'; + +export const addTabChangedListener = () => { + startAppListening({ + actionCreator: setActiveTab, + effect: (action, { getState, dispatch }) => { + const activeTabName = action.payload; + if (activeTabName === 'unifiedCanvas') { + // grab the models from RTK Query cache + const { data } = modelsApi.endpoints.getMainModels.select( + NON_REFINER_BASE_MODELS + )(getState()); + + if (!data) { + // no models yet, so we can't do anything + dispatch(modelChanged(null)); + return; + } + + // need to filter out all the invalid canvas models (currently, this is just sdxl) + const validCanvasModels: MainModelConfigEntity[] = []; + + forEach(data.entities, (entity) => { + if (!entity) { + return; + } + if (['sd-1', 'sd-2'].includes(entity.base_model)) { + validCanvasModels.push(entity); + } + }); + + // this could still be undefined even tho TS doesn't say so + const firstValidCanvasModel = validCanvasModels[0]; + + if (!firstValidCanvasModel) { + // uh oh, we have no models that are valid for canvas + dispatch(modelChanged(null)); + return; + } + + // only store the model name and base model in redux + const { base_model, model_name } = firstValidCanvasModel; + + dispatch(modelChanged({ base_model, model_name })); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts index 2ef62aed7b..39bd742d7d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts @@ -39,8 +39,22 @@ export const addUserInvokedCanvasListener = () => { const state = getState(); + const { + layerState, + boundingBoxCoordinates, + boundingBoxDimensions, + isMaskEnabled, + shouldPreserveMaskedArea, + } = state.canvas; + // Build canvas blobs - const canvasBlobsAndImageData = await getCanvasData(state); + const canvasBlobsAndImageData = await getCanvasData( + layerState, + boundingBoxCoordinates, + boundingBoxDimensions, + isMaskEnabled, + shouldPreserveMaskedArea + ); if (!canvasBlobsAndImageData) { log.error('Unable to create canvas data'); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts index 8101530eea..b0172e693b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts @@ -3,6 +3,7 @@ import { userInvoked } from 'app/store/actions'; import { parseify } from 'common/util/serialize'; import { imageToImageGraphBuilt } from 'features/nodes/store/actions'; import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph'; +import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph'; import { sessionReadyToInvoke } from 'features/system/store/actions'; import { sessionCreated } from 'services/api/thunks/session'; import { startAppListening } from '..'; @@ -14,8 +15,16 @@ export const addUserInvokedImageToImageListener = () => { effect: async (action, { getState, dispatch, take }) => { const log = logger('session'); const state = getState(); + const model = state.generation.model; + + let graph; + + if (model && model.base_model === 'sdxl') { + graph = buildLinearSDXLImageToImageGraph(state); + } else { + graph = buildLinearImageToImageGraph(state); + } - const graph = buildLinearImageToImageGraph(state); dispatch(imageToImageGraphBuilt(graph)); log.debug({ graph: parseify(graph) }, 'Image to Image graph built'); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts index a9e9fe1ad8..f1cdabcabd 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts @@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger'; import { userInvoked } from 'app/store/actions'; import { parseify } from 'common/util/serialize'; import { textToImageGraphBuilt } from 'features/nodes/store/actions'; +import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph'; import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph'; import { sessionReadyToInvoke } from 'features/system/store/actions'; import { sessionCreated } from 'services/api/thunks/session'; @@ -14,8 +15,15 @@ export const addUserInvokedTextToImageListener = () => { effect: async (action, { getState, dispatch, take }) => { const log = logger('session'); const state = getState(); + const model = state.generation.model; - const graph = buildLinearTextToImageGraph(state); + let graph; + + if (model && model.base_model === 'sdxl') { + graph = buildLinearSDXLTextToImageGraph(state); + } else { + graph = buildLinearTextToImageGraph(state); + } dispatch(textToImageGraphBuilt(graph)); diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index dd80e1e378..d71a147913 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -15,6 +15,7 @@ import loraReducer from 'features/lora/store/loraSlice'; import nodesReducer from 'features/nodes/store/nodesSlice'; import generationReducer from 'features/parameters/store/generationSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; +import sdxlReducer from 'features/sdxl/store/sdxlSlice'; import configReducer from 'features/system/store/configSlice'; import systemReducer from 'features/system/store/systemSlice'; import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice'; @@ -47,6 +48,7 @@ const allReducers = { imageDeletion: imageDeletionReducer, lora: loraReducer, modelmanager: modelmanagerReducer, + sdxl: sdxlReducer, [api.reducerPath]: api.reducer, }; @@ -58,6 +60,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [ 'canvas', 'gallery', 'generation', + 'sdxl', 'nodes', 'postprocessing', 'system', diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index be642a6435..b38790e0c9 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -95,7 +95,8 @@ export type AppFeature = | 'localization' | 'consoleLogging' | 'dynamicPrompting' - | 'batches'; + | 'batches' + | 'syncModels'; /** * A disable-able Stable Diffusion feature diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts index 384150be10..f43ec1851f 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts @@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; // import { validateSeedWeights } from 'common/util/seedWeightPairs'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { forEach } from 'lodash-es'; +import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; import { modelsApi } from '../../services/api/endpoints/models'; const readinessSelector = createSelector( @@ -24,7 +25,7 @@ const readinessSelector = createSelector( } const { isSuccess: mainModelsSuccessfullyLoaded } = - modelsApi.endpoints.getMainModels.select()(state); + modelsApi.endpoints.getMainModels.select(NON_REFINER_BASE_MODELS)(state); if (!mainModelsSuccessfullyLoaded) { isReady = false; reasonsWhyNotReady.push('Models are not loaded'); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx index 69bf628a39..8c1dfbb86f 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStatusText.tsx @@ -2,8 +2,8 @@ import { Box, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { canvasSelector } from 'features/canvas/store/canvasSelectors'; +import GenerationModeStatusText from 'features/parameters/components/Parameters/Canvas/GenerationModeStatusText'; import { isEqual } from 'lodash-es'; - import { useTranslation } from 'react-i18next'; import roundToHundreth from '../util/roundToHundreth'; import IAICanvasStatusTextCursorPos from './IAICanvasStatusText/IAICanvasStatusTextCursorPos'; @@ -110,6 +110,7 @@ const IAICanvasStatusText = () => { }, }} > + { }} > - { + const layerState = useAppSelector((state) => state.canvas.layerState); + + const boundingBoxCoordinates = useAppSelector( + (state) => state.canvas.boundingBoxCoordinates + ); + const boundingBoxDimensions = useAppSelector( + (state) => state.canvas.boundingBoxDimensions + ); + const isMaskEnabled = useAppSelector((state) => state.canvas.isMaskEnabled); + + const shouldPreserveMaskedArea = useAppSelector( + (state) => state.canvas.shouldPreserveMaskedArea + ); + const [generationMode, setGenerationMode] = useState< + GenerationMode | undefined + >(); + + useEffect(() => { + setGenerationMode(undefined); + }, [ + layerState, + boundingBoxCoordinates, + boundingBoxDimensions, + isMaskEnabled, + shouldPreserveMaskedArea, + ]); + + useDebounce( + async () => { + // Build canvas blobs + const canvasBlobsAndImageData = await getCanvasData( + layerState, + boundingBoxCoordinates, + boundingBoxDimensions, + isMaskEnabled, + shouldPreserveMaskedArea + ); + + if (!canvasBlobsAndImageData) { + return; + } + + const { baseImageData, maskImageData } = canvasBlobsAndImageData; + + // Determine the generation mode + const generationMode = getCanvasGenerationMode( + baseImageData, + maskImageData + ); + + setGenerationMode(generationMode); + }, + 1000, + [ + layerState, + boundingBoxCoordinates, + boundingBoxDimensions, + isMaskEnabled, + shouldPreserveMaskedArea, + ] + ); + + return generationMode; +}; diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts index 48d59395ab..ba85a7e132 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts @@ -168,4 +168,7 @@ export interface CanvasState { stageDimensions: Dimensions; stageScale: number; tool: CanvasTool; + generationMode?: GenerationMode; } + +export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint'; diff --git a/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts b/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts index d37ee7b8d0..4e575791ed 100644 --- a/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts +++ b/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts @@ -1,6 +1,10 @@ import { logger } from 'app/logging/logger'; -import { RootState } from 'app/store/store'; -import { isCanvasMaskLine } from '../store/canvasTypes'; +import { Vector2d } from 'konva/lib/types'; +import { + CanvasLayerState, + Dimensions, + isCanvasMaskLine, +} from '../store/canvasTypes'; import createMaskStage from './createMaskStage'; import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider'; import { konvaNodeToBlob } from './konvaNodeToBlob'; @@ -9,7 +13,13 @@ import { konvaNodeToImageData } from './konvaNodeToImageData'; /** * Gets Blob and ImageData objects for the base and mask layers */ -export const getCanvasData = async (state: RootState) => { +export const getCanvasData = async ( + layerState: CanvasLayerState, + boundingBoxCoordinates: Vector2d, + boundingBoxDimensions: Dimensions, + isMaskEnabled: boolean, + shouldPreserveMaskedArea: boolean +) => { const log = logger('canvas'); const canvasBaseLayer = getCanvasBaseLayer(); @@ -20,14 +30,6 @@ export const getCanvasData = async (state: RootState) => { return; } - const { - layerState: { objects }, - boundingBoxCoordinates, - boundingBoxDimensions, - isMaskEnabled, - shouldPreserveMaskedArea, - } = state.canvas; - const boundingBox = { ...boundingBoxCoordinates, ...boundingBoxDimensions, @@ -58,7 +60,7 @@ export const getCanvasData = async (state: RootState) => { // For the mask layer, use the normal boundingBox const maskStage = await createMaskStage( - isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled + isMaskEnabled ? layerState.objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled boundingBox, shouldPreserveMaskedArea ); diff --git a/invokeai/frontend/web/src/features/canvas/util/getCanvasGenerationMode.ts b/invokeai/frontend/web/src/features/canvas/util/getCanvasGenerationMode.ts index 5b38ecf938..d3e8792690 100644 --- a/invokeai/frontend/web/src/features/canvas/util/getCanvasGenerationMode.ts +++ b/invokeai/frontend/web/src/features/canvas/util/getCanvasGenerationMode.ts @@ -2,11 +2,12 @@ import { areAnyPixelsBlack, getImageDataTransparency, } from 'common/util/arrayBuffer'; +import { GenerationMode } from '../store/canvasTypes'; export const getCanvasGenerationMode = ( baseImageData: ImageData, maskImageData: ImageData -) => { +): GenerationMode => { const { isPartiallyTransparent: baseIsPartiallyTransparent, isFullyTransparent: baseIsFullyTransparent, diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx index 23effc5375..0ecc43ef9c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx @@ -20,6 +20,7 @@ import StringInputFieldComponent from './fields/StringInputFieldComponent'; import UnetInputFieldComponent from './fields/UnetInputFieldComponent'; import VaeInputFieldComponent from './fields/VaeInputFieldComponent'; import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent'; +import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent'; type InputFieldComponentProps = { nodeId: string; @@ -155,6 +156,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => { ); } + if (type === 'refiner_model' && template.type === 'refiner_model') { + return ( + + ); + } + if (type === 'vae_model' && template.type === 'vae_model') { return ( @@ -24,8 +26,11 @@ const ModelInputFieldComponent = ( const dispatch = useAppDispatch(); const { t } = useTranslation(); + const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled; - const { data: mainModels, isLoading } = useGetMainModelsQuery(); + const { data: mainModels, isLoading } = useGetMainModelsQuery( + NON_REFINER_BASE_MODELS + ); const data = useMemo(() => { if (!mainModels) { @@ -103,9 +108,11 @@ const ModelInputFieldComponent = ( disabled={data.length === 0} onChange={handleChangeModel} /> - - - + {isSyncModelEnabled && ( + + + + )} ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/RefinerModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/RefinerModelInputFieldComponent.tsx new file mode 100644 index 0000000000..28c6567e8d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/RefinerModelInputFieldComponent.tsx @@ -0,0 +1,120 @@ +import { Box, Flex } from '@chakra-ui/react'; +import { SelectItem } from '@mantine/core'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; +import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; +import { + RefinerModelInputFieldTemplate, + RefinerModelInputFieldValue, +} from 'features/nodes/types/types'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; +import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton'; +import { forEach } from 'lodash-es'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { REFINER_BASE_MODELS } from 'services/api/constants'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { FieldComponentProps } from './types'; +import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; + +const RefinerModelInputFieldComponent = ( + props: FieldComponentProps< + RefinerModelInputFieldValue, + RefinerModelInputFieldTemplate + > +) => { + const { nodeId, field } = props; + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled; + const { data: refinerModels, isLoading } = + useGetMainModelsQuery(REFINER_BASE_MODELS); + + const data = useMemo(() => { + if (!refinerModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(refinerModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.model_name, + group: MODEL_TYPE_MAP[model.base_model], + }); + }); + + return data; + }, [refinerModels]); + + // grab the full model entity from the RTK Query cache + // TODO: maybe we should just store the full model entity in state? + const selectedModel = useMemo( + () => + refinerModels?.entities[ + `${field.value?.base_model}/main/${field.value?.model_name}` + ] ?? null, + [field.value?.base_model, field.value?.model_name, refinerModels?.entities] + ); + + const handleChangeModel = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + const newModel = modelIdToMainModelParam(v); + + if (!newModel) { + return; + } + + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: newModel, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + return isLoading ? ( + + ) : ( + + 0 ? 'Select a model' : 'No models available'} + data={data} + error={data.length === 0} + disabled={data.length === 0} + onChange={handleChangeModel} + /> + {isSyncModelEnabled && ( + + + + )} + + ); +}; + +export default memo(RefinerModelInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index c07efc9c27..d83d240847 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -17,6 +17,7 @@ export const FIELD_TYPE_MAP: Record = { ClipField: 'clip', VaeField: 'vae', model: 'model', + refiner_model: 'refiner_model', vae_model: 'vae_model', lora_model: 'lora_model', controlnet_model: 'controlnet_model', @@ -120,6 +121,12 @@ export const FIELDS: Record = { title: 'Model', description: 'Models are models.', }, + refiner_model: { + color: 'teal', + colorCssVar: getColorTokenCssVariable('teal'), + title: 'Refiner Model', + description: 'Models are models.', + }, vae_model: { color: 'teal', colorCssVar: getColorTokenCssVariable('teal'), diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index f111155a39..157b990b96 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -70,6 +70,7 @@ export type FieldType = | 'vae' | 'control' | 'model' + | 'refiner_model' | 'vae_model' | 'lora_model' | 'controlnet_model' @@ -100,6 +101,7 @@ export type InputFieldValue = | ControlInputFieldValue | EnumInputFieldValue | MainModelInputFieldValue + | RefinerModelInputFieldValue | VaeModelInputFieldValue | LoRAModelInputFieldValue | ControlNetModelInputFieldValue @@ -128,6 +130,7 @@ export type InputFieldTemplate = | ControlInputFieldTemplate | EnumInputFieldTemplate | ModelInputFieldTemplate + | RefinerModelInputFieldTemplate | VaeModelInputFieldTemplate | LoRAModelInputFieldTemplate | ControlNetModelInputFieldTemplate @@ -243,6 +246,11 @@ export type MainModelInputFieldValue = FieldValueBase & { value?: MainModelParam; }; +export type RefinerModelInputFieldValue = FieldValueBase & { + type: 'refiner_model'; + value?: MainModelParam; +}; + export type VaeModelInputFieldValue = FieldValueBase & { type: 'vae_model'; value?: VaeModelParam; @@ -367,6 +375,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & { type: 'model'; }; +export type RefinerModelInputFieldTemplate = InputFieldTemplateBase & { + default: string; + type: 'refiner_model'; +}; + export type VaeModelInputFieldTemplate = InputFieldTemplateBase & { default: string; type: 'vae_model'; diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index 9c01deded6..83692533f7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -22,6 +22,7 @@ import { LoRAModelInputFieldTemplate, ModelInputFieldTemplate, OutputFieldTemplate, + RefinerModelInputFieldTemplate, StringInputFieldTemplate, TypeHints, UNetInputFieldTemplate, @@ -178,6 +179,21 @@ const buildModelInputFieldTemplate = ({ return template; }; +const buildRefinerModelInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): RefinerModelInputFieldTemplate => { + const template: RefinerModelInputFieldTemplate = { + ...baseField, + type: 'refiner_model', + inputRequirement: 'always', + inputKind: 'direct', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildVaeModelInputFieldTemplate = ({ schemaObject, baseField, @@ -492,6 +508,9 @@ export const buildInputFieldTemplate = ( if (['model'].includes(fieldType)) { return buildModelInputFieldTemplate({ schemaObject, baseField }); } + if (['refiner_model'].includes(fieldType)) { + return buildRefinerModelInputFieldTemplate({ schemaObject, baseField }); + } if (['vae_model'].includes(fieldType)) { return buildVaeModelInputFieldTemplate({ schemaObject, baseField }); } diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index f54a7640bd..3c6850a88a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -76,6 +76,10 @@ export const buildInputFieldValue = ( fieldValue.value = undefined; } + if (template.type === 'refiner_model') { + fieldValue.value = undefined; + } + if (template.type === 'vae_model') { fieldValue.value = undefined; } diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts new file mode 100644 index 0000000000..35e7f3ac38 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts @@ -0,0 +1,70 @@ +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; +import { + ImageNSFWBlurInvocation, + LatentsToImageInvocation, + MetadataAccumulatorInvocation, +} from 'services/api/types'; +import { + LATENTS_TO_IMAGE, + METADATA_ACCUMULATOR, + NSFW_CHECKER, +} from './constants'; + +export const addNSFWCheckerToGraph = ( + state: RootState, + graph: NonNullableGraph, + nodeIdToAddTo = LATENTS_TO_IMAGE +): void => { + const activeTabName = activeTabNameSelector(state); + + const is_intermediate = + activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false; + + const nodeToAddTo = graph.nodes[nodeIdToAddTo] as + | LatentsToImageInvocation + | undefined; + + const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as + | MetadataAccumulatorInvocation + | undefined; + + if (!nodeToAddTo) { + // something has gone terribly awry + return; + } + + nodeToAddTo.is_intermediate = true; + + const nsfwCheckerNode: ImageNSFWBlurInvocation = { + id: NSFW_CHECKER, + type: 'img_nsfw', + is_intermediate, + }; + + graph.nodes[NSFW_CHECKER] = nsfwCheckerNode; + graph.edges.push({ + source: { + node_id: nodeIdToAddTo, + field: 'image', + }, + destination: { + node_id: NSFW_CHECKER, + field: 'image', + }, + }); + + if (metadataAccumulator) { + graph.edges.push({ + source: { + node_id: METADATA_ACCUMULATOR, + field: 'metadata', + }, + destination: { + node_id: NSFW_CHECKER, + field: 'metadata', + }, + }); + } +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts new file mode 100644 index 0000000000..c47c7be8b4 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts @@ -0,0 +1,186 @@ +import { RootState } from 'app/store/store'; +import { MetadataAccumulatorInvocation } from 'services/api/types'; +import { NonNullableGraph } from '../../types/types'; +import { + IMAGE_TO_LATENTS, + LATENTS_TO_IMAGE, + METADATA_ACCUMULATOR, + SDXL_LATENTS_TO_LATENTS, + SDXL_MODEL_LOADER, + SDXL_REFINER_LATENTS_TO_LATENTS, + SDXL_REFINER_MODEL_LOADER, + SDXL_REFINER_NEGATIVE_CONDITIONING, + SDXL_REFINER_POSITIVE_CONDITIONING, +} from './constants'; + +export const addSDXLRefinerToGraph = ( + state: RootState, + graph: NonNullableGraph, + baseNodeId: string +): void => { + const { positivePrompt, negativePrompt } = state.generation; + const { + refinerModel, + refinerAestheticScore, + positiveStylePrompt, + negativeStylePrompt, + refinerSteps, + refinerScheduler, + refinerCFGScale, + refinerStart, + } = state.sdxl; + + if (!refinerModel) return; + + const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as + | MetadataAccumulatorInvocation + | undefined; + + if (metadataAccumulator) { + metadataAccumulator.refiner_model = refinerModel; + metadataAccumulator.refiner_aesthetic_store = refinerAestheticScore; + metadataAccumulator.refiner_cfg_scale = refinerCFGScale; + metadataAccumulator.refiner_scheduler = refinerScheduler; + metadataAccumulator.refiner_start = refinerStart; + metadataAccumulator.refiner_steps = refinerSteps; + } + + // Unplug SDXL Latents Generation To Latents To Image + graph.edges = graph.edges.filter( + (e) => + !(e.source.node_id === baseNodeId && ['latents'].includes(e.source.field)) + ); + + graph.edges = graph.edges.filter( + (e) => + !( + e.source.node_id === SDXL_MODEL_LOADER && + ['vae'].includes(e.source.field) + ) + ); + + // connect the VAE back to the i2l, which we just removed in the filter + // but only if we are doing l2l + if (baseNodeId === SDXL_LATENTS_TO_LATENTS) { + graph.edges.push({ + source: { + node_id: SDXL_MODEL_LOADER, + field: 'vae', + }, + destination: { + node_id: IMAGE_TO_LATENTS, + field: 'vae', + }, + }); + } + + graph.nodes[SDXL_REFINER_MODEL_LOADER] = { + type: 'sdxl_refiner_model_loader', + id: SDXL_REFINER_MODEL_LOADER, + model: refinerModel, + }; + graph.nodes[SDXL_REFINER_POSITIVE_CONDITIONING] = { + type: 'sdxl_refiner_compel_prompt', + id: SDXL_REFINER_POSITIVE_CONDITIONING, + style: `${positivePrompt} ${positiveStylePrompt}`, + aesthetic_score: refinerAestheticScore, + }; + graph.nodes[SDXL_REFINER_NEGATIVE_CONDITIONING] = { + type: 'sdxl_refiner_compel_prompt', + id: SDXL_REFINER_NEGATIVE_CONDITIONING, + style: `${negativePrompt} ${negativeStylePrompt}`, + aesthetic_score: refinerAestheticScore, + }; + graph.nodes[SDXL_REFINER_LATENTS_TO_LATENTS] = { + type: 'l2l_sdxl', + id: SDXL_REFINER_LATENTS_TO_LATENTS, + cfg_scale: refinerCFGScale, + steps: refinerSteps / (1 - Math.min(refinerStart, 0.99)), + scheduler: refinerScheduler, + denoising_start: refinerStart, + denoising_end: 1, + }; + + graph.edges.push( + { + source: { + node_id: SDXL_REFINER_MODEL_LOADER, + field: 'unet', + }, + destination: { + node_id: SDXL_REFINER_LATENTS_TO_LATENTS, + field: 'unet', + }, + }, + { + source: { + node_id: SDXL_REFINER_MODEL_LOADER, + field: 'vae', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'vae', + }, + }, + { + source: { + node_id: SDXL_REFINER_MODEL_LOADER, + field: 'clip2', + }, + destination: { + node_id: SDXL_REFINER_POSITIVE_CONDITIONING, + field: 'clip2', + }, + }, + { + source: { + node_id: SDXL_REFINER_MODEL_LOADER, + field: 'clip2', + }, + destination: { + node_id: SDXL_REFINER_NEGATIVE_CONDITIONING, + field: 'clip2', + }, + }, + { + source: { + node_id: SDXL_REFINER_POSITIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: SDXL_REFINER_LATENTS_TO_LATENTS, + field: 'positive_conditioning', + }, + }, + { + source: { + node_id: SDXL_REFINER_NEGATIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: SDXL_REFINER_LATENTS_TO_LATENTS, + field: 'negative_conditioning', + }, + }, + { + source: { + node_id: baseNodeId, + field: 'latents', + }, + destination: { + node_id: SDXL_REFINER_LATENTS_TO_LATENTS, + field: 'latents', + }, + }, + { + source: { + node_id: SDXL_REFINER_LATENTS_TO_LATENTS, + field: 'latents', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'latents', + }, + } + ); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addWatermarkerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addWatermarkerToGraph.ts new file mode 100644 index 0000000000..f2e8a0aeca --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addWatermarkerToGraph.ts @@ -0,0 +1,95 @@ +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; +import { + ImageNSFWBlurInvocation, + ImageWatermarkInvocation, + LatentsToImageInvocation, + MetadataAccumulatorInvocation, +} from 'services/api/types'; +import { + LATENTS_TO_IMAGE, + METADATA_ACCUMULATOR, + NSFW_CHECKER, + WATERMARKER, +} from './constants'; + +export const addWatermarkerToGraph = ( + state: RootState, + graph: NonNullableGraph, + nodeIdToAddTo = LATENTS_TO_IMAGE +): void => { + const activeTabName = activeTabNameSelector(state); + + const is_intermediate = + activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false; + + const nodeToAddTo = graph.nodes[nodeIdToAddTo] as + | LatentsToImageInvocation + | undefined; + + const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as + | ImageNSFWBlurInvocation + | undefined; + + const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as + | MetadataAccumulatorInvocation + | undefined; + + if (!nodeToAddTo) { + // something has gone terribly awry + return; + } + + const watermarkerNode: ImageWatermarkInvocation = { + id: WATERMARKER, + type: 'img_watermark', + is_intermediate, + }; + + graph.nodes[WATERMARKER] = watermarkerNode; + + // no matter the situation, we want the l2i node to be intermediate + nodeToAddTo.is_intermediate = true; + + if (nsfwCheckerNode) { + // if we are using NSFW checker, we need to "disable" it output by marking it intermediate, + // then connect it to the watermark node + nsfwCheckerNode.is_intermediate = true; + graph.edges.push({ + source: { + node_id: NSFW_CHECKER, + field: 'image', + }, + destination: { + node_id: WATERMARKER, + field: 'image', + }, + }); + } else { + // otherwise we just connect to the watermark node + graph.edges.push({ + source: { + node_id: nodeIdToAddTo, + field: 'image', + }, + destination: { + node_id: WATERMARKER, + field: 'image', + }, + }); + } + + if (metadataAccumulator) { + graph.edges.push({ + source: { + node_id: METADATA_ACCUMULATOR, + field: 'metadata', + }, + destination: { + node_id: WATERMARKER, + field: 'metadata', + }, + }); + } +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts index 7920d2638a..81e99e0dd0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts @@ -10,7 +10,9 @@ import { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addVAEToGraph } from './addVAEToGraph'; +import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { CLIP_SKIP, IMAGE_TO_IMAGE_GRAPH, @@ -23,8 +25,6 @@ import { NOISE, POSITIVE_CONDITIONING, RESIZE, - NSFW_CHECKER, - WATERMARKER, } from './constants'; /** @@ -105,11 +105,6 @@ export const buildCanvasImageToImageGraph = ( is_intermediate: true, skipped_layers: clipSkip, }, - [LATENTS_TO_IMAGE]: { - type: 'l2i', - id: LATENTS_TO_IMAGE, - is_intermediate: true, - }, [LATENTS_TO_LATENTS]: { type: 'l2l', id: LATENTS_TO_LATENTS, @@ -128,15 +123,10 @@ export const buildCanvasImageToImageGraph = ( // image_name: initialImage.image_name, // }, }, - [NSFW_CHECKER]: { - type: 'img_nsfw', - id: NSFW_CHECKER, - is_intermediate: true, - }, - [WATERMARKER]: { + [LATENTS_TO_IMAGE]: { + type: 'l2i', + id: LATENTS_TO_IMAGE, is_intermediate: !shouldAutoSave, - type: 'img_watermark', - id: WATERMARKER, }, }, edges: [ @@ -180,26 +170,6 @@ export const buildCanvasImageToImageGraph = ( field: 'latents', }, }, - { - source: { - node_id: LATENTS_TO_IMAGE, - field: 'image', - }, - destination: { - node_id: NSFW_CHECKER, - field: 'image', - }, - }, - { - source: { - node_id: NSFW_CHECKER, - field: 'image', - }, - destination: { - node_id: WATERMARKER, - field: 'image', - }, - }, { source: { node_id: IMAGE_TO_LATENTS, @@ -342,17 +312,6 @@ export const buildCanvasImageToImageGraph = ( init_image: initialImage.image_name, }; - graph.edges.push({ - source: { - node_id: METADATA_ACCUMULATOR, - field: 'metadata', - }, - destination: { - node_id: WATERMARKER, - field: 'metadata', - }, - }); - // add LoRA support addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS); @@ -365,5 +324,16 @@ export const buildCanvasImageToImageGraph = ( // add controlnet, mutating `graph` addControlNetToLinearGraph(state, graph, LATENTS_TO_LATENTS); + // NSFW & watermark - must be last thing added to graph + if (state.system.shouldUseNSFWChecker) { + // must add before watermarker! + addNSFWCheckerToGraph(state, graph); + } + + if (state.system.shouldUseWatermarker) { + // must add after nsfw checker! + addWatermarkerToGraph(state, graph); + } + return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts index fa6e20aaa4..4154c1b5eb 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts @@ -20,6 +20,8 @@ import { RANDOM_INT, RANGE_OF_SIZE, } from './constants'; +import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; +import { addWatermarkerToGraph } from './addWatermarkerToGraph'; /** * Builds the Canvas tab's Inpaint graph. @@ -249,5 +251,16 @@ export const buildCanvasInpaintGraph = ( (graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed; } + // NSFW & watermark - must be last thing added to graph + if (state.system.shouldUseNSFWChecker) { + // must add before watermarker! + addNSFWCheckerToGraph(state, graph, INPAINT); + } + + if (state.system.shouldUseWatermarker) { + // must add after nsfw checker! + addWatermarkerToGraph(state, graph, INPAINT); + } + return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts index 88e2f5e70b..597a643367 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts @@ -5,7 +5,9 @@ import { initialGenerationState } from 'features/parameters/store/generationSlic import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addVAEToGraph } from './addVAEToGraph'; +import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { CLIP_SKIP, LATENTS_TO_IMAGE, @@ -16,8 +18,6 @@ import { POSITIVE_CONDITIONING, TEXT_TO_IMAGE_GRAPH, TEXT_TO_LATENTS, - NSFW_CHECKER, - WATERMARKER, } from './constants'; /** @@ -109,16 +109,6 @@ export const buildCanvasTextToImageGraph = ( [LATENTS_TO_IMAGE]: { type: 'l2i', id: LATENTS_TO_IMAGE, - is_intermediate: true, - }, - [NSFW_CHECKER]: { - type: 'img_nsfw', - id: NSFW_CHECKER, - is_intermediate: true, - }, - [WATERMARKER]: { - type: 'img_watermark', - id: WATERMARKER, is_intermediate: !shouldAutoSave, }, }, @@ -193,26 +183,6 @@ export const buildCanvasTextToImageGraph = ( field: 'latents', }, }, - { - source: { - node_id: LATENTS_TO_IMAGE, - field: 'image', - }, - destination: { - node_id: NSFW_CHECKER, - field: 'image', - }, - }, - { - source: { - node_id: NSFW_CHECKER, - field: 'image', - }, - destination: { - node_id: WATERMARKER, - field: 'image', - }, - }, { source: { node_id: NOISE, @@ -247,17 +217,6 @@ export const buildCanvasTextToImageGraph = ( clip_skip: clipSkip, }; - graph.edges.push({ - source: { - node_id: METADATA_ACCUMULATOR, - field: 'metadata', - }, - destination: { - node_id: WATERMARKER, - field: 'metadata', - }, - }); - // add LoRA support addLoRAsToGraph(state, graph, TEXT_TO_LATENTS); @@ -270,5 +229,16 @@ export const buildCanvasTextToImageGraph = ( // add controlnet, mutating `graph` addControlNetToLinearGraph(state, graph, TEXT_TO_LATENTS); + // NSFW & watermark - must be last thing added to graph + if (state.system.shouldUseNSFWChecker) { + // must add before watermarker! + addNSFWCheckerToGraph(state, graph); + } + + if (state.system.shouldUseWatermarker) { + // must add after nsfw checker! + addWatermarkerToGraph(state, graph); + } + return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts index ebe0a51f99..d78e2e1356 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts @@ -9,7 +9,9 @@ import { import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addVAEToGraph } from './addVAEToGraph'; +import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { CLIP_SKIP, IMAGE_TO_IMAGE_GRAPH, @@ -22,8 +24,6 @@ import { NOISE, POSITIVE_CONDITIONING, RESIZE, - NSFW_CHECKER, - WATERMARKER, } from './constants'; /** @@ -48,6 +48,7 @@ export const buildLinearImageToImageGraph = ( clipSkip, shouldUseCpuNoise, shouldUseNoiseSettings, + vaePrecision, } = state.generation; // TODO: add batch functionality @@ -115,7 +116,7 @@ export const buildLinearImageToImageGraph = ( [LATENTS_TO_IMAGE]: { type: 'l2i', id: LATENTS_TO_IMAGE, - is_intermediate: true, + fp32: vaePrecision === 'fp32' ? true : false, }, [LATENTS_TO_LATENTS]: { type: 'l2l', @@ -131,15 +132,8 @@ export const buildLinearImageToImageGraph = ( // must be set manually later, bc `fit` parameter may require a resize node inserted // image: { // image_name: initialImage.image_name, - }, - [NSFW_CHECKER]: { - type: 'img_nsfw', - id: NSFW_CHECKER, - is_intermediate: true, - }, - [WATERMARKER]: { - type: 'img_watermark', - id: WATERMARKER, + // }, + fp32: vaePrecision === 'fp32' ? true : false, }, }, edges: [ @@ -193,26 +187,6 @@ export const buildLinearImageToImageGraph = ( field: 'latents', }, }, - { - source: { - node_id: LATENTS_TO_IMAGE, - field: 'image', - }, - destination: { - node_id: NSFW_CHECKER, - field: 'image', - }, - }, - { - source: { - node_id: NSFW_CHECKER, - field: 'image', - }, - destination: { - node_id: WATERMARKER, - field: 'image', - }, - }, { source: { node_id: IMAGE_TO_LATENTS, @@ -325,42 +299,6 @@ export const buildLinearImageToImageGraph = ( }); } - // TODO: add batch functionality - // if (isBatchEnabled && asInitialImage && batchImageNames.length > 0) { - // // we are going to connect an iterate up to the init image - // delete (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image; - - // const imageCollection: ImageCollectionInvocation = { - // id: IMAGE_COLLECTION, - // type: 'image_collection', - // images: batchImageNames.map((image_name) => ({ image_name })), - // }; - - // const imageCollectionIterate: IterateInvocation = { - // id: IMAGE_COLLECTION_ITERATE, - // type: 'iterate', - // }; - - // graph.nodes[IMAGE_COLLECTION] = imageCollection; - // graph.nodes[IMAGE_COLLECTION_ITERATE] = imageCollectionIterate; - - // graph.edges.push({ - // source: { node_id: IMAGE_COLLECTION, field: 'collection' }, - // destination: { - // node_id: IMAGE_COLLECTION_ITERATE, - // field: 'collection', - // }, - // }); - - // graph.edges.push({ - // source: { node_id: IMAGE_COLLECTION_ITERATE, field: 'item' }, - // destination: { - // node_id: IMAGE_TO_LATENTS, - // field: 'image', - // }, - // }); - // } - // add metadata accumulator, which is only mostly populated - some fields are added later graph.nodes[METADATA_ACCUMULATOR] = { id: METADATA_ACCUMULATOR, @@ -384,17 +322,6 @@ export const buildLinearImageToImageGraph = ( init_image: initialImage.imageName, }; - graph.edges.push({ - source: { - node_id: METADATA_ACCUMULATOR, - field: 'metadata', - }, - destination: { - node_id: WATERMARKER, - field: 'metadata', - }, - }); - // add LoRA support addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS); @@ -407,5 +334,16 @@ export const buildLinearImageToImageGraph = ( // add controlnet, mutating `graph` addControlNetToLinearGraph(state, graph, LATENTS_TO_LATENTS); + // NSFW & watermark - must be last thing added to graph + if (state.system.shouldUseNSFWChecker) { + // must add before watermarker! + addNSFWCheckerToGraph(state, graph); + } + + if (state.system.shouldUseWatermarker) { + // must add after nsfw checker! + addWatermarkerToGraph(state, graph); + } + return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts new file mode 100644 index 0000000000..8c82002b2e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts @@ -0,0 +1,382 @@ +import { logger } from 'app/logging/logger'; +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { initialGenerationState } from 'features/parameters/store/generationSlice'; +import { + ImageResizeInvocation, + ImageToLatentsInvocation, +} from 'services/api/types'; +import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; +import { + IMAGE_TO_LATENTS, + LATENTS_TO_IMAGE, + METADATA_ACCUMULATOR, + NEGATIVE_CONDITIONING, + NOISE, + POSITIVE_CONDITIONING, + RESIZE, + SDXL_IMAGE_TO_IMAGE_GRAPH, + SDXL_LATENTS_TO_LATENTS, + SDXL_MODEL_LOADER, +} from './constants'; +import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; +import { addWatermarkerToGraph } from './addWatermarkerToGraph'; + +/** + * Builds the Image to Image tab graph. + */ +export const buildLinearSDXLImageToImageGraph = ( + state: RootState +): NonNullableGraph => { + const log = logger('nodes'); + const { + positivePrompt, + negativePrompt, + model, + cfgScale: cfg_scale, + scheduler, + steps, + initialImage, + shouldFitToWidthHeight, + width, + height, + clipSkip, + shouldUseCpuNoise, + shouldUseNoiseSettings, + vaePrecision, + } = state.generation; + + const { + positiveStylePrompt, + negativeStylePrompt, + shouldUseSDXLRefiner, + refinerStart, + sdxlImg2ImgDenoisingStrength: strength, + } = state.sdxl; + + /** + * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the + * full graph here as a template. Then use the parameters from app state and set friendlier node + * ids. + * + * The only thing we need extra logic for is handling randomized seed, control net, and for img2img, + * the `fit` param. These are added to the graph at the end. + */ + + if (!initialImage) { + log.error('No initial image found in state'); + throw new Error('No initial image found in state'); + } + + if (!model) { + log.error('No model found in state'); + throw new Error('No model found in state'); + } + + const use_cpu = shouldUseNoiseSettings + ? shouldUseCpuNoise + : initialGenerationState.shouldUseCpuNoise; + + // copy-pasted graph from node editor, filled in with state values & friendly node ids + const graph: NonNullableGraph = { + id: SDXL_IMAGE_TO_IMAGE_GRAPH, + nodes: { + [SDXL_MODEL_LOADER]: { + type: 'sdxl_model_loader', + id: SDXL_MODEL_LOADER, + model, + }, + [POSITIVE_CONDITIONING]: { + type: 'sdxl_compel_prompt', + id: POSITIVE_CONDITIONING, + prompt: positivePrompt, + style: positiveStylePrompt, + }, + [NEGATIVE_CONDITIONING]: { + type: 'sdxl_compel_prompt', + id: NEGATIVE_CONDITIONING, + prompt: negativePrompt, + style: negativeStylePrompt, + }, + [NOISE]: { + type: 'noise', + id: NOISE, + use_cpu, + }, + [LATENTS_TO_IMAGE]: { + type: 'l2i', + id: LATENTS_TO_IMAGE, + fp32: vaePrecision === 'fp32' ? true : false, + }, + [SDXL_LATENTS_TO_LATENTS]: { + type: 'l2l_sdxl', + id: SDXL_LATENTS_TO_LATENTS, + cfg_scale, + scheduler, + steps, + denoising_start: shouldUseSDXLRefiner + ? Math.min(refinerStart, 1 - strength) + : 1 - strength, + denoising_end: shouldUseSDXLRefiner ? refinerStart : 1, + }, + [IMAGE_TO_LATENTS]: { + type: 'i2l', + id: IMAGE_TO_LATENTS, + // must be set manually later, bc `fit` parameter may require a resize node inserted + // image: { + // image_name: initialImage.image_name, + // }, + fp32: vaePrecision === 'fp32' ? true : false, + }, + }, + edges: [ + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'unet', + }, + destination: { + node_id: SDXL_LATENTS_TO_LATENTS, + field: 'unet', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'vae', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'vae', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'vae', + }, + destination: { + node_id: IMAGE_TO_LATENTS, + field: 'vae', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip2', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip2', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip2', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip2', + }, + }, + { + source: { + node_id: SDXL_LATENTS_TO_LATENTS, + field: 'latents', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'latents', + }, + }, + { + source: { + node_id: IMAGE_TO_LATENTS, + field: 'latents', + }, + destination: { + node_id: SDXL_LATENTS_TO_LATENTS, + field: 'latents', + }, + }, + { + source: { + node_id: NOISE, + field: 'noise', + }, + destination: { + node_id: SDXL_LATENTS_TO_LATENTS, + field: 'noise', + }, + }, + { + source: { + node_id: POSITIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: SDXL_LATENTS_TO_LATENTS, + field: 'positive_conditioning', + }, + }, + { + source: { + node_id: NEGATIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: SDXL_LATENTS_TO_LATENTS, + field: 'negative_conditioning', + }, + }, + ], + }; + + // handle `fit` + if ( + shouldFitToWidthHeight && + (initialImage.width !== width || initialImage.height !== height) + ) { + // The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS` + + // Create a resize node, explicitly setting its image + const resizeNode: ImageResizeInvocation = { + id: RESIZE, + type: 'img_resize', + image: { + image_name: initialImage.imageName, + }, + is_intermediate: true, + width, + height, + }; + + graph.nodes[RESIZE] = resizeNode; + + // The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS` + graph.edges.push({ + source: { node_id: RESIZE, field: 'image' }, + destination: { + node_id: IMAGE_TO_LATENTS, + field: 'image', + }, + }); + + // The `RESIZE` node also passes its width and height to `NOISE` + graph.edges.push({ + source: { node_id: RESIZE, field: 'width' }, + destination: { + node_id: NOISE, + field: 'width', + }, + }); + + graph.edges.push({ + source: { node_id: RESIZE, field: 'height' }, + destination: { + node_id: NOISE, + field: 'height', + }, + }); + } else { + // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly + (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = { + image_name: initialImage.imageName, + }; + + // Pass the image's dimensions to the `NOISE` node + graph.edges.push({ + source: { node_id: IMAGE_TO_LATENTS, field: 'width' }, + destination: { + node_id: NOISE, + field: 'width', + }, + }); + graph.edges.push({ + source: { node_id: IMAGE_TO_LATENTS, field: 'height' }, + destination: { + node_id: NOISE, + field: 'height', + }, + }); + } + + // add metadata accumulator, which is only mostly populated - some fields are added later + graph.nodes[METADATA_ACCUMULATOR] = { + id: METADATA_ACCUMULATOR, + type: 'metadata_accumulator', + generation_mode: 'sdxl_img2img', + cfg_scale, + height, + width, + positive_prompt: '', // set in addDynamicPromptsToGraph + negative_prompt: negativePrompt, + model, + seed: 0, // set in addDynamicPromptsToGraph + steps, + rand_device: use_cpu ? 'cpu' : 'cuda', + scheduler, + vae: undefined, + controlnets: [], + loras: [], + clip_skip: clipSkip, + strength: strength, + init_image: initialImage.imageName, + positive_style_prompt: positiveStylePrompt, + negative_style_prompt: negativeStylePrompt, + }; + + graph.edges.push({ + source: { + node_id: METADATA_ACCUMULATOR, + field: 'metadata', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'metadata', + }, + }); + + // Add Refiner if enabled + if (shouldUseSDXLRefiner) { + addSDXLRefinerToGraph(state, graph, SDXL_LATENTS_TO_LATENTS); + } + + // add dynamic prompts - also sets up core iteration and seed + addDynamicPromptsToGraph(state, graph); + + // NSFW & watermark - must be last thing added to graph + if (state.system.shouldUseNSFWChecker) { + // must add before watermarker! + addNSFWCheckerToGraph(state, graph); + } + + if (state.system.shouldUseWatermarker) { + // must add after nsfw checker! + addWatermarkerToGraph(state, graph); + } + + return graph; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts new file mode 100644 index 0000000000..36f35a90de --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts @@ -0,0 +1,264 @@ +import { logger } from 'app/logging/logger'; +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { initialGenerationState } from 'features/parameters/store/generationSlice'; +import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; +import { + LATENTS_TO_IMAGE, + METADATA_ACCUMULATOR, + NEGATIVE_CONDITIONING, + NOISE, + POSITIVE_CONDITIONING, + SDXL_MODEL_LOADER, + SDXL_TEXT_TO_IMAGE_GRAPH, + SDXL_TEXT_TO_LATENTS, +} from './constants'; +import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; +import { addWatermarkerToGraph } from './addWatermarkerToGraph'; + +export const buildLinearSDXLTextToImageGraph = ( + state: RootState +): NonNullableGraph => { + const log = logger('nodes'); + const { + positivePrompt, + negativePrompt, + model, + cfgScale: cfg_scale, + scheduler, + steps, + width, + height, + clipSkip, + shouldUseCpuNoise, + shouldUseNoiseSettings, + vaePrecision, + } = state.generation; + + const { + positiveStylePrompt, + negativeStylePrompt, + shouldUseSDXLRefiner, + refinerStart, + } = state.sdxl; + + const use_cpu = shouldUseNoiseSettings + ? shouldUseCpuNoise + : initialGenerationState.shouldUseCpuNoise; + + if (!model) { + log.error('No model found in state'); + throw new Error('No model found in state'); + } + + /** + * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the + * full graph here as a template. Then use the parameters from app state and set friendlier node + * ids. + * + * The only thing we need extra logic for is handling randomized seed, control net, and for img2img, + * the `fit` param. These are added to the graph at the end. + */ + + // copy-pasted graph from node editor, filled in with state values & friendly node ids + const graph: NonNullableGraph = { + id: SDXL_TEXT_TO_IMAGE_GRAPH, + nodes: { + [SDXL_MODEL_LOADER]: { + type: 'sdxl_model_loader', + id: SDXL_MODEL_LOADER, + model, + }, + [POSITIVE_CONDITIONING]: { + type: 'sdxl_compel_prompt', + id: POSITIVE_CONDITIONING, + prompt: positivePrompt, + style: positiveStylePrompt, + }, + [NEGATIVE_CONDITIONING]: { + type: 'sdxl_compel_prompt', + id: NEGATIVE_CONDITIONING, + prompt: negativePrompt, + style: negativeStylePrompt, + }, + [NOISE]: { + type: 'noise', + id: NOISE, + width, + height, + use_cpu, + }, + [SDXL_TEXT_TO_LATENTS]: { + type: 't2l_sdxl', + id: SDXL_TEXT_TO_LATENTS, + cfg_scale, + scheduler, + steps, + denoising_end: shouldUseSDXLRefiner ? refinerStart : 1, + }, + [LATENTS_TO_IMAGE]: { + type: 'l2i', + id: LATENTS_TO_IMAGE, + fp32: vaePrecision === 'fp32' ? true : false, + }, + }, + edges: [ + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'unet', + }, + destination: { + node_id: SDXL_TEXT_TO_LATENTS, + field: 'unet', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'vae', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'vae', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip2', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip2', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip', + }, + }, + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'clip2', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip2', + }, + }, + { + source: { + node_id: POSITIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: SDXL_TEXT_TO_LATENTS, + field: 'positive_conditioning', + }, + }, + { + source: { + node_id: NEGATIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: SDXL_TEXT_TO_LATENTS, + field: 'negative_conditioning', + }, + }, + { + source: { + node_id: NOISE, + field: 'noise', + }, + destination: { + node_id: SDXL_TEXT_TO_LATENTS, + field: 'noise', + }, + }, + { + source: { + node_id: SDXL_TEXT_TO_LATENTS, + field: 'latents', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'latents', + }, + }, + ], + }; + + // add metadata accumulator, which is only mostly populated - some fields are added later + graph.nodes[METADATA_ACCUMULATOR] = { + id: METADATA_ACCUMULATOR, + type: 'metadata_accumulator', + generation_mode: 'sdxl_txt2img', + cfg_scale, + height, + width, + positive_prompt: '', // set in addDynamicPromptsToGraph + negative_prompt: negativePrompt, + model, + seed: 0, // set in addDynamicPromptsToGraph + steps, + rand_device: use_cpu ? 'cpu' : 'cuda', + scheduler, + vae: undefined, + controlnets: [], + loras: [], + clip_skip: clipSkip, + positive_style_prompt: positiveStylePrompt, + negative_style_prompt: negativeStylePrompt, + }; + + graph.edges.push({ + source: { + node_id: METADATA_ACCUMULATOR, + field: 'metadata', + }, + destination: { + node_id: LATENTS_TO_IMAGE, + field: 'metadata', + }, + }); + + // Add Refiner if enabled + if (shouldUseSDXLRefiner) { + addSDXLRefinerToGraph(state, graph, SDXL_TEXT_TO_LATENTS); + } + + // add dynamic prompts - also sets up core iteration and seed + addDynamicPromptsToGraph(state, graph); + + // NSFW & watermark - must be last thing added to graph + if (state.system.shouldUseNSFWChecker) { + // must add before watermarker! + addNSFWCheckerToGraph(state, graph); + } + + if (state.system.shouldUseWatermarker) { + // must add after nsfw checker! + addWatermarkerToGraph(state, graph); + } + + return graph; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts index f60f372a12..94f35feba0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts @@ -5,7 +5,9 @@ import { initialGenerationState } from 'features/parameters/store/generationSlic import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph'; +import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addVAEToGraph } from './addVAEToGraph'; +import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { CLIP_SKIP, LATENTS_TO_IMAGE, @@ -16,8 +18,6 @@ import { POSITIVE_CONDITIONING, TEXT_TO_IMAGE_GRAPH, TEXT_TO_LATENTS, - NSFW_CHECKER, - WATERMARKER, } from './constants'; export const buildLinearTextToImageGraph = ( @@ -36,6 +36,7 @@ export const buildLinearTextToImageGraph = ( clipSkip, shouldUseCpuNoise, shouldUseNoiseSettings, + vaePrecision, } = state.generation; const use_cpu = shouldUseNoiseSettings @@ -97,16 +98,7 @@ export const buildLinearTextToImageGraph = ( [LATENTS_TO_IMAGE]: { type: 'l2i', id: LATENTS_TO_IMAGE, - is_intermediate: true, - }, - [NSFW_CHECKER]: { - type: 'img_nsfw', - id: NSFW_CHECKER, - is_intermediate: true, - }, - [WATERMARKER]: { - type: 'img_watermark', - id: WATERMARKER, + fp32: vaePrecision === 'fp32' ? true : false, }, }, edges: [ @@ -190,26 +182,6 @@ export const buildLinearTextToImageGraph = ( field: 'noise', }, }, - { - source: { - node_id: LATENTS_TO_IMAGE, - field: 'image', - }, - destination: { - node_id: NSFW_CHECKER, - field: 'image', - }, - }, - { - source: { - node_id: NSFW_CHECKER, - field: 'image', - }, - destination: { - node_id: WATERMARKER, - field: 'image', - }, - }, ], }; @@ -234,17 +206,6 @@ export const buildLinearTextToImageGraph = ( clip_skip: clipSkip, }; - graph.edges.push({ - source: { - node_id: METADATA_ACCUMULATOR, - field: 'metadata', - }, - destination: { - node_id: WATERMARKER, - field: 'metadata', - }, - }); - // add LoRA support addLoRAsToGraph(state, graph, TEXT_TO_LATENTS); @@ -257,5 +218,16 @@ export const buildLinearTextToImageGraph = ( // add controlnet, mutating `graph` addControlNetToLinearGraph(state, graph, TEXT_TO_LATENTS); + // NSFW & watermark - must be last thing added to graph + if (state.system.shouldUseNSFWChecker) { + // must add before watermarker! + addNSFWCheckerToGraph(state, graph); + } + + if (state.system.shouldUseWatermarker) { + // must add after nsfw checker! + addWatermarkerToGraph(state, graph); + } + return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts index d33bffb266..8cb9c6d50d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts @@ -25,8 +25,19 @@ export const METADATA_ACCUMULATOR = 'metadata_accumulator'; export const REALESRGAN = 'esrgan'; export const DIVIDE = 'divide'; export const SCALE = 'scale_image'; +export const SDXL_MODEL_LOADER = 'sdxl_model_loader'; +export const SDXL_TEXT_TO_LATENTS = 't2l_sdxl'; +export const SDXL_LATENTS_TO_LATENTS = 'l2l_sdxl'; +export const SDXL_REFINER_MODEL_LOADER = 'sdxl_refiner_model_loader'; +export const SDXL_REFINER_POSITIVE_CONDITIONING = + 'sdxl_refiner_positive_conditioning'; +export const SDXL_REFINER_NEGATIVE_CONDITIONING = + 'sdxl_refiner_negative_conditioning'; +export const SDXL_REFINER_LATENTS_TO_LATENTS = 'l2l_sdxl_refiner'; // friendly graph ids export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph'; +export const SDXL_TEXT_TO_IMAGE_GRAPH = 'sdxl_text_to_image_graph'; +export const SDXL_IMAGE_TO_IMAGE_GRAPH = 'sxdl_image_to_image_graph'; export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph'; export const INPAINT_GRAPH = 'inpaint_graph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index bedf932b50..3a9cf233b5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -13,7 +13,12 @@ import { buildOutputFieldTemplates, } from './fieldTemplateBuilders'; -const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate', 'metadata']; +const getReservedFieldNames = (type: string): string[] => { + if (type === 'l2i') { + return ['id', 'type', 'metadata']; + } + return ['id', 'type', 'is_intermediate', 'metadata']; +}; const invocationDenylist = [ 'Graph', @@ -21,11 +26,11 @@ const invocationDenylist = [ 'MetadataAccumulatorInvocation', ]; -export const parseSchema = (openAPI: OpenAPIV3.Document) => { - // filter out non-invocation schemas, plus some tricky invocations for now +export const parseSchema = ( + openAPI: OpenAPIV3.Document +): Record => { const filteredSchemas = filter( - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - openAPI.components!.schemas, + openAPI.components?.schemas, (schema, key) => key.includes('Invocation') && !key.includes('InvocationOutput') && @@ -35,21 +40,17 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => { const invocations = filteredSchemas.reduce< Record >((acc, schema) => { - // only want SchemaObjects if (isInvocationSchemaObject(schema)) { const type = schema.properties.type.default; + const RESERVED_FIELD_NAMES = getReservedFieldNames(type); const title = schema.ui?.title ?? schema.title.replace('Invocation', ''); - const typeHints = schema.ui?.type_hints; const inputs: Record = {}; if (type === 'collect') { - const itemProperty = schema.properties[ - 'item' - ] as InvocationSchemaObject; - // Handle the special Collect node + const itemProperty = schema.properties.item as InvocationSchemaObject; inputs.item = { type: 'item', name: 'item', @@ -60,10 +61,8 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => { default: undefined, }; } else if (type === 'iterate') { - const itemProperty = schema.properties[ - 'collection' - ] as InvocationSchemaObject; - + const itemProperty = schema.properties + .collection as InvocationSchemaObject; inputs.collection = { type: 'array', name: 'collection', @@ -74,18 +73,18 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => { inputKind: 'connection', }; } else { - // All other nodes reduce( schema.properties, (inputsAccumulator, property, propertyName) => { if ( - // `type` and `id` are not valid inputs/outputs !RESERVED_FIELD_NAMES.includes(propertyName) && isSchemaObject(property) ) { - const field: InputFieldTemplate | undefined = - buildInputFieldTemplate(property, propertyName, typeHints); - + const field = buildInputFieldTemplate( + property, + propertyName, + typeHints + ); if (field) { inputsAccumulator[propertyName] = field; } @@ -97,22 +96,17 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => { } const rawOutput = (schema as InvocationSchemaObject).output; - let outputs: Record; - // some special handling is needed for collect, iterate and range nodes if (type === 'iterate') { - // this is guaranteed to be a SchemaObject - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const iterationOutput = openAPI.components!.schemas![ + const iterationOutput = openAPI.components?.schemas?.[ 'IterateInvocationOutput' ] as OpenAPIV3.SchemaObject; - outputs = { item: { name: 'item', - title: iterationOutput.title ?? '', - description: iterationOutput.description ?? '', + title: iterationOutput?.title ?? '', + description: iterationOutput?.description ?? '', type: 'array', }, }; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/GenerationModeStatusText.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/GenerationModeStatusText.tsx new file mode 100644 index 0000000000..511e90f0f3 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/GenerationModeStatusText.tsx @@ -0,0 +1,21 @@ +import { Box } from '@chakra-ui/react'; +import { useCanvasGenerationMode } from 'features/canvas/hooks/useCanvasGenerationMode'; + +const GENERATION_MODE_NAME_MAP = { + txt2img: 'Text to Image', + img2img: 'Image to Image', + inpaint: 'Inpaint', + outpaint: 'Inpaint', +}; + +const GenerationModeStatusText = () => { + const generationMode = useCanvasGenerationMode(); + + return ( + + Mode: {generationMode ? GENERATION_MODE_NAME_MAP[generationMode] : '...'} + + ); +}; + +export default GenerationModeStatusText; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler.tsx index 74418de1d3..37c5eb30c4 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler.tsx @@ -4,6 +4,7 @@ import { memo } from 'react'; import ParamMainModelSelect from '../MainModel/ParamMainModelSelect'; import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect'; import ParamScheduler from './ParamScheduler'; +import ParamVAEPrecision from '../VAEModel/ParamVAEPrecision'; const ParamModelandVAEandScheduler = () => { const isVaeEnabled = useFeatureStatus('vae').isFeatureEnabled; @@ -13,16 +14,15 @@ const ParamModelandVAEandScheduler = () => { - - {isVaeEnabled && ( - - - - )} - - - - + + + + {isVaeEnabled && ( + + + + + )} ); }; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx index 75f1bc8bd9..d380da60bf 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/MainModel/ParamMainModelSelect.tsx @@ -13,8 +13,11 @@ import { modelSelected } from 'features/parameters/store/actions'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton'; +import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { forEach } from 'lodash-es'; +import { NON_REFINER_BASE_MODELS } from 'services/api/constants'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; +import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus'; const selector = createSelector( stateSelector, @@ -28,7 +31,12 @@ const ParamMainModelSelect = () => { const { model } = useAppSelector(selector); - const { data: mainModels, isLoading } = useGetMainModelsQuery(); + const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled; + const { data: mainModels, isLoading } = useGetMainModelsQuery( + NON_REFINER_BASE_MODELS + ); + + const activeTabName = useAppSelector(activeTabNameSelector); const data = useMemo(() => { if (!mainModels) { @@ -38,7 +46,10 @@ const ParamMainModelSelect = () => { const data: SelectItem[] = []; forEach(mainModels.entities, (model, id) => { - if (!model || ['sdxl', 'sdxl-refiner'].includes(model.base_model)) { + if ( + !model || + (activeTabName === 'unifiedCanvas' && model.base_model === 'sdxl') + ) { return; } @@ -50,7 +61,7 @@ const ParamMainModelSelect = () => { }); return data; - }, [mainModels]); + }, [mainModels, activeTabName]); // grab the full model entity from the RTK Query cache // TODO: maybe we should just store the full model entity in state? @@ -86,7 +97,7 @@ const ParamMainModelSelect = () => { data={[]} /> ) : ( - + { onChange={handleChangeModel} w="100%" /> - - - + {isSyncModelEnabled && ( + + + + )} ); }; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeed.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeed.tsx index d5ced67d0e..481fe27964 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeed.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeed.tsx @@ -32,11 +32,6 @@ export default function ParamSeed() { isInvalid={seed < 0 && shouldGenerateVariations} onChange={handleChangeSeed} value={seed} - formControlProps={{ - display: 'flex', - alignItems: 'center', - gap: 3, // really this should work with 2 but seems to need to be 3 to match gap 2? - }} /> ); } diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeedFull.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeedFull.tsx index 75a5d189ae..a1887ec896 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeedFull.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Seed/ParamSeedFull.tsx @@ -6,7 +6,7 @@ import ParamSeedRandomize from './ParamSeedRandomize'; const ParamSeedFull = () => { return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEPrecision.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEPrecision.tsx new file mode 100644 index 0000000000..c57cdc1132 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEPrecision.tsx @@ -0,0 +1,46 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { vaePrecisionChanged } from 'features/parameters/store/generationSlice'; +import { PrecisionParam } from 'features/parameters/types/parameterSchemas'; +import { memo, useCallback } from 'react'; + +const selector = createSelector( + stateSelector, + ({ generation }) => { + const { vaePrecision } = generation; + return { vaePrecision }; + }, + defaultSelectorOptions +); + +const DATA = ['fp16', 'fp32']; + +const ParamVAEModelSelect = () => { + const dispatch = useAppDispatch(); + const { vaePrecision } = useAppSelector(selector); + + const handleChange = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + dispatch(vaePrecisionChanged(v as PrecisionParam)); + }, + [dispatch] + ); + + return ( + + ); +}; + +export default memo(ParamVAEModelSelect); diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 971558335b..a2ddb569dd 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -11,6 +11,7 @@ import { MainModelParam, NegativePromptParam, PositivePromptParam, + PrecisionParam, SchedulerParam, SeedParam, StepsParam, @@ -51,6 +52,7 @@ export interface GenerationState { verticalSymmetrySteps: number; model: MainModelField | null; vae: VaeModelParam | null; + vaePrecision: PrecisionParam; seamlessXAxis: boolean; seamlessYAxis: boolean; clipSkip: number; @@ -89,6 +91,7 @@ export const initialGenerationState: GenerationState = { verticalSymmetrySteps: 0, model: null, vae: null, + vaePrecision: 'fp32', seamlessXAxis: false, seamlessYAxis: false, clipSkip: 0, @@ -241,6 +244,9 @@ export const generationSlice = createSlice({ // null is a valid VAE! state.vae = action.payload; }, + vaePrecisionChanged: (state, action: PayloadAction) => { + state.vaePrecision = action.payload; + }, setClipSkip: (state, action: PayloadAction) => { state.clipSkip = action.payload; }, @@ -327,6 +333,7 @@ export const { shouldUseCpuNoiseChanged, setShouldShowAdvancedOptions, setAspectRatio, + vaePrecisionChanged, } = generationSlice.actions; export default generationSlice.reducer; diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index cea5fb9987..64f4665c5f 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -42,6 +42,42 @@ export const isValidNegativePrompt = ( val: unknown ): val is NegativePromptParam => zNegativePrompt.safeParse(val).success; +/** + * Zod schema for SDXL positive style prompt parameter + */ +export const zPositiveStylePromptSDXL = z.string(); +/** + * Type alias for SDXL positive style prompt parameter, inferred from its zod schema + */ +export type PositiveStylePromptSDXLParam = z.infer< + typeof zPositiveStylePromptSDXL +>; +/** + * Validates/type-guards a value as a SDXL positive style prompt parameter + */ +export const isValidSDXLPositiveStylePrompt = ( + val: unknown +): val is PositiveStylePromptSDXLParam => + zPositiveStylePromptSDXL.safeParse(val).success; + +/** + * Zod schema for SDXL negative style prompt parameter + */ +export const zNegativeStylePromptSDXL = z.string(); +/** + * Type alias for SDXL negative style prompt parameter, inferred from its zod schema + */ +export type NegativeStylePromptSDXLParam = z.infer< + typeof zNegativeStylePromptSDXL +>; +/** + * Validates/type-guards a value as a SDXL negative style prompt parameter + */ +export const isValidSDXLNegativeStylePrompt = ( + val: unknown +): val is NegativeStylePromptSDXLParam => + zNegativeStylePromptSDXL.safeParse(val).success; + /** * Zod schema for steps parameter */ @@ -260,6 +296,20 @@ export type StrengthParam = z.infer; export const isValidStrength = (val: unknown): val is StrengthParam => zStrength.safeParse(val).success; +/** + * Zod schema for a precision parameter + */ +export const zPrecision = z.enum(['fp16', 'fp32']); +/** + * Type alias for precision parameter, inferred from its zod schema + */ +export type PrecisionParam = z.infer; +/** + * Validates/type-guards a value as a precision parameter + */ +export const isValidPrecision = (val: unknown): val is PrecisionParam => + zPrecision.safeParse(val).success; + // /** // * Zod schema for BaseModelType // */ diff --git a/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLImg2ImgDenoisingStrength.tsx b/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLImg2ImgDenoisingStrength.tsx new file mode 100644 index 0000000000..52d7567339 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLImg2ImgDenoisingStrength.tsx @@ -0,0 +1,53 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISlider from 'common/components/IAISlider'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { setSDXLImg2ImgDenoisingStrength } from '../store/sdxlSlice'; + +const selector = createSelector( + [stateSelector], + ({ sdxl }) => { + const { sdxlImg2ImgDenoisingStrength } = sdxl; + + return { + sdxlImg2ImgDenoisingStrength, + }; + }, + defaultSelectorOptions +); + +const ParamSDXLImg2ImgDenoisingStrength = () => { + const { sdxlImg2ImgDenoisingStrength } = useAppSelector(selector); + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const handleChange = useCallback( + (v: number) => dispatch(setSDXLImg2ImgDenoisingStrength(v)), + [dispatch] + ); + + const handleReset = useCallback(() => { + dispatch(setSDXLImg2ImgDenoisingStrength(0.7)); + }, [dispatch]); + + return ( + + ); +}; + +export default memo(ParamSDXLImg2ImgDenoisingStrength); diff --git a/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLNegativeStyleConditioning.tsx b/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLNegativeStyleConditioning.tsx new file mode 100644 index 0000000000..e2d5c5b8e4 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLNegativeStyleConditioning.tsx @@ -0,0 +1,149 @@ +import { Box, FormControl, useDisclosure } from '@chakra-ui/react'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react'; + +import { createSelector } from '@reduxjs/toolkit'; +import { clampSymmetrySteps } from 'features/parameters/store/generationSlice'; +import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; + +import { userInvoked } from 'app/store/actions'; +import IAITextarea from 'common/components/IAITextarea'; +import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; +import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton'; +import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover'; +import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { isEqual } from 'lodash-es'; +import { flushSync } from 'react-dom'; +import { setNegativeStylePromptSDXL } from '../store/sdxlSlice'; + +const promptInputSelector = createSelector( + [stateSelector, activeTabNameSelector], + ({ sdxl }, activeTabName) => { + return { + prompt: sdxl.negativeStylePrompt, + activeTabName, + }; + }, + { + memoizeOptions: { + resultEqualityCheck: isEqual, + }, + } +); + +/** + * Prompt input text area. + */ +const ParamSDXLNegativeStyleConditioning = () => { + const dispatch = useAppDispatch(); + const { prompt, activeTabName } = useAppSelector(promptInputSelector); + const isReady = useIsReadyToInvoke(); + const promptRef = useRef(null); + const { isOpen, onClose, onOpen } = useDisclosure(); + + const handleChangePrompt = useCallback( + (e: ChangeEvent) => { + dispatch(setNegativeStylePromptSDXL(e.target.value)); + }, + [dispatch] + ); + + const handleSelectEmbedding = useCallback( + (v: string) => { + if (!promptRef.current) { + return; + } + + // this is where we insert the TI trigger + const caret = promptRef.current.selectionStart; + + if (caret === undefined) { + return; + } + + let newPrompt = prompt.slice(0, caret); + + if (newPrompt[newPrompt.length - 1] !== '<') { + newPrompt += '<'; + } + + newPrompt += `${v}>`; + + // we insert the cursor after the `>` + const finalCaretPos = newPrompt.length; + + newPrompt += prompt.slice(caret); + + // must flush dom updates else selection gets reset + flushSync(() => { + dispatch(setNegativeStylePromptSDXL(newPrompt)); + }); + + // set the caret position to just after the TI trigger + promptRef.current.selectionStart = finalCaretPos; + promptRef.current.selectionEnd = finalCaretPos; + onClose(); + }, + [dispatch, onClose, prompt] + ); + + const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled; + + const handleKeyDown = useCallback( + (e: KeyboardEvent) => { + if (e.key === 'Enter' && e.shiftKey === false && isReady) { + e.preventDefault(); + dispatch(clampSymmetrySteps()); + dispatch(userInvoked(activeTabName)); + } + if (isEmbeddingEnabled && e.key === '<') { + onOpen(); + } + }, + [isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled] + ); + + // const handleSelect = (e: MouseEvent) => { + // const target = e.target as HTMLTextAreaElement; + // setCaret({ start: target.selectionStart, end: target.selectionEnd }); + // }; + + return ( + + + + + + + {!isOpen && isEmbeddingEnabled && ( + + + + )} + + ); +}; + +export default ParamSDXLNegativeStyleConditioning; diff --git a/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLPositiveStyleConditioning.tsx b/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLPositiveStyleConditioning.tsx new file mode 100644 index 0000000000..8512aedd68 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLPositiveStyleConditioning.tsx @@ -0,0 +1,148 @@ +import { Box, FormControl, useDisclosure } from '@chakra-ui/react'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react'; + +import { createSelector } from '@reduxjs/toolkit'; +import { clampSymmetrySteps } from 'features/parameters/store/generationSlice'; +import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; + +import { userInvoked } from 'app/store/actions'; +import IAITextarea from 'common/components/IAITextarea'; +import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; +import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton'; +import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover'; +import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { isEqual } from 'lodash-es'; +import { flushSync } from 'react-dom'; +import { setPositiveStylePromptSDXL } from '../store/sdxlSlice'; + +const promptInputSelector = createSelector( + [stateSelector, activeTabNameSelector], + ({ sdxl }, activeTabName) => { + return { + prompt: sdxl.positiveStylePrompt, + activeTabName, + }; + }, + { + memoizeOptions: { + resultEqualityCheck: isEqual, + }, + } +); + +/** + * Prompt input text area. + */ +const ParamSDXLPositiveStyleConditioning = () => { + const dispatch = useAppDispatch(); + const { prompt, activeTabName } = useAppSelector(promptInputSelector); + const isReady = useIsReadyToInvoke(); + const promptRef = useRef(null); + const { isOpen, onClose, onOpen } = useDisclosure(); + + const handleChangePrompt = useCallback( + (e: ChangeEvent) => { + dispatch(setPositiveStylePromptSDXL(e.target.value)); + }, + [dispatch] + ); + + const handleSelectEmbedding = useCallback( + (v: string) => { + if (!promptRef.current) { + return; + } + + // this is where we insert the TI trigger + const caret = promptRef.current.selectionStart; + + if (caret === undefined) { + return; + } + + let newPrompt = prompt.slice(0, caret); + + if (newPrompt[newPrompt.length - 1] !== '<') { + newPrompt += '<'; + } + + newPrompt += `${v}>`; + + // we insert the cursor after the `>` + const finalCaretPos = newPrompt.length; + + newPrompt += prompt.slice(caret); + + // must flush dom updates else selection gets reset + flushSync(() => { + dispatch(setPositiveStylePromptSDXL(newPrompt)); + }); + + // set the caret position to just after the TI trigger + promptRef.current.selectionStart = finalCaretPos; + promptRef.current.selectionEnd = finalCaretPos; + onClose(); + }, + [dispatch, onClose, prompt] + ); + + const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled; + + const handleKeyDown = useCallback( + (e: KeyboardEvent) => { + if (e.key === 'Enter' && e.shiftKey === false && isReady) { + e.preventDefault(); + dispatch(clampSymmetrySteps()); + dispatch(userInvoked(activeTabName)); + } + if (isEmbeddingEnabled && e.key === '<') { + onOpen(); + } + }, + [isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled] + ); + + // const handleSelect = (e: MouseEvent) => { + // const target = e.target as HTMLTextAreaElement; + // setCaret({ start: target.selectionStart, end: target.selectionEnd }); + // }; + + return ( + + + + + + + {!isOpen && isEmbeddingEnabled && ( + + + + )} + + ); +}; + +export default ParamSDXLPositiveStyleConditioning; diff --git a/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLRefinerCollapse.tsx b/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLRefinerCollapse.tsx new file mode 100644 index 0000000000..37e1718dc6 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/ParamSDXLRefinerCollapse.tsx @@ -0,0 +1,48 @@ +import { Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAICollapse from 'common/components/IAICollapse'; +import ParamSDXLRefinerAestheticScore from './SDXLRefiner/ParamSDXLRefinerAestheticScore'; +import ParamSDXLRefinerCFGScale from './SDXLRefiner/ParamSDXLRefinerCFGScale'; +import ParamSDXLRefinerModelSelect from './SDXLRefiner/ParamSDXLRefinerModelSelect'; +import ParamSDXLRefinerScheduler from './SDXLRefiner/ParamSDXLRefinerScheduler'; +import ParamSDXLRefinerStart from './SDXLRefiner/ParamSDXLRefinerStart'; +import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps'; +import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner'; + +const selector = createSelector( + stateSelector, + (state) => { + const { shouldUseSDXLRefiner } = state.sdxl; + const { shouldUseSliders } = state.ui; + return { + activeLabel: shouldUseSDXLRefiner ? 'Enabled' : undefined, + shouldUseSliders, + }; + }, + defaultSelectorOptions +); + +const ParamSDXLRefinerCollapse = () => { + const { activeLabel, shouldUseSliders } = useAppSelector(selector); + + return ( + + + + + + + + + + + + + + ); +}; + +export default ParamSDXLRefinerCollapse; diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabCoreParameters.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabCoreParameters.tsx new file mode 100644 index 0000000000..4d7c919655 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabCoreParameters.tsx @@ -0,0 +1,78 @@ +import { Box, Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAICollapse from 'common/components/IAICollapse'; +import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; +import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; +import ParamModelandVAEandScheduler from 'features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler'; +import ParamSize from 'features/parameters/components/Parameters/Core/ParamSize'; +import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; +import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit'; +import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; +import { generationSelector } from 'features/parameters/store/generationSelectors'; +import { uiSelector } from 'features/ui/store/uiSelectors'; +import { memo } from 'react'; +import ParamSDXLImg2ImgDenoisingStrength from './ParamSDXLImg2ImgDenoisingStrength'; + +const selector = createSelector( + [uiSelector, generationSelector], + (ui, generation) => { + const { shouldUseSliders } = ui; + const { shouldRandomizeSeed } = generation; + + const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined; + + return { shouldUseSliders, activeLabel }; + }, + defaultSelectorOptions +); + +const SDXLImageToImageTabCoreParameters = () => { + const { shouldUseSliders, activeLabel } = useAppSelector(selector); + + return ( + + + {shouldUseSliders ? ( + <> + + + + + + + + + + ) : ( + <> + + + + + + + + + + + + )} + + + + + ); +}; + +export default memo(SDXLImageToImageTabCoreParameters); diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx new file mode 100644 index 0000000000..778116eefe --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLImageToImageTabParameters.tsx @@ -0,0 +1,28 @@ +import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; +import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; +import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; +// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; +import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; +import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning'; +import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning'; +import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse'; +import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters'; + +const SDXLImageToImageTabParameters = () => { + return ( + <> + + + + + + + + + + + ); +}; + +export default SDXLImageToImageTabParameters; diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerAestheticScore.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerAestheticScore.tsx new file mode 100644 index 0000000000..9c9c4b2f89 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerAestheticScore.tsx @@ -0,0 +1,60 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISlider from 'common/components/IAISlider'; +import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice'; +import { memo, useCallback } from 'react'; +import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable'; + +const selector = createSelector( + [stateSelector], + ({ sdxl, hotkeys }) => { + const { refinerAestheticScore } = sdxl; + const { shift } = hotkeys; + + return { + refinerAestheticScore, + shift, + }; + }, + defaultSelectorOptions +); + +const ParamSDXLRefinerAestheticScore = () => { + const { refinerAestheticScore, shift } = useAppSelector(selector); + + const isRefinerAvailable = useIsRefinerAvailable(); + + const dispatch = useAppDispatch(); + + const handleChange = useCallback( + (v: number) => dispatch(setRefinerAestheticScore(v)), + [dispatch] + ); + + const handleReset = useCallback( + () => dispatch(setRefinerAestheticScore(6)), + [dispatch] + ); + + return ( + + ); +}; + +export default memo(ParamSDXLRefinerAestheticScore); diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerCFGScale.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerCFGScale.tsx new file mode 100644 index 0000000000..dd678ac0f7 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerCFGScale.tsx @@ -0,0 +1,75 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAINumberInput from 'common/components/IAINumberInput'; +import IAISlider from 'common/components/IAISlider'; +import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable'; + +const selector = createSelector( + [stateSelector], + ({ sdxl, ui, hotkeys }) => { + const { refinerCFGScale } = sdxl; + const { shouldUseSliders } = ui; + const { shift } = hotkeys; + + return { + refinerCFGScale, + shouldUseSliders, + shift, + }; + }, + defaultSelectorOptions +); + +const ParamSDXLRefinerCFGScale = () => { + const { refinerCFGScale, shouldUseSliders, shift } = useAppSelector(selector); + const isRefinerAvailable = useIsRefinerAvailable(); + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const handleChange = useCallback( + (v: number) => dispatch(setRefinerCFGScale(v)), + [dispatch] + ); + + const handleReset = useCallback( + () => dispatch(setRefinerCFGScale(7)), + [dispatch] + ); + + return shouldUseSliders ? ( + + ) : ( + + ); +}; + +export default memo(ParamSDXLRefinerCFGScale); diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx new file mode 100644 index 0000000000..cae40bbff3 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx @@ -0,0 +1,111 @@ +import { Box, Flex } from '@chakra-ui/react'; +import { SelectItem } from '@mantine/core'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; +import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; +import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; +import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice'; +import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton'; +import { forEach } from 'lodash-es'; +import { memo, useCallback, useMemo } from 'react'; +import { REFINER_BASE_MODELS } from 'services/api/constants'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; + +const selector = createSelector( + stateSelector, + (state) => ({ model: state.sdxl.refinerModel }), + defaultSelectorOptions +); + +const ParamSDXLRefinerModelSelect = () => { + const dispatch = useAppDispatch(); + const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled; + + const { model } = useAppSelector(selector); + + const { data: refinerModels, isLoading } = + useGetMainModelsQuery(REFINER_BASE_MODELS); + + const data = useMemo(() => { + if (!refinerModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(refinerModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.model_name, + group: MODEL_TYPE_MAP[model.base_model], + }); + }); + + return data; + }, [refinerModels]); + + // grab the full model entity from the RTK Query cache + // TODO: maybe we should just store the full model entity in state? + const selectedModel = useMemo( + () => + refinerModels?.entities[ + `${model?.base_model}/main/${model?.model_name}` + ] ?? null, + [refinerModels?.entities, model] + ); + + const handleChangeModel = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + const newModel = modelIdToMainModelParam(v); + + if (!newModel) { + return; + } + + dispatch(refinerModelChanged(newModel)); + }, + [dispatch] + ); + + return isLoading ? ( + + ) : ( + + 0 ? 'Select a model' : 'No models available'} + data={data} + error={data.length === 0} + disabled={data.length === 0} + onChange={handleChangeModel} + w="100%" + /> + {isSyncModelEnabled && ( + + + + )} + + ); +}; + +export default memo(ParamSDXLRefinerModelSelect); diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx new file mode 100644 index 0000000000..e14eb0b5f8 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx @@ -0,0 +1,65 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; +import { + SCHEDULER_LABEL_MAP, + SchedulerParam, +} from 'features/parameters/types/parameterSchemas'; +import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice'; +import { map } from 'lodash-es'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable'; + +const selector = createSelector( + stateSelector, + ({ ui, sdxl }) => { + const { refinerScheduler } = sdxl; + const { favoriteSchedulers: enabledSchedulers } = ui; + + const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({ + value: name, + label: label, + group: enabledSchedulers.includes(name as SchedulerParam) + ? 'Favorites' + : undefined, + })).sort((a, b) => a.label.localeCompare(b.label)); + + return { + refinerScheduler, + data, + }; + }, + defaultSelectorOptions +); + +const ParamSDXLRefinerScheduler = () => { + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + const { refinerScheduler, data } = useAppSelector(selector); + const isRefinerAvailable = useIsRefinerAvailable(); + const handleChange = useCallback( + (v: string | null) => { + if (!v) { + return; + } + dispatch(setRefinerScheduler(v as SchedulerParam)); + }, + [dispatch] + ); + + return ( + + ); +}; + +export default memo(ParamSDXLRefinerScheduler); diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerStart.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerStart.tsx new file mode 100644 index 0000000000..a98259203c --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerStart.tsx @@ -0,0 +1,53 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISlider from 'common/components/IAISlider'; +import { setRefinerStart } from 'features/sdxl/store/sdxlSlice'; +import { memo, useCallback } from 'react'; +import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable'; + +const selector = createSelector( + [stateSelector], + ({ sdxl }) => { + const { refinerStart } = sdxl; + return { + refinerStart, + }; + }, + defaultSelectorOptions +); + +const ParamSDXLRefinerStart = () => { + const { refinerStart } = useAppSelector(selector); + const dispatch = useAppDispatch(); + const isRefinerAvailable = useIsRefinerAvailable(); + const handleChange = useCallback( + (v: number) => dispatch(setRefinerStart(v)), + [dispatch] + ); + + const handleReset = useCallback( + () => dispatch(setRefinerStart(0.7)), + [dispatch] + ); + + return ( + + ); +}; + +export default memo(ParamSDXLRefinerStart); diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx new file mode 100644 index 0000000000..456cbb5d3a --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerSteps.tsx @@ -0,0 +1,72 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAINumberInput from 'common/components/IAINumberInput'; +import IAISlider from 'common/components/IAISlider'; +import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable'; + +const selector = createSelector( + [stateSelector], + ({ sdxl, ui }) => { + const { refinerSteps } = sdxl; + const { shouldUseSliders } = ui; + + return { + refinerSteps, + shouldUseSliders, + }; + }, + defaultSelectorOptions +); + +const ParamSDXLRefinerSteps = () => { + const { refinerSteps, shouldUseSliders } = useAppSelector(selector); + const isRefinerAvailable = useIsRefinerAvailable(); + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const handleChange = useCallback( + (v: number) => { + dispatch(setRefinerSteps(v)); + }, + [dispatch] + ); + const handleReset = useCallback(() => { + dispatch(setRefinerSteps(20)); + }, [dispatch]); + + return shouldUseSliders ? ( + + ) : ( + + ); +}; + +export default memo(ParamSDXLRefinerSteps); diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx new file mode 100644 index 0000000000..1649f95e9a --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamUseSDXLRefiner.tsx @@ -0,0 +1,28 @@ +import { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAISwitch from 'common/components/IAISwitch'; +import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice'; +import { ChangeEvent } from 'react'; +import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable'; + +export default function ParamUseSDXLRefiner() { + const shouldUseSDXLRefiner = useAppSelector( + (state: RootState) => state.sdxl.shouldUseSDXLRefiner + ); + const isRefinerAvailable = useIsRefinerAvailable(); + + const dispatch = useAppDispatch(); + + const handleUseSDXLRefinerChange = (e: ChangeEvent) => { + dispatch(setShouldUseSDXLRefiner(e.target.checked)); + }; + + return ( + + ); +} diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx new file mode 100644 index 0000000000..2175fcc9e3 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLTextToImageTabParameters.tsx @@ -0,0 +1,27 @@ +import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; +import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; +import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; +import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; +import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters'; +import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning'; +import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning'; +import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse'; + +const SDXLTextToImageTabParameters = () => { + return ( + <> + + + + + + + + + + + ); +}; + +export default SDXLTextToImageTabParameters; diff --git a/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts b/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts new file mode 100644 index 0000000000..16bb806029 --- /dev/null +++ b/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts @@ -0,0 +1,89 @@ +import { PayloadAction, createSlice } from '@reduxjs/toolkit'; +import { + MainModelParam, + NegativeStylePromptSDXLParam, + PositiveStylePromptSDXLParam, + SchedulerParam, +} from 'features/parameters/types/parameterSchemas'; +import { MainModelField } from 'services/api/types'; + +type SDXLInitialState = { + positiveStylePrompt: PositiveStylePromptSDXLParam; + negativeStylePrompt: NegativeStylePromptSDXLParam; + shouldUseSDXLRefiner: boolean; + sdxlImg2ImgDenoisingStrength: number; + refinerModel: MainModelField | null; + refinerSteps: number; + refinerCFGScale: number; + refinerScheduler: SchedulerParam; + refinerAestheticScore: number; + refinerStart: number; +}; + +const sdxlInitialState: SDXLInitialState = { + positiveStylePrompt: '', + negativeStylePrompt: '', + shouldUseSDXLRefiner: false, + sdxlImg2ImgDenoisingStrength: 0.7, + refinerModel: null, + refinerSteps: 20, + refinerCFGScale: 7.5, + refinerScheduler: 'euler', + refinerAestheticScore: 6, + refinerStart: 0.7, +}; + +const sdxlSlice = createSlice({ + name: 'sdxl', + initialState: sdxlInitialState, + reducers: { + setPositiveStylePromptSDXL: (state, action: PayloadAction) => { + state.positiveStylePrompt = action.payload; + }, + setNegativeStylePromptSDXL: (state, action: PayloadAction) => { + state.negativeStylePrompt = action.payload; + }, + setShouldUseSDXLRefiner: (state, action: PayloadAction) => { + state.shouldUseSDXLRefiner = action.payload; + }, + setSDXLImg2ImgDenoisingStrength: (state, action: PayloadAction) => { + state.sdxlImg2ImgDenoisingStrength = action.payload; + }, + refinerModelChanged: ( + state, + action: PayloadAction + ) => { + state.refinerModel = action.payload; + }, + setRefinerSteps: (state, action: PayloadAction) => { + state.refinerSteps = action.payload; + }, + setRefinerCFGScale: (state, action: PayloadAction) => { + state.refinerCFGScale = action.payload; + }, + setRefinerScheduler: (state, action: PayloadAction) => { + state.refinerScheduler = action.payload; + }, + setRefinerAestheticScore: (state, action: PayloadAction) => { + state.refinerAestheticScore = action.payload; + }, + setRefinerStart: (state, action: PayloadAction) => { + state.refinerStart = action.payload; + }, + }, +}); + +export const { + setPositiveStylePromptSDXL, + setNegativeStylePromptSDXL, + setShouldUseSDXLRefiner, + setSDXLImg2ImgDenoisingStrength, + refinerModelChanged, + setRefinerSteps, + setRefinerCFGScale, + setRefinerScheduler, + setRefinerAestheticScore, + setRefinerStart, +} = sdxlSlice.actions; + +export default sdxlSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx index 2deccfa46d..49102a5bbd 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx @@ -26,6 +26,8 @@ import { setShouldConfirmOnDelete, shouldAntialiasProgressImageChanged, shouldLogToConsoleChanged, + shouldUseNSFWCheckerChanged, + shouldUseWatermarkerChanged, } from 'features/system/store/systemSlice'; import { setShouldShowProgressInViewer, @@ -42,6 +44,7 @@ import { } from 'react'; import { useTranslation } from 'react-i18next'; import { LogLevelName } from 'roarr'; +import { useGetAppConfigQuery } from 'services/api/endpoints/appInfo'; import SettingSwitch from './SettingSwitch'; import SettingsClearIntermediates from './SettingsClearIntermediates'; import SettingsSchedulers from './SettingsSchedulers'; @@ -57,6 +60,8 @@ const selector = createSelector( shouldLogToConsole, shouldAntialiasProgressImage, isNodesEnabled, + shouldUseNSFWChecker, + shouldUseWatermarker, } = system; const { @@ -78,6 +83,8 @@ const selector = createSelector( shouldAntialiasProgressImage, shouldShowAdvancedOptions, isNodesEnabled, + shouldUseNSFWChecker, + shouldUseWatermarker, }; }, { @@ -120,6 +127,16 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => { } }, [shouldShowDeveloperSettings, dispatch]); + const { isNSFWCheckerAvailable, isWatermarkerAvailable } = + useGetAppConfigQuery(undefined, { + selectFromResult: ({ data }) => ({ + isNSFWCheckerAvailable: + data?.nsfw_methods.includes('nsfw_checker') ?? false, + isWatermarkerAvailable: + data?.watermarking_methods.includes('invisible_watermark') ?? false, + }), + }); + const { isOpen: isSettingsModalOpen, onOpen: onSettingsModalOpen, @@ -143,6 +160,8 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => { shouldAntialiasProgressImage, shouldShowAdvancedOptions, isNodesEnabled, + shouldUseNSFWChecker, + shouldUseWatermarker, } = useAppSelector(selector); const handleClickResetWebUI = useCallback(() => { @@ -221,6 +240,22 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => { {t('settings.generation')} + ) => + dispatch(shouldUseNSFWCheckerChanged(e.target.checked)) + } + /> + ) => + dispatch(shouldUseWatermarkerChanged(e.target.checked)) + } + /> diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 629a4f0139..7189058632 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -1,5 +1,5 @@ import { UseToastOptions } from '@chakra-ui/react'; -import { PayloadAction, createSlice } from '@reduxjs/toolkit'; +import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit'; import { InvokeLogLevel } from 'app/logging/logger'; import { userInvoked } from 'app/store/actions'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; @@ -16,13 +16,16 @@ import { appSocketGraphExecutionStateComplete, appSocketInvocationComplete, appSocketInvocationError, + appSocketInvocationRetrievalError, appSocketInvocationStarted, + appSocketSessionRetrievalError, appSocketSubscribed, appSocketUnsubscribed, } from 'services/events/actions'; import { ProgressImage } from 'services/events/types'; import { makeToast } from '../util/makeToast'; import { LANGUAGES } from './constants'; +import { startCase } from 'lodash-es'; export type CancelStrategy = 'immediate' | 'scheduled'; @@ -84,6 +87,8 @@ export interface SystemState { language: keyof typeof LANGUAGES; isUploading: boolean; isNodesEnabled: boolean; + shouldUseNSFWChecker: boolean; + shouldUseWatermarker: boolean; } export const initialSystemState: SystemState = { @@ -116,6 +121,8 @@ export const initialSystemState: SystemState = { language: 'en', isUploading: false, isNodesEnabled: false, + shouldUseNSFWChecker: true, + shouldUseWatermarker: true, }; export const systemSlice = createSlice({ @@ -191,6 +198,12 @@ export const systemSlice = createSlice({ setIsNodesEnabled(state, action: PayloadAction) { state.isNodesEnabled = action.payload; }, + shouldUseNSFWCheckerChanged(state, action: PayloadAction) { + state.shouldUseNSFWChecker = action.payload; + }, + shouldUseWatermarkerChanged(state, action: PayloadAction) { + state.shouldUseWatermarker = action.payload; + }, }, extraReducers(builder) { /** @@ -288,25 +301,6 @@ export const systemSlice = createSlice({ } }); - /** - * Invocation Error - */ - builder.addCase(appSocketInvocationError, (state) => { - state.isProcessing = false; - state.isCancelable = true; - // state.currentIteration = 0; - // state.totalIterations = 0; - state.currentStatusHasSteps = false; - state.currentStep = 0; - state.totalSteps = 0; - state.statusTranslationKey = 'common.statusError'; - state.progressImage = null; - - state.toastQueue.push( - makeToast({ title: t('toast.serverError'), status: 'error' }) - ); - }); - /** * Graph Execution State Complete */ @@ -362,7 +356,7 @@ export const systemSlice = createSlice({ * Session Invoked - REJECTED * Session Created - REJECTED */ - builder.addMatcher(isAnySessionRejected, (state) => { + builder.addMatcher(isAnySessionRejected, (state, action) => { state.isProcessing = false; state.isCancelable = false; state.isCancelScheduled = false; @@ -372,7 +366,35 @@ export const systemSlice = createSlice({ state.progressImage = null; state.toastQueue.push( - makeToast({ title: t('toast.serverError'), status: 'error' }) + makeToast({ + title: t('toast.serverError'), + status: 'error', + description: + action.payload?.status === 422 ? 'Validation Error' : undefined, + }) + ); + }); + + /** + * Any server error + */ + builder.addMatcher(isAnyServerError, (state, action) => { + state.isProcessing = false; + state.isCancelable = true; + // state.currentIteration = 0; + // state.totalIterations = 0; + state.currentStatusHasSteps = false; + state.currentStep = 0; + state.totalSteps = 0; + state.statusTranslationKey = 'common.statusError'; + state.progressImage = null; + + state.toastQueue.push( + makeToast({ + title: t('toast.serverError'), + status: 'error', + description: startCase(action.payload.data.error_type), + }) ); }); }, @@ -397,6 +419,14 @@ export const { languageChanged, progressImageSet, setIsNodesEnabled, + shouldUseNSFWCheckerChanged, + shouldUseWatermarkerChanged, } = systemSlice.actions; export default systemSlice.reducer; + +const isAnyServerError = isAnyOf( + appSocketInvocationError, + appSocketSessionRetrievalError, + appSocketInvocationRetrievalError +); diff --git a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx index 6c683470e7..26d06adfb3 100644 --- a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx +++ b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx @@ -16,7 +16,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent'; import { configSelector } from 'features/system/store/configSelectors'; -import { InvokeTabName } from 'features/ui/store/tabMap'; +import { InvokeTabName, tabMap } from 'features/ui/store/tabMap'; import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice'; import { ResourceKey } from 'i18next'; import { isEqual } from 'lodash-es'; @@ -172,13 +172,22 @@ const InvokeTabs = () => { const { ref: galleryPanelRef, minSizePct: galleryMinSizePct } = useMinimumPanelSize(MIN_GALLERY_WIDTH, DEFAULT_GALLERY_PCT, 'app'); + const handleTabChange = useCallback( + (index: number) => { + const activeTabName = tabMap[index]; + if (!activeTabName) { + return; + } + dispatch(setActiveTab(activeTabName)); + }, + [dispatch] + ); + return ( { - dispatch(setActiveTab(index)); - }} + onChange={handleTabChange} sx={{ flexGrow: 1, gap: 4, diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTab.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTab.tsx index a0ec95d72d..d58630d1b9 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTab.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTab.tsx @@ -1,7 +1,9 @@ import { Box, Flex } from '@chakra-ui/react'; -import { useAppDispatch } from 'app/store/storeHooks'; +import { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import InitialImageDisplay from 'features/parameters/components/Parameters/ImageToImage/InitialImageDisplay'; +import SDXLImageToImageTabParameters from 'features/sdxl/components/SDXLImageToImageTabParameters'; import { memo, useCallback, useRef } from 'react'; import { ImperativePanelGroupHandle, @@ -16,6 +18,7 @@ import ImageToImageTabParameters from './ImageToImageTabParameters'; const ImageToImageTab = () => { const dispatch = useAppDispatch(); const panelGroupRef = useRef(null); + const model = useAppSelector((state: RootState) => state.generation.model); const handleDoubleClickHandle = useCallback(() => { if (!panelGroupRef.current) { @@ -28,7 +31,11 @@ const ImageToImageTab = () => { return ( - + {model && model.base_model === 'sdxl' ? ( + + ) : ( + + )} (''); // Get paths of models that are already installed - const { data: installedModels } = useGetMainModelsQuery(); + const { data: installedModels } = useGetMainModelsQuery(ALL_BASE_MODELS); // Get all model paths from a given directory const { foundModels, alreadyInstalled, filteredModels } = diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx index 19ca10e240..4ad8fbaba6 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx @@ -1,5 +1,4 @@ import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react'; -import { makeToast } from 'features/system/util/makeToast'; import { useAppDispatch } from 'app/store/storeHooks'; import IAIButton from 'common/components/IAIButton'; import IAIInput from 'common/components/IAIInput'; @@ -8,9 +7,11 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAISlider from 'common/components/IAISlider'; import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; import { pickBy } from 'lodash-es'; import { useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { ALL_BASE_MODELS } from 'services/api/constants'; import { useGetMainModelsQuery, useMergeMainModelsMutation, @@ -32,7 +33,7 @@ export default function MergeModelsPanel() { const { t } = useTranslation(); const dispatch = useAppDispatch(); - const { data } = useGetMainModelsQuery(); + const { data } = useGetMainModelsQuery(ALL_BASE_MODELS); const [mergeModels, { isLoading }] = useMergeMainModelsMutation(); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx index f49294cfb0..87eb918564 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx @@ -8,10 +8,11 @@ import { import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import ModelList from './ModelManagerPanel/ModelList'; +import { ALL_BASE_MODELS } from 'services/api/constants'; export default function ModelManagerPanel() { const [selectedModelId, setSelectedModelId] = useState(); - const { model } = useGetMainModelsQuery(undefined, { + const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, { selectFromResult: ({ data }) => ({ model: selectedModelId ? data?.entities[selectedModelId] : undefined, }), diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index 722bd83b6e..f3d0eae495 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -11,6 +11,7 @@ import { useGetMainModelsQuery, } from 'services/api/endpoints/models'; import ModelListItem from './ModelListItem'; +import { ALL_BASE_MODELS } from 'services/api/constants'; type ModelListProps = { selectedModelId: string | undefined; @@ -26,13 +27,13 @@ const ModelList = (props: ModelListProps) => { const [modelFormatFilter, setModelFormatFilter] = useState('images'); - const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, { + const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { selectFromResult: ({ data }) => ({ filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter), }), }); - const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, { + const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { selectFromResult: ({ data }) => ({ filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter), }), diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTab.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTab.tsx index 90141af785..8c3add3d62 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTab.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTab.tsx @@ -1,14 +1,22 @@ import { Flex } from '@chakra-ui/react'; +import { RootState } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import TextToImageSDXLTabParameters from 'features/sdxl/components/SDXLTextToImageTabParameters'; import { memo } from 'react'; import ParametersPinnedWrapper from '../../ParametersPinnedWrapper'; import TextToImageTabMain from './TextToImageTabMain'; import TextToImageTabParameters from './TextToImageTabParameters'; const TextToImageTab = () => { + const model = useAppSelector((state: RootState) => state.generation.model); return ( - + {model && model.base_model === 'sdxl' ? ( + + ) : ( + + )} diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index 81243aa03f..e487f08067 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -26,7 +26,7 @@ export const uiSlice = createSlice({ name: 'ui', initialState: initialUIState, reducers: { - setActiveTab: (state, action: PayloadAction) => { + setActiveTab: (state, action: PayloadAction) => { setActiveTabReducer(state, action.payload); }, setShouldPinParametersPanel: (state, action: PayloadAction) => { diff --git a/invokeai/frontend/web/src/services/api/constants.ts b/invokeai/frontend/web/src/services/api/constants.ts new file mode 100644 index 0000000000..8bf35d0198 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/constants.ts @@ -0,0 +1,16 @@ +import { BaseModelType } from './types'; + +export const ALL_BASE_MODELS: BaseModelType[] = [ + 'sd-1', + 'sd-2', + 'sdxl', + 'sdxl-refiner', +]; + +export const NON_REFINER_BASE_MODELS: BaseModelType[] = [ + 'sd-1', + 'sd-2', + 'sdxl', +]; + +export const REFINER_BASE_MODELS: BaseModelType[] = ['sdxl-refiner']; diff --git a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts index f76b56761c..2d3537998d 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts @@ -8,6 +8,7 @@ export const appInfoApi = api.injectEndpoints({ url: `app/version`, method: 'GET', }), + providesTags: ['AppVersion'], keepUnusedDataFor: 86400000, // 1 day }), getAppConfig: build.query({ @@ -15,6 +16,7 @@ export const appInfoApi = api.injectEndpoints({ url: `app/config`, method: 'GET', }), + providesTags: ['AppConfig'], keepUnusedDataFor: 86400000, // 1 day }), }), diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index ff82bc2802..3d0013a62c 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -144,8 +144,19 @@ const createModelEntities = ( export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ - getMainModels: build.query, void>({ - query: () => ({ url: 'models/', params: { model_type: 'main' } }), + getMainModels: build.query< + EntityState, + BaseModelType[] + >({ + query: (base_models) => { + const params = { + model_type: 'main', + base_models, + }; + + const query = queryString.stringify(params, { arrayFormat: 'none' }); + return `models/?${query}`; + }, providesTags: (result, error, arg) => { const tags: ApiFullTagDescription[] = [ { type: 'MainModel', id: LIST_TAG }, @@ -187,7 +198,10 @@ export const modelsApi = api.injectEndpoints({ body: body, }; }, - invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], + invalidatesTags: [ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ], }), importMainModels: build.mutation< ImportMainModelResponse, @@ -200,7 +214,10 @@ export const modelsApi = api.injectEndpoints({ body: body, }; }, - invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], + invalidatesTags: [ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ], }), addMainModels: build.mutation({ query: ({ body }) => { @@ -210,7 +227,10 @@ export const modelsApi = api.injectEndpoints({ body: body, }; }, - invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], + invalidatesTags: [ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ], }), deleteMainModels: build.mutation< DeleteMainModelResponse, @@ -222,7 +242,10 @@ export const modelsApi = api.injectEndpoints({ method: 'DELETE', }; }, - invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], + invalidatesTags: [ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ], }), convertMainModels: build.mutation< ConvertMainModelResponse, @@ -235,7 +258,10 @@ export const modelsApi = api.injectEndpoints({ params: params, }; }, - invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], + invalidatesTags: [ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ], }), mergeMainModels: build.mutation({ query: ({ base_model, body }) => { @@ -245,7 +271,10 @@ export const modelsApi = api.injectEndpoints({ body: body, }; }, - invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], + invalidatesTags: [ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ], }), syncModels: build.mutation({ query: () => { @@ -254,7 +283,10 @@ export const modelsApi = api.injectEndpoints({ method: 'POST', }; }, - invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], + invalidatesTags: [ + { type: 'MainModel', id: LIST_TAG }, + { type: 'SDXLRefinerModel', id: LIST_TAG }, + ], }), getLoRAModels: build.query, void>({ query: () => ({ url: 'models/', params: { model_type: 'lora' } }), diff --git a/invokeai/frontend/web/src/services/api/hooks/useIsRefinerAvailable.ts b/invokeai/frontend/web/src/services/api/hooks/useIsRefinerAvailable.ts new file mode 100644 index 0000000000..4cb4891be4 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/hooks/useIsRefinerAvailable.ts @@ -0,0 +1,12 @@ +import { REFINER_BASE_MODELS } from 'services/api/constants'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; + +export const useIsRefinerAvailable = () => { + const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, { + selectFromResult: ({ data }) => ({ + isRefinerAvailable: data ? data.ids.length > 0 : false, + }), + }); + + return isRefinerAvailable; +}; diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 32e0400a8e..c8e87ef8c5 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -318,6 +318,21 @@ export type components = { * @description List of available infill methods */ infill_methods: (string)[]; + /** + * Upscaling Methods + * @description List of upscaling methods + */ + upscaling_methods: (components["schemas"]["Upscaler"])[]; + /** + * Nsfw Methods + * @description List of NSFW checking methods + */ + nsfw_methods: (string)[]; + /** + * Watermarking Methods + * @description List of invisible watermark methods + */ + watermarking_methods: (string)[]; }; /** * AppVersion @@ -1014,6 +1029,11 @@ export type components = { * @description The LoRAs used for inference */ loras: (components["schemas"]["LoRAMetadataField"])[]; + /** + * Vae + * @description The VAE used for decoding, if the main model's default was not used + */ + vae?: components["schemas"]["VAEModelField"]; /** * Strength * @description The strength used for latents-to-latents @@ -1025,10 +1045,45 @@ export type components = { */ init_image?: string; /** - * Vae - * @description The VAE used for decoding, if the main model's default was not used + * Positive Style Prompt + * @description The positive style prompt parameter */ - vae?: components["schemas"]["VAEModelField"]; + positive_style_prompt?: string; + /** + * Negative Style Prompt + * @description The negative style prompt parameter + */ + negative_style_prompt?: string; + /** + * Refiner Model + * @description The SDXL Refiner model used + */ + refiner_model?: components["schemas"]["MainModelField"]; + /** + * Refiner Cfg Scale + * @description The classifier-free guidance scale parameter used for the refiner + */ + refiner_cfg_scale?: number; + /** + * Refiner Steps + * @description The number of steps used for the refiner + */ + refiner_steps?: number; + /** + * Refiner Scheduler + * @description The scheduler used for the refiner + */ + refiner_scheduler?: string; + /** + * Refiner Aesthetic Store + * @description The aesthetic score used for the refiner + */ + refiner_aesthetic_store?: number; + /** + * Refiner Start + * @description The start value used for refiner denoising + */ + refiner_start?: number; }; /** * CvInpaintInvocation @@ -1305,7 +1360,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined; + [key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined; }; /** * Edges @@ -1348,7 +1403,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined; + [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined; }; /** * Errors @@ -1922,22 +1977,11 @@ export type components = { * @description The image to check */ image?: components["schemas"]["ImageField"]; - /** - * Enabled - * @description Whether the NSFW checker is enabled - * @default false - */ - enabled?: boolean; /** * Metadata * @description Optional core metadata to be written to the image */ metadata?: components["schemas"]["CoreMetadata"]; - /** - * Default Enabled - * @default false - */ - DEFAULT_ENABLED?: boolean; }; /** * ImageOutput @@ -2252,22 +2296,11 @@ export type components = { * @default InvokeAI */ text?: string; - /** - * Enabled - * @description Whether the invisible watermark is enabled - * @default true - */ - enabled?: boolean; /** * Metadata * @description Optional core metadata to be written to the image */ metadata?: components["schemas"]["CoreMetadata"]; - /** - * Default Enabled - * @default true - */ - DEFAULT_ENABLED?: boolean; }; /** * InfillColorInvocation @@ -3362,6 +3395,46 @@ export type components = { * @description The VAE used for decoding, if the main model's default was not used */ vae?: components["schemas"]["VAEModelField"]; + /** + * Positive Style Prompt + * @description The positive style prompt parameter + */ + positive_style_prompt?: string; + /** + * Negative Style Prompt + * @description The negative style prompt parameter + */ + negative_style_prompt?: string; + /** + * Refiner Model + * @description The SDXL Refiner model used + */ + refiner_model?: components["schemas"]["MainModelField"]; + /** + * Refiner Cfg Scale + * @description The classifier-free guidance scale parameter used for the refiner + */ + refiner_cfg_scale?: number; + /** + * Refiner Steps + * @description The number of steps used for the refiner + */ + refiner_steps?: number; + /** + * Refiner Scheduler + * @description The scheduler used for the refiner + */ + refiner_scheduler?: string; + /** + * Refiner Aesthetic Store + * @description The aesthetic score used for the refiner + */ + refiner_aesthetic_store?: number; + /** + * Refiner Start + * @description The start value used for refiner denoising + */ + refiner_start?: number; }; /** * MetadataAccumulatorOutput @@ -5323,6 +5396,19 @@ export type components = { */ loras: (components["schemas"]["LoraInfo"])[]; }; + /** Upscaler */ + Upscaler: { + /** + * Upscaling Method + * @description Name of upscaling method + */ + upscaling_method: string; + /** + * Upscaling Models + * @description List of upscaling models for this method + */ + upscaling_models: (string)[]; + }; /** * VAEModelField * @description Vae model field @@ -5449,6 +5535,12 @@ export type components = { */ image?: components["schemas"]["ImageField"]; }; + /** + * StableDiffusion1ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusion2ModelFormat * @description An enumeration. @@ -5461,12 +5553,6 @@ export type components = { * @enum {string} */ StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; - /** - * StableDiffusion1ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; @@ -5577,7 +5663,7 @@ export type operations = { }; requestBody: { content: { - "application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; + "application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; }; }; responses: { @@ -5614,7 +5700,7 @@ export type operations = { }; requestBody: { content: { - "application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; + "application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; }; }; responses: { diff --git a/invokeai/frontend/web/src/services/api/thunks/session.ts b/invokeai/frontend/web/src/services/api/thunks/session.ts index 6d20b9dd33..5588f25b46 100644 --- a/invokeai/frontend/web/src/services/api/thunks/session.ts +++ b/invokeai/frontend/web/src/services/api/thunks/session.ts @@ -18,7 +18,7 @@ type CreateSessionResponse = O.Required< >; type CreateSessionThunkConfig = { - rejectValue: { arg: CreateSessionArg; error: unknown }; + rejectValue: { arg: CreateSessionArg; status: number; error: unknown }; }; /** @@ -36,7 +36,7 @@ export const sessionCreated = createAsyncThunk< }); if (error) { - return rejectWithValue({ arg, error }); + return rejectWithValue({ arg, status: response.status, error }); } return data; @@ -53,6 +53,7 @@ type InvokedSessionThunkConfig = { rejectValue: { arg: InvokedSessionArg; error: unknown; + status: number; }; }; @@ -78,9 +79,13 @@ export const sessionInvoked = createAsyncThunk< if (error) { if (isErrorWithStatus(error) && error.status === 403) { - return rejectWithValue({ arg, error: (error as any).body.detail }); + return rejectWithValue({ + arg, + status: response.status, + error: (error as any).body.detail, + }); } - return rejectWithValue({ arg, error }); + return rejectWithValue({ arg, status: response.status, error }); } }); diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index 3e945691f1..7d9040321c 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -134,6 +134,12 @@ export type ESRGANInvocation = TypeReq< export type DivideInvocation = TypeReq< components['schemas']['DivideInvocation'] >; +export type ImageNSFWBlurInvocation = TypeReq< + components['schemas']['ImageNSFWBlurInvocation'] +>; +export type ImageWatermarkInvocation = TypeReq< + components['schemas']['ImageWatermarkInvocation'] +>; // ControlNet Nodes export type ControlNetInvocation = TypeReq< diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts index b6316c5e95..35ebb725cb 100644 --- a/invokeai/frontend/web/src/services/events/actions.ts +++ b/invokeai/frontend/web/src/services/events/actions.ts @@ -4,9 +4,11 @@ import { GraphExecutionStateCompleteEvent, InvocationCompleteEvent, InvocationErrorEvent, + InvocationRetrievalErrorEvent, InvocationStartedEvent, ModelLoadCompletedEvent, ModelLoadStartedEvent, + SessionRetrievalErrorEvent, } from 'services/events/types'; // Create actions for each socket @@ -181,3 +183,35 @@ export const socketModelLoadCompleted = createAction<{ export const appSocketModelLoadCompleted = createAction<{ data: ModelLoadCompletedEvent; }>('socket/appSocketModelLoadCompleted'); + +/** + * Socket.IO Session Retrieval Error + * + * Do not use. Only for use in middleware. + */ +export const socketSessionRetrievalError = createAction<{ + data: SessionRetrievalErrorEvent; +}>('socket/socketSessionRetrievalError'); + +/** + * App-level Session Retrieval Error + */ +export const appSocketSessionRetrievalError = createAction<{ + data: SessionRetrievalErrorEvent; +}>('socket/appSocketSessionRetrievalError'); + +/** + * Socket.IO Invocation Retrieval Error + * + * Do not use. Only for use in middleware. + */ +export const socketInvocationRetrievalError = createAction<{ + data: InvocationRetrievalErrorEvent; +}>('socket/socketInvocationRetrievalError'); + +/** + * App-level Invocation Retrieval Error + */ +export const appSocketInvocationRetrievalError = createAction<{ + data: InvocationRetrievalErrorEvent; +}>('socket/appSocketInvocationRetrievalError'); diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index ec1b55e3fe..37f5f24eac 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -87,6 +87,7 @@ export type InvocationErrorEvent = { graph_execution_state_id: string; node: BaseNode; source_node_id: string; + error_type: string; error: string; }; @@ -110,6 +111,29 @@ export type GraphExecutionStateCompleteEvent = { graph_execution_state_id: string; }; +/** + * A `session_retrieval_error` socket.io event. + * + * @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... } + */ +export type SessionRetrievalErrorEvent = { + graph_execution_state_id: string; + error_type: string; + error: string; +}; + +/** + * A `invocation_retrieval_error` socket.io event. + * + * @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... } + */ +export type InvocationRetrievalErrorEvent = { + graph_execution_state_id: string; + node_id: string; + error_type: string; + error: string; +}; + export type ClientEmitSubscribe = { session: string; }; @@ -128,6 +152,8 @@ export type ServerToClientEvents = { ) => void; model_load_started: (payload: ModelLoadStartedEvent) => void; model_load_completed: (payload: ModelLoadCompletedEvent) => void; + session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void; + invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void; }; export type ClientToServerEvents = { diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts index d44a549183..9ebb7ffbff 100644 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts @@ -11,9 +11,11 @@ import { socketGraphExecutionStateComplete, socketInvocationComplete, socketInvocationError, + socketInvocationRetrievalError, socketInvocationStarted, socketModelLoadCompleted, socketModelLoadStarted, + socketSessionRetrievalError, socketSubscribed, } from '../actions'; import { ClientToServerEvents, ServerToClientEvents } from '../types'; @@ -138,4 +140,26 @@ export const setEventListeners = (arg: SetEventListenersArg) => { }) ); }); + + /** + * Session retrieval error + */ + socket.on('session_retrieval_error', (data) => { + dispatch( + socketSessionRetrievalError({ + data, + }) + ); + }); + + /** + * Invocation retrieval error + */ + socket.on('invocation_retrieval_error', (data) => { + dispatch( + socketInvocationRetrievalError({ + data, + }) + ); + }); }; diff --git a/mkdocs.yml b/mkdocs.yml index 7d3e0e0b85..cbcaf52af6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -101,7 +101,7 @@ plugins: nav: - Home: 'index.md' - - Installation: + - Installation: - Overview: 'installation/index.md' - Installing with the Automated Installer: 'installation/010_INSTALL_AUTOMATED.md' - Installing manually: 'installation/020_INSTALL_MANUAL.md' @@ -122,14 +122,14 @@ nav: - Community Nodes: - Community Nodes: 'nodes/communityNodes.md' - Overview: 'nodes/overview.md' - - Features: + - Features: - Overview: 'features/index.md' - Concepts: 'features/CONCEPTS.md' - Configuration: 'features/CONFIGURATION.md' - ControlNet: 'features/CONTROLNET.md' - Image-to-Image: 'features/IMG2IMG.md' - Controlling Logging: 'features/LOGGING.md' - - Model Mergeing: 'features/MODEL_MERGING.md' + - Model Merging: 'features/MODEL_MERGING.md' - Nodes Editor (Experimental): 'features/NODES.md' - NSFW Checker: 'features/NSFW.md' - Postprocessing: 'features/POSTPROCESS.md' @@ -140,9 +140,9 @@ nav: - InvokeAI Web Server: 'features/WEB.md' - WebUI Hotkeys: "features/WEBUIHOTKEYS.md" - Other: 'features/OTHER.md' - - Contributing: + - Contributing: - How to Contribute: 'contributing/CONTRIBUTING.md' - - Development: + - Development: - Overview: 'contributing/contribution_guides/development.md' - InvokeAI Architecture: 'contributing/ARCHITECTURE.md' - Frontend Documentation: 'contributing/contribution_guides/development_guides/contributingToFrontend.md' @@ -161,5 +161,3 @@ nav: - Other: - Contributors: 'other/CONTRIBUTORS.md' - CompViz-README: 'other/README-CompViz.md' - - diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index cd995141ab..e2a4c8b343 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -462,16 +462,16 @@ def test_graph_subgraph_t2i(): n4 = ShowImageInvocation(id = "4") g.add_node(n4) - g.add_edge(create_edge("1.7","image","4","image")) + g.add_edge(create_edge("1.8","image","4","image")) # Validate dg = g.nx_graph_flat() - assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '1.7', '2', '3', '4']) + assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '1.7', '1.8', '2', '3', '4']) expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges] expected_edges.extend([ ('2','1.width'), ('3','1.height'), - ('1.7','4') + ('1.8','4') ]) print(expected_edges) print(list(dg.edges)) diff --git a/tests/test_path.py b/tests/test_path.py index 6076c6554f..52936142c7 100644 --- a/tests/test_path.py +++ b/tests/test_path.py @@ -9,7 +9,7 @@ from PIL import Image import invokeai.frontend.web.dist as frontend import invokeai.configs as configs -import invokeai.assets.web as assets_web +import invokeai.app.assets.images as image_assets class ConfigsTestCase(unittest.TestCase): """Test the configuration related imports and objects""" @@ -35,7 +35,7 @@ class ConfigsTestCase(unittest.TestCase): def test_caution_img(self): """Verify the caution image""" - caution_img = Image.open(osp.join(assets_web.__path__[0], "caution.png")) + caution_img = Image.open(osp.join(image_assets.__path__[0], "caution.png")) assert caution_img.width == int(500) assert caution_img.height == int(441) assert caution_img.format == str("PNG")