fix merge conflicts
@ -108,40 +108,6 @@ Can be used with .and():
|
|||||||
Each will give you different results - try them out and see what you prefer!
|
Each will give you different results - try them out and see what you prefer!
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Cross-Attention Control ('prompt2prompt')
|
|
||||||
|
|
||||||
Sometimes an image you generate is almost right, and you just want to change one
|
|
||||||
detail without affecting the rest. You could use a photo editor and inpainting
|
|
||||||
to overpaint the area, but that's a pain. Here's where `prompt2prompt` comes in
|
|
||||||
handy.
|
|
||||||
|
|
||||||
Generate an image with a given prompt, record the seed of the image, and then
|
|
||||||
use the `prompt2prompt` syntax to substitute words in the original prompt for
|
|
||||||
words in a new prompt. This works for `img2img` as well.
|
|
||||||
|
|
||||||
For example, consider the prompt `a cat.swap(dog) playing with a ball in the forest`. Normally, because the words interact with each other when doing a stable diffusion image generation, these two prompts would generate different compositions:
|
|
||||||
- `a cat playing with a ball in the forest`
|
|
||||||
- `a dog playing with a ball in the forest`
|
|
||||||
|
|
||||||
| `a cat playing with a ball in the forest` | `a dog playing with a ball in the forest` |
|
|
||||||
| --- | --- |
|
|
||||||
| img | img |
|
|
||||||
|
|
||||||
|
|
||||||
- For multiple word swaps, use parentheses: `a (fluffy cat).swap(barking dog) playing with a ball in the forest`.
|
|
||||||
- To swap a comma, use quotes: `a ("fluffy, grey cat").swap("big, barking dog") playing with a ball in the forest`.
|
|
||||||
- Supports options `t_start` and `t_end` (each 0-1) loosely corresponding to (bloc97's)[(https://github.com/bloc97/CrossAttentionControl)] `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to
|
|
||||||
intuitively understand. `t_start` and `t_end` are used to control on which steps cross-attention control should run. With the default values `t_start=0` and `t_end=1`, cross-attention control is active on every step of image generation. Other values can be used to turn cross-attention control off for part of the image generation process.
|
|
||||||
- For example, if doing a diffusion with 10 steps for the prompt is `a cat.swap(dog, t_start=0.3, t_end=1.0) playing with a ball in the forest`, the first 3 steps will be run as `a cat playing with a ball in the forest`, while the last 7 steps will run as `a dog playing with a ball in the forest`, but the pixels that represent `dog` will be locked to the pixels that would have represented `cat` if the `cat` prompt had been used instead.
|
|
||||||
- Conversely, for `a cat.swap(dog, t_start=0, t_end=0.7) playing with a ball in the forest`, the first 7 steps will run as `a dog playing with a ball in the forest` with the pixels that represent `dog` locked to the same pixels that would have represented `cat` if the `cat` prompt was being used instead. The final 3 steps will just run `a cat playing with a ball in the forest`.
|
|
||||||
> For img2img, the step sequence does not start at 0 but instead at `(1.0-strength)` - so if the img2img `strength` is `0.7`, `t_start` and `t_end` must both be greater than `0.3` (`1.0-0.7`) to have any effect.
|
|
||||||
|
|
||||||
Prompt2prompt `.swap()` is not compatible with xformers, which will be temporarily disabled when doing a `.swap()` - so you should expect to use more VRAM and run slower that with xformers enabled.
|
|
||||||
|
|
||||||
The `prompt2prompt` code is based off
|
|
||||||
[bloc97's colab](https://github.com/bloc97/CrossAttentionControl).
|
|
||||||
|
|
||||||
### Escaping parentheses and speech marks
|
### Escaping parentheses and speech marks
|
||||||
|
|
||||||
If the model you are using has parentheses () or speech marks "" as part of its
|
If the model you are using has parentheses () or speech marks "" as part of its
|
||||||
|
@ -40,6 +40,25 @@ Follow the same steps to scan and import the missing models.
|
|||||||
- Check the `ram` setting in `invokeai.yaml`. This setting tells Invoke how much of your system RAM can be used to cache models. Having this too high or too low can slow things down. That said, it's generally safest to not set this at all and instead let Invoke manage it.
|
- Check the `ram` setting in `invokeai.yaml`. This setting tells Invoke how much of your system RAM can be used to cache models. Having this too high or too low can slow things down. That said, it's generally safest to not set this at all and instead let Invoke manage it.
|
||||||
- Check the `vram` setting in `invokeai.yaml`. This setting tells Invoke how much of your GPU VRAM can be used to cache models. Counter-intuitively, if this setting is too high, Invoke will need to do a lot of shuffling of models as it juggles the VRAM cache and the currently-loaded model. The default value of 0.25 is generally works well for GPUs without 16GB or more VRAM. Even on a 24GB card, the default works well.
|
- Check the `vram` setting in `invokeai.yaml`. This setting tells Invoke how much of your GPU VRAM can be used to cache models. Counter-intuitively, if this setting is too high, Invoke will need to do a lot of shuffling of models as it juggles the VRAM cache and the currently-loaded model. The default value of 0.25 is generally works well for GPUs without 16GB or more VRAM. Even on a 24GB card, the default works well.
|
||||||
- Check that your generations are happening on your GPU (if you have one). InvokeAI will log what is being used for generation upon startup. If your GPU isn't used, re-install to ensure the correct versions of torch get installed.
|
- Check that your generations are happening on your GPU (if you have one). InvokeAI will log what is being used for generation upon startup. If your GPU isn't used, re-install to ensure the correct versions of torch get installed.
|
||||||
|
- If you are on Windows, you may have exceeded your GPU's VRAM capacity and are using slower [shared GPU memory](#shared-gpu-memory-windows). There's a guide to opt out of this behaviour in the linked FAQ entry.
|
||||||
|
|
||||||
|
## Shared GPU Memory (Windows)
|
||||||
|
|
||||||
|
!!! tip "Nvidia GPUs with driver 536.40"
|
||||||
|
|
||||||
|
This only applies to current Nvidia cards with driver 536.40 or later, released in June 2023.
|
||||||
|
|
||||||
|
When the GPU doesn't have enough VRAM for a task, Windows is able to allocate some of its CPU RAM to the GPU. This is much slower than VRAM, but it does allow the system to generate when it otherwise might no have enough VRAM.
|
||||||
|
|
||||||
|
When shared GPU memory is used, generation slows down dramatically - but at least it doesn't crash.
|
||||||
|
|
||||||
|
If you'd like to opt out of this behavior and instead get an error when you exceed your GPU's VRAM, follow [this guide from Nvidia](https://nvidia.custhelp.com/app/answers/detail/a_id/5490).
|
||||||
|
|
||||||
|
Here's how to get the python path required in the linked guide:
|
||||||
|
|
||||||
|
- Run `invoke.bat`.
|
||||||
|
- Select option 2 for developer console.
|
||||||
|
- At least one python path will be printed. Copy the path that includes your invoke installation directory (typically the first).
|
||||||
|
|
||||||
## Installer cannot find python (Windows)
|
## Installer cannot find python (Windows)
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
InvokeAI installer script
|
InvokeAI installer script
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import locale
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import re
|
import re
|
||||||
@ -316,7 +317,9 @@ def upgrade_pip(venv_path: Path) -> str | None:
|
|||||||
python = str(venv_path.expanduser().resolve() / python)
|
python = str(venv_path.expanduser().resolve() / python)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = subprocess.check_output([python, "-m", "pip", "install", "--upgrade", "pip"]).decode()
|
result = subprocess.check_output([python, "-m", "pip", "install", "--upgrade", "pip"]).decode(
|
||||||
|
encoding=locale.getpreferredencoding()
|
||||||
|
)
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
print(e)
|
print(e)
|
||||||
result = None
|
result = None
|
||||||
|
@ -12,7 +12,7 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
|
||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
from invokeai.backend.util.logging import logging
|
from invokeai.backend.util.logging import logging
|
||||||
from invokeai.version import __version__
|
from invokeai.version import __version__
|
||||||
@ -100,7 +100,7 @@ async def get_app_deps() -> AppDependencyVersions:
|
|||||||
|
|
||||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||||
async def get_config() -> AppConfig:
|
async def get_config() -> AppConfig:
|
||||||
infill_methods = ["tile", "lama", "cv2"]
|
infill_methods = ["tile", "lama", "cv2", "color"] # TODO: add mosaic back
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
infill_methods.append("patchmatch")
|
infill_methods.append("patchmatch")
|
||||||
|
|
||||||
|
@ -5,7 +5,15 @@ from compel import Compel, ReturnedEmbeddingsType
|
|||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
from invokeai.app.invocations.fields import (
|
||||||
|
ConditioningField,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
OutputField,
|
||||||
|
TensorField,
|
||||||
|
UIComponent,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.ti_utils import generate_ti_list
|
from invokeai.app.util.ti_utils import generate_ti_list
|
||||||
@ -14,7 +22,6 @@ from invokeai.backend.model_patcher import ModelPatcher
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
ConditioningFieldData,
|
ConditioningFieldData,
|
||||||
ExtraConditioningInfo,
|
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
from invokeai.backend.util.devices import torch_dtype
|
from invokeai.backend.util.devices import torch_dtype
|
||||||
@ -36,7 +43,7 @@ from .model import CLIPField
|
|||||||
title="Prompt",
|
title="Prompt",
|
||||||
tags=["prompt", "compel"],
|
tags=["prompt", "compel"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.1.1",
|
version="1.2.0",
|
||||||
)
|
)
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -51,6 +58,9 @@ class CompelInvocation(BaseInvocation):
|
|||||||
description=FieldDescriptions.clip,
|
description=FieldDescriptions.clip,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
|
mask: Optional[TensorField] = InputField(
|
||||||
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
@ -98,27 +108,19 @@ class CompelInvocation(BaseInvocation):
|
|||||||
if context.config.get().log_tokenization:
|
if context.config.get().log_tokenization:
|
||||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
|
|
||||||
ec = ExtraConditioningInfo(
|
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
c = c.detach().to("cpu")
|
c = c.detach().to("cpu")
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
|
||||||
conditionings=[
|
|
||||||
BasicConditioningInfo(
|
|
||||||
embeds=c,
|
|
||||||
extra_conditioning=ec,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
return ConditioningOutput(
|
||||||
return ConditioningOutput.build(conditioning_name)
|
conditioning=ConditioningField(
|
||||||
|
conditioning_name=conditioning_name,
|
||||||
|
mask=self.mask,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
@ -132,7 +134,7 @@ class SDXLPromptInvocationBase:
|
|||||||
get_pooled: bool,
|
get_pooled: bool,
|
||||||
lora_prefix: str,
|
lora_prefix: str,
|
||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||||
tokenizer_model = tokenizer_info.model
|
tokenizer_model = tokenizer_info.model
|
||||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||||
@ -159,7 +161,7 @@ class SDXLPromptInvocationBase:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
c_pooled = None
|
c_pooled = None
|
||||||
return c, c_pooled, None
|
return c, c_pooled
|
||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
@ -204,17 +206,12 @@ class SDXLPromptInvocationBase:
|
|||||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||||
|
|
||||||
# TODO: ask for optimizations? to not run text_encoder twice
|
# TODO: ask for optimizations? to not run text_encoder twice
|
||||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
if get_pooled:
|
if get_pooled:
|
||||||
c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt])
|
c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt])
|
||||||
else:
|
else:
|
||||||
c_pooled = None
|
c_pooled = None
|
||||||
|
|
||||||
ec = ExtraConditioningInfo(
|
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
del tokenizer
|
del tokenizer
|
||||||
del text_encoder
|
del text_encoder
|
||||||
del tokenizer_info
|
del tokenizer_info
|
||||||
@ -224,7 +221,7 @@ class SDXLPromptInvocationBase:
|
|||||||
if c_pooled is not None:
|
if c_pooled is not None:
|
||||||
c_pooled = c_pooled.detach().to("cpu")
|
c_pooled = c_pooled.detach().to("cpu")
|
||||||
|
|
||||||
return c, c_pooled, ec
|
return c, c_pooled
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -232,7 +229,7 @@ class SDXLPromptInvocationBase:
|
|||||||
title="SDXL Prompt",
|
title="SDXL Prompt",
|
||||||
tags=["sdxl", "compel", "prompt"],
|
tags=["sdxl", "compel", "prompt"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.1.1",
|
version="1.2.0",
|
||||||
)
|
)
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -255,20 +252,19 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
target_height: int = InputField(default=1024, description="")
|
target_height: int = InputField(default=1024, description="")
|
||||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||||
|
mask: Optional[TensorField] = InputField(
|
||||||
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
c1, c1_pooled = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True)
|
||||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
|
||||||
)
|
|
||||||
if self.style.strip() == "":
|
if self.style.strip() == "":
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
c2, c2_pooled = self.run_clip_compel(
|
||||||
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
|
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
c2, c2_pooled = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True)
|
||||||
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
|
|
||||||
)
|
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@ -307,17 +303,19 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[
|
conditionings=[
|
||||||
SDXLConditioningInfo(
|
SDXLConditioningInfo(
|
||||||
embeds=torch.cat([c1, c2], dim=-1),
|
embeds=torch.cat([c1, c2], dim=-1), pooled_embeds=c2_pooled, add_time_ids=add_time_ids
|
||||||
pooled_embeds=c2_pooled,
|
|
||||||
add_time_ids=add_time_ids,
|
|
||||||
extra_conditioning=ec1,
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
|
||||||
return ConditioningOutput.build(conditioning_name)
|
return ConditioningOutput(
|
||||||
|
conditioning=ConditioningField(
|
||||||
|
conditioning_name=conditioning_name,
|
||||||
|
mask=self.mask,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -345,7 +343,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
# TODO: if there will appear lora for refiner - write proper prefix
|
# TODO: if there will appear lora for refiner - write proper prefix
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
c2, c2_pooled = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@ -354,14 +352,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
|
|
||||||
assert c2_pooled is not None
|
assert c2_pooled is not None
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[
|
conditionings=[SDXLConditioningInfo(embeds=c2, pooled_embeds=c2_pooled, add_time_ids=add_time_ids)]
|
||||||
SDXLConditioningInfo(
|
|
||||||
embeds=c2,
|
|
||||||
pooled_embeds=c2_pooled,
|
|
||||||
add_time_ids=add_time_ids,
|
|
||||||
extra_conditioning=ec2, # or None
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
@ -203,6 +203,12 @@ class DenoiseMaskField(BaseModel):
|
|||||||
gradient: bool = Field(default=False, description="Used for gradient inpainting")
|
gradient: bool = Field(default=False, description="Used for gradient inpainting")
|
||||||
|
|
||||||
|
|
||||||
|
class TensorField(BaseModel):
|
||||||
|
"""A tensor primitive field."""
|
||||||
|
|
||||||
|
tensor_name: str = Field(description="The name of a tensor.")
|
||||||
|
|
||||||
|
|
||||||
class LatentsField(BaseModel):
|
class LatentsField(BaseModel):
|
||||||
"""A latents tensor primitive field"""
|
"""A latents tensor primitive field"""
|
||||||
|
|
||||||
@ -226,7 +232,11 @@ class ConditioningField(BaseModel):
|
|||||||
"""A conditioning tensor primitive value"""
|
"""A conditioning tensor primitive value"""
|
||||||
|
|
||||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||||
# endregion
|
mask: Optional[TensorField] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||||
|
"included regions should be set to True.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MetadataField(RootModel[dict[str, Any]]):
|
class MetadataField(RootModel[dict[str, Any]]):
|
||||||
|
@ -1,154 +1,91 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
from abc import abstractmethod
|
||||||
|
from typing import Literal, get_args
|
||||||
|
|
||||||
import math
|
from PIL import Image
|
||||||
from typing import Literal, Optional, get_args
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image, ImageOps
|
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import ColorField, ImageField
|
from invokeai.app.invocations.fields import ColorField, ImageField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
|
||||||
from invokeai.app.util.misc import SEED_MAX
|
from invokeai.app.util.misc import SEED_MAX
|
||||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
from invokeai.backend.image_util.infill_methods.cv2_inpaint import cv2_inpaint
|
||||||
from invokeai.backend.image_util.lama import LaMA
|
from invokeai.backend.image_util.infill_methods.lama import LaMA
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.infill_methods.mosaic import infill_mosaic
|
||||||
|
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch, infill_patchmatch
|
||||||
|
from invokeai.backend.image_util.infill_methods.tile import infill_tile
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
from .fields import InputField, WithBoard, WithMetadata
|
from .fields import InputField, WithBoard, WithMetadata
|
||||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||||
|
|
||||||
|
logger = InvokeAILogger.get_logger()
|
||||||
|
|
||||||
def infill_methods() -> list[str]:
|
|
||||||
methods = ["tile", "solid", "lama", "cv2"]
|
def get_infill_methods():
|
||||||
|
methods = Literal["tile", "color", "lama", "cv2"] # TODO: add mosaic back
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
methods.insert(0, "patchmatch")
|
methods = Literal["patchmatch", "tile", "color", "lama", "cv2"] # TODO: add mosaic back
|
||||||
return methods
|
return methods
|
||||||
|
|
||||||
|
|
||||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
INFILL_METHODS = get_infill_methods()
|
||||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||||
|
|
||||||
|
|
||||||
def infill_lama(im: Image.Image) -> Image.Image:
|
class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
lama = LaMA()
|
"""Base class for invocations that preprocess images for Infilling"""
|
||||||
return lama(im)
|
|
||||||
|
|
||||||
|
image: ImageField = InputField(description="The image to process")
|
||||||
|
|
||||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
@abstractmethod
|
||||||
if im.mode != "RGBA":
|
def infill(self, image: Image.Image) -> Image.Image:
|
||||||
return im
|
"""Infill the image with the specified method"""
|
||||||
|
pass
|
||||||
|
|
||||||
# Skip patchmatch if patchmatch isn't available
|
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
|
||||||
if not PatchMatch.patchmatch_available():
|
"""Process the image to have an alpha channel before being infilled"""
|
||||||
return im
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
has_alpha = True if image.mode == "RGBA" else False
|
||||||
|
return image, has_alpha
|
||||||
|
|
||||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
|
# Retrieve and process image to be infilled
|
||||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
input_image, has_alpha = self.load_image(context)
|
||||||
return im_patched
|
|
||||||
|
|
||||||
|
# If the input image has no alpha channel, return it
|
||||||
|
if has_alpha is False:
|
||||||
|
return ImageOutput.build(context.images.get_dto(self.image.image_name))
|
||||||
|
|
||||||
def infill_cv2(im: Image.Image) -> Image.Image:
|
# Perform Infill action
|
||||||
return cv2_inpaint(im)
|
infilled_image = self.infill(input_image)
|
||||||
|
|
||||||
|
# Create ImageDTO for Infilled Image
|
||||||
|
infilled_image_dto = context.images.save(image=infilled_image)
|
||||||
|
|
||||||
def get_tile_images(image: np.ndarray, width=8, height=8):
|
# Return Infilled Image
|
||||||
_nrows, _ncols, depth = image.shape
|
return ImageOutput.build(infilled_image_dto)
|
||||||
_strides = image.strides
|
|
||||||
|
|
||||||
nrows, _m = divmod(_nrows, height)
|
|
||||||
ncols, _n = divmod(_ncols, width)
|
|
||||||
if _m != 0 or _n != 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return np.lib.stride_tricks.as_strided(
|
|
||||||
np.ravel(image),
|
|
||||||
shape=(nrows, ncols, height, width, depth),
|
|
||||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
|
||||||
writeable=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
|
||||||
# Only fill if there's an alpha layer
|
|
||||||
if im.mode != "RGBA":
|
|
||||||
return im
|
|
||||||
|
|
||||||
a = np.asarray(im, dtype=np.uint8)
|
|
||||||
|
|
||||||
tile_size_tuple = (tile_size, tile_size)
|
|
||||||
|
|
||||||
# Get the image as tiles of a specified size
|
|
||||||
tiles = get_tile_images(a, *tile_size_tuple).copy()
|
|
||||||
|
|
||||||
# Get the mask as tiles
|
|
||||||
tiles_mask = tiles[:, :, :, :, 3]
|
|
||||||
|
|
||||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
|
||||||
tmask_shape = tiles_mask.shape
|
|
||||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
|
||||||
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
|
||||||
tiles_mask = tiles_mask > 0
|
|
||||||
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
|
||||||
|
|
||||||
# Get RGB tiles in single array and filter by the mask
|
|
||||||
tshape = tiles.shape
|
|
||||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
|
||||||
filtered_tiles = tiles_all[tiles_mask]
|
|
||||||
|
|
||||||
if len(filtered_tiles) == 0:
|
|
||||||
return im
|
|
||||||
|
|
||||||
# Find all invalid tiles and replace with a random valid tile
|
|
||||||
replace_count = (tiles_mask == False).sum() # noqa: E712
|
|
||||||
rng = np.random.default_rng(seed=seed)
|
|
||||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
|
||||||
|
|
||||||
# Convert back to an image
|
|
||||||
tiles_all = tiles_all.reshape(tshape)
|
|
||||||
tiles_all = tiles_all.swapaxes(1, 2)
|
|
||||||
st = tiles_all.reshape(
|
|
||||||
(
|
|
||||||
math.prod(tiles_all.shape[0:2]),
|
|
||||||
math.prod(tiles_all.shape[2:4]),
|
|
||||||
tiles_all.shape[4],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
si = Image.fromarray(st, mode="RGBA")
|
|
||||||
|
|
||||||
return si
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||||
class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class InfillColorInvocation(InfillImageProcessorInvocation):
|
||||||
"""Infills transparent areas of an image with a solid color"""
|
"""Infills transparent areas of an image with a solid color"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
|
||||||
color: ColorField = InputField(
|
color: ColorField = InputField(
|
||||||
default=ColorField(r=127, g=127, b=127, a=255),
|
default=ColorField(r=127, g=127, b=127, a=255),
|
||||||
description="The color to use to infill",
|
description="The color to use to infill",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def infill(self, image: Image.Image):
|
||||||
image = context.images.get_pil(self.image.image_name)
|
|
||||||
|
|
||||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||||
|
|
||||||
infilled.paste(image, (0, 0), image.split()[-1])
|
infilled.paste(image, (0, 0), image.split()[-1])
|
||||||
|
return infilled
|
||||||
image_dto = context.images.save(image=infilled)
|
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.3")
|
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.3")
|
||||||
class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class InfillTileInvocation(InfillImageProcessorInvocation):
|
||||||
"""Infills transparent areas of an image with tiles of the image"""
|
"""Infills transparent areas of an image with tiles of the image"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
|
||||||
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
||||||
seed: int = InputField(
|
seed: int = InputField(
|
||||||
default=0,
|
default=0,
|
||||||
@ -157,92 +94,74 @@ class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
description="The seed to use for tile generation (omit for random)",
|
description="The seed to use for tile generation (omit for random)",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def infill(self, image: Image.Image):
|
||||||
image = context.images.get_pil(self.image.image_name)
|
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
|
||||||
|
return output.infilled
|
||||||
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
|
||||||
infilled.paste(image, (0, 0), image.split()[-1])
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=infilled)
|
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2"
|
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2"
|
||||||
)
|
)
|
||||||
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
|
||||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
|
||||||
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
||||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def infill(self, image: Image.Image):
|
||||||
image = context.images.get_pil(self.image.image_name).convert("RGBA")
|
|
||||||
|
|
||||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||||
|
|
||||||
infill_image = image.copy()
|
|
||||||
width = int(image.width / self.downscale)
|
width = int(image.width / self.downscale)
|
||||||
height = int(image.height / self.downscale)
|
height = int(image.height / self.downscale)
|
||||||
infill_image = infill_image.resize(
|
|
||||||
|
infilled = image.resize(
|
||||||
(width, height),
|
(width, height),
|
||||||
resample=resample_mode,
|
resample=resample_mode,
|
||||||
)
|
)
|
||||||
|
infilled = infill_patchmatch(image)
|
||||||
if PatchMatch.patchmatch_available():
|
|
||||||
infilled = infill_patchmatch(infill_image)
|
|
||||||
else:
|
|
||||||
raise ValueError("PatchMatch is not available on this system")
|
|
||||||
|
|
||||||
infilled = infilled.resize(
|
infilled = infilled.resize(
|
||||||
(image.width, image.height),
|
(image.width, image.height),
|
||||||
resample=resample_mode,
|
resample=resample_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
||||||
# image.paste(infilled, (0, 0), mask=image.split()[-1])
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=infilled)
|
return infilled
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||||
class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||||
"""Infills transparent areas of an image using the LaMa model"""
|
"""Infills transparent areas of an image using the LaMa model"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
def infill(self, image: Image.Image):
|
||||||
|
lama = LaMA()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
return lama(image)
|
||||||
image = context.images.get_pil(self.image.image_name)
|
|
||||||
|
|
||||||
# Downloads the LaMa model if it doesn't already exist
|
|
||||||
download_with_progress_bar(
|
|
||||||
name="LaMa Inpainting Model",
|
|
||||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
|
||||||
dest_path=context.config.get().models_path / "core/misc/lama/lama.pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
infilled = infill_lama(image.copy())
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=infilled)
|
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||||
class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class CV2InfillInvocation(InfillImageProcessorInvocation):
|
||||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||||
|
|
||||||
|
def infill(self, image: Image.Image):
|
||||||
|
return cv2_inpaint(image)
|
||||||
|
|
||||||
|
|
||||||
|
# @invocation(
|
||||||
|
# "infill_mosaic", title="Mosaic Infill", tags=["image", "inpaint", "outpaint"], category="inpaint", version="1.0.0"
|
||||||
|
# )
|
||||||
|
class MosaicInfillInvocation(InfillImageProcessorInvocation):
|
||||||
|
"""Infills transparent areas of an image with a mosaic pattern drawing colors from the rest of the image"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
|
tile_width: int = InputField(default=64, description="Width of the tile")
|
||||||
|
tile_height: int = InputField(default=64, description="Height of the tile")
|
||||||
|
min_color: ColorField = InputField(
|
||||||
|
default=ColorField(r=0, g=0, b=0, a=255),
|
||||||
|
description="The min threshold for color",
|
||||||
|
)
|
||||||
|
max_color: ColorField = InputField(
|
||||||
|
default=ColorField(r=255, g=255, b=255, a=255),
|
||||||
|
description="The max threshold for color",
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def infill(self, image: Image.Image):
|
||||||
image = context.images.get_pil(self.image.image_name)
|
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())
|
||||||
|
|
||||||
infilled = infill_cv2(image.copy())
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=infilled)
|
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
@ -1,11 +1,23 @@
|
|||||||
from builtins import float
|
from builtins import float
|
||||||
from typing import List, Literal, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
invocation,
|
||||||
|
invocation_output,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
OutputField,
|
||||||
|
TensorField,
|
||||||
|
UIType,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
@ -23,13 +35,18 @@ class IPAdapterField(BaseModel):
|
|||||||
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
|
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
|
||||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
|
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
|
||||||
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
|
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
|
||||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||||
)
|
)
|
||||||
end_step_percent: float = Field(
|
end_step_percent: float = Field(
|
||||||
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
||||||
)
|
)
|
||||||
|
mask: Optional[TensorField] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The bool mask associated with this IP-Adapter. Excluded regions should be set to False, included "
|
||||||
|
"regions should be set to True.",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("weight")
|
@field_validator("weight")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -52,7 +69,7 @@ class IPAdapterOutput(BaseInvocationOutput):
|
|||||||
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
||||||
|
|
||||||
|
|
||||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
|
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.3.0")
|
||||||
class IPAdapterInvocation(BaseInvocation):
|
class IPAdapterInvocation(BaseInvocation):
|
||||||
"""Collects IP-Adapter info to pass to other nodes."""
|
"""Collects IP-Adapter info to pass to other nodes."""
|
||||||
|
|
||||||
@ -65,9 +82,9 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
ui_order=-1,
|
ui_order=-1,
|
||||||
ui_type=UIType.IPAdapterModel,
|
ui_type=UIType.IPAdapterModel,
|
||||||
)
|
)
|
||||||
clip_vision_model: Literal["auto", "ViT-H", "ViT-G"] = InputField(
|
clip_vision_model: Literal["ViT-H", "ViT-G"] = InputField(
|
||||||
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
|
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
|
||||||
default="auto",
|
default="ViT-H",
|
||||||
ui_order=2,
|
ui_order=2,
|
||||||
)
|
)
|
||||||
weight: Union[float, List[float]] = InputField(
|
weight: Union[float, List[float]] = InputField(
|
||||||
@ -79,6 +96,9 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
end_step_percent: float = InputField(
|
end_step_percent: float = InputField(
|
||||||
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
||||||
)
|
)
|
||||||
|
mask: Optional[TensorField] = InputField(
|
||||||
|
default=None, description="A mask defining the region that this IP-Adapter applies to."
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("weight")
|
@field_validator("weight")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -96,14 +116,9 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||||
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
|
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
|
||||||
|
|
||||||
if self.clip_vision_model == "auto":
|
|
||||||
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
|
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
|
||||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"You need to set the appropriate CLIP Vision model for checkpoint IP Adapter models."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
|
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
|
||||||
|
|
||||||
@ -117,6 +132,7 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
begin_step_percent=self.begin_step_percent,
|
begin_step_percent=self.begin_step_percent,
|
||||||
end_step_percent=self.end_step_percent,
|
end_step_percent=self.end_step_percent,
|
||||||
|
mask=self.mask,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
import inspect
|
||||||
import math
|
import math
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from functools import singledispatchmethod
|
from functools import singledispatchmethod
|
||||||
@ -9,6 +9,7 @@ import einops
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
@ -52,12 +53,20 @@ from invokeai.backend.lora import LoRAModelRaw
|
|||||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
BasicConditioningInfo,
|
||||||
|
IPAdapterConditioningInfo,
|
||||||
|
IPAdapterData,
|
||||||
|
Range,
|
||||||
|
SDXLConditioningInfo,
|
||||||
|
TextConditioningData,
|
||||||
|
TextConditioningRegions,
|
||||||
|
)
|
||||||
|
from invokeai.backend.util.mask import to_standard_float_mask
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ControlNetData,
|
ControlNetData,
|
||||||
IPAdapterData,
|
|
||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
T2IAdapterData,
|
T2IAdapterData,
|
||||||
image_resized_to_grid_as_tensor,
|
image_resized_to_grid_as_tensor,
|
||||||
@ -275,10 +284,10 @@ def get_scheduler(
|
|||||||
class DenoiseLatentsInvocation(BaseInvocation):
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
"""Denoises noisy latents to decodable images"""
|
"""Denoises noisy latents to decodable images"""
|
||||||
|
|
||||||
positive_conditioning: ConditioningField = InputField(
|
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||||
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
||||||
)
|
)
|
||||||
negative_conditioning: ConditioningField = InputField(
|
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
||||||
)
|
)
|
||||||
noise: Optional[LatentsField] = InputField(
|
noise: Optional[LatentsField] = InputField(
|
||||||
@ -356,33 +365,168 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
def _get_text_embeddings_and_masks(
|
||||||
|
self,
|
||||||
|
cond_list: list[ConditioningField],
|
||||||
|
context: InvocationContext,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
|
||||||
|
"""Get the text embeddings and masks from the input conditioning fields."""
|
||||||
|
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||||
|
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||||
|
for cond in cond_list:
|
||||||
|
cond_data = context.conditioning.load(cond.conditioning_name)
|
||||||
|
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||||
|
|
||||||
|
mask = cond.mask
|
||||||
|
if mask is not None:
|
||||||
|
mask = context.tensors.load(mask.tensor_name)
|
||||||
|
text_embeddings_masks.append(mask)
|
||||||
|
|
||||||
|
return text_embeddings, text_embeddings_masks
|
||||||
|
|
||||||
|
def _preprocess_regional_prompt_mask(
|
||||||
|
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Preprocess a regional prompt mask to match the target height and width.
|
||||||
|
If mask is None, returns a mask of all ones with the target height and width.
|
||||||
|
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
return torch.ones((1, 1, target_height, target_width), dtype=dtype)
|
||||||
|
|
||||||
|
mask = to_standard_float_mask(mask, out_dtype=dtype)
|
||||||
|
|
||||||
|
tf = torchvision.transforms.Resize(
|
||||||
|
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||||
|
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||||
|
resized_mask = tf(mask)
|
||||||
|
return resized_mask
|
||||||
|
|
||||||
|
def _concat_regional_text_embeddings(
|
||||||
|
self,
|
||||||
|
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||||
|
masks: Optional[list[Optional[torch.Tensor]]],
|
||||||
|
latent_height: int,
|
||||||
|
latent_width: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
||||||
|
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
||||||
|
if masks is None:
|
||||||
|
masks = [None] * len(text_conditionings)
|
||||||
|
assert len(text_conditionings) == len(masks)
|
||||||
|
|
||||||
|
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
||||||
|
|
||||||
|
all_masks_are_none = all(mask is None for mask in masks)
|
||||||
|
|
||||||
|
text_embedding = []
|
||||||
|
pooled_embedding = None
|
||||||
|
add_time_ids = None
|
||||||
|
cur_text_embedding_len = 0
|
||||||
|
processed_masks = []
|
||||||
|
embedding_ranges = []
|
||||||
|
|
||||||
|
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
|
||||||
|
mask = masks[prompt_idx]
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
|
||||||
|
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
|
||||||
|
# global prompt information. In an ideal case, there should be exactly one global prompt without a
|
||||||
|
# mask, but we don't enforce this.
|
||||||
|
|
||||||
|
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
|
||||||
|
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
|
||||||
|
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
||||||
|
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
||||||
|
# pretty major breaking change to a popular node, so for now we use this hack.
|
||||||
|
if pooled_embedding is None or mask is None:
|
||||||
|
pooled_embedding = text_embedding_info.pooled_embeds
|
||||||
|
if add_time_ids is None or mask is None:
|
||||||
|
add_time_ids = text_embedding_info.add_time_ids
|
||||||
|
|
||||||
|
text_embedding.append(text_embedding_info.embeds)
|
||||||
|
if not all_masks_are_none:
|
||||||
|
embedding_ranges.append(
|
||||||
|
Range(
|
||||||
|
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
processed_masks.append(
|
||||||
|
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||||
|
|
||||||
|
text_embedding = torch.cat(text_embedding, dim=1)
|
||||||
|
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||||
|
|
||||||
|
regions = None
|
||||||
|
if not all_masks_are_none:
|
||||||
|
regions = TextConditioningRegions(
|
||||||
|
masks=torch.cat(processed_masks, dim=1),
|
||||||
|
ranges=embedding_ranges,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
return SDXLConditioningInfo(
|
||||||
|
embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids
|
||||||
|
), regions
|
||||||
|
return BasicConditioningInfo(embeds=text_embedding), regions
|
||||||
|
|
||||||
def get_conditioning_data(
|
def get_conditioning_data(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
scheduler: Scheduler,
|
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
seed: int,
|
latent_height: int,
|
||||||
) -> ConditioningData:
|
latent_width: int,
|
||||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
) -> TextConditioningData:
|
||||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
|
||||||
|
cond_list = self.positive_conditioning
|
||||||
|
if not isinstance(cond_list, list):
|
||||||
|
cond_list = [cond_list]
|
||||||
|
uncond_list = self.negative_conditioning
|
||||||
|
if not isinstance(uncond_list, list):
|
||||||
|
uncond_list = [uncond_list]
|
||||||
|
|
||||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
cond_list, context, unet.device, unet.dtype
|
||||||
|
)
|
||||||
conditioning_data = ConditioningData(
|
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||||
unconditioned_embeddings=uc,
|
uncond_list, context, unet.device, unet.dtype
|
||||||
text_embeddings=c,
|
|
||||||
guidance_scale=self.cfg_scale,
|
|
||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
|
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
|
||||||
scheduler,
|
text_conditionings=cond_text_embeddings,
|
||||||
# for ddim scheduler
|
masks=cond_text_embedding_masks,
|
||||||
eta=0.0, # ddim_eta
|
latent_height=latent_height,
|
||||||
# for ancestral and sde schedulers
|
latent_width=latent_width,
|
||||||
# flip all bits to have noise different from initial
|
dtype=unet.dtype,
|
||||||
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
|
)
|
||||||
|
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
|
||||||
|
text_conditionings=uncond_text_embeddings,
|
||||||
|
masks=uncond_text_embedding_masks,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
dtype=unet.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
conditioning_data = TextConditioningData(
|
||||||
|
uncond_text=uncond_text_embedding,
|
||||||
|
cond_text=cond_text_embedding,
|
||||||
|
uncond_regions=uncond_regions,
|
||||||
|
cond_regions=cond_regions,
|
||||||
|
guidance_scale=self.cfg_scale,
|
||||||
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
)
|
)
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
@ -488,8 +632,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
||||||
conditioning_data: ConditioningData,
|
|
||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
|
latent_height: int,
|
||||||
|
latent_width: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
) -> Optional[list[IPAdapterData]]:
|
) -> Optional[list[IPAdapterData]]:
|
||||||
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
||||||
to the `conditioning_data` (in-place).
|
to the `conditioning_data` (in-place).
|
||||||
@ -505,7 +651,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
ip_adapter_data_list = []
|
ip_adapter_data_list = []
|
||||||
conditioning_data.ip_adapter_conditioning = []
|
|
||||||
for single_ip_adapter in ip_adapter:
|
for single_ip_adapter in ip_adapter:
|
||||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||||
context.models.load(single_ip_adapter.ip_adapter_model)
|
context.models.load(single_ip_adapter.ip_adapter_model)
|
||||||
@ -528,9 +673,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
single_ipa_images, image_encoder_model
|
single_ipa_images, image_encoder_model
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_data.ip_adapter_conditioning.append(
|
mask = single_ip_adapter.mask
|
||||||
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
if mask is not None:
|
||||||
)
|
mask = context.tensors.load(mask.tensor_name)
|
||||||
|
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||||
|
|
||||||
ip_adapter_data_list.append(
|
ip_adapter_data_list.append(
|
||||||
IPAdapterData(
|
IPAdapterData(
|
||||||
@ -538,6 +684,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
weight=single_ip_adapter.weight,
|
weight=single_ip_adapter.weight,
|
||||||
begin_step_percent=single_ip_adapter.begin_step_percent,
|
begin_step_percent=single_ip_adapter.begin_step_percent,
|
||||||
end_step_percent=single_ip_adapter.end_step_percent,
|
end_step_percent=single_ip_adapter.end_step_percent,
|
||||||
|
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||||
|
mask=mask,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -627,6 +775,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
steps: int,
|
steps: int,
|
||||||
denoising_start: float,
|
denoising_start: float,
|
||||||
denoising_end: float,
|
denoising_end: float,
|
||||||
|
seed: int,
|
||||||
) -> Tuple[int, List[int], int]:
|
) -> Tuple[int, List[int], int]:
|
||||||
assert isinstance(scheduler, ConfigMixin)
|
assert isinstance(scheduler, ConfigMixin)
|
||||||
if scheduler.config.get("cpu_only", False):
|
if scheduler.config.get("cpu_only", False):
|
||||||
@ -655,7 +804,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
||||||
num_inference_steps = len(timesteps) // scheduler.order
|
num_inference_steps = len(timesteps) // scheduler.order
|
||||||
|
|
||||||
return num_inference_steps, timesteps, init_timestep
|
scheduler_step_kwargs = {}
|
||||||
|
scheduler_step_signature = inspect.signature(scheduler.step)
|
||||||
|
if "generator" in scheduler_step_signature.parameters:
|
||||||
|
# At some point, someone decided that schedulers that accept a generator should use the original seed with
|
||||||
|
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
|
||||||
|
# reproducibility.
|
||||||
|
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
|
||||||
|
|
||||||
|
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
|
||||||
|
|
||||||
def prep_inpaint_mask(
|
def prep_inpaint_mask(
|
||||||
self, context: InvocationContext, latents: torch.Tensor
|
self, context: InvocationContext, latents: torch.Tensor
|
||||||
@ -749,7 +906,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
|
||||||
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
conditioning_data = self.get_conditioning_data(
|
||||||
|
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||||
|
)
|
||||||
|
|
||||||
controlnet_data = self.prep_control_data(
|
controlnet_data = self.prep_control_data(
|
||||||
context=context,
|
context=context,
|
||||||
@ -763,16 +924,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
ip_adapter_data = self.prep_ip_adapter_data(
|
ip_adapter_data = self.prep_ip_adapter_data(
|
||||||
context=context,
|
context=context,
|
||||||
ip_adapter=self.ip_adapter,
|
ip_adapter=self.ip_adapter,
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
exit_stack=exit_stack,
|
exit_stack=exit_stack,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
dtype=unet.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||||
scheduler,
|
scheduler,
|
||||||
device=unet.device,
|
device=unet.device,
|
||||||
steps=self.steps,
|
steps=self.steps,
|
||||||
denoising_start=self.denoising_start,
|
denoising_start=self.denoising_start,
|
||||||
denoising_end=self.denoising_end,
|
denoising_end=self.denoising_end,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_latents = pipeline.latents_from_embeddings(
|
result_latents = pipeline.latents_from_embeddings(
|
||||||
@ -785,6 +949,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
masked_latents=masked_latents,
|
masked_latents=masked_latents,
|
||||||
gradient_mask=gradient_mask,
|
gradient_mask=gradient_mask,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
control_data=controlnet_data,
|
control_data=controlnet_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
@ -799,7 +964,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
name = context.tensors.save(tensor=result_latents)
|
name = context.tensors.save(tensor=result_latents)
|
||||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed)
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -1254,7 +1419,7 @@ class IdealSizeInvocation(BaseInvocation):
|
|||||||
return tuple((x - x % multiple_of) for x in args)
|
return tuple((x - x % multiple_of) for x in args)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
||||||
unet_config = context.models.get_config(**self.unet.unet.model_dump())
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
aspect = self.width / self.height
|
aspect = self.width / self.height
|
||||||
dimension: float = 512
|
dimension: float = 512
|
||||||
if unet_config.base == BaseModelType.StableDiffusion2:
|
if unet_config.base == BaseModelType.StableDiffusion2:
|
||||||
|
36
invokeai/app/invocations/mask.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext, invocation
|
||||||
|
from invokeai.app.invocations.fields import InputField, TensorField, WithMetadata
|
||||||
|
from invokeai.app.invocations.primitives import MaskOutput
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"rectangle_mask",
|
||||||
|
title="Create Rectangle Mask",
|
||||||
|
tags=["conditioning"],
|
||||||
|
category="conditioning",
|
||||||
|
version="1.0.1",
|
||||||
|
)
|
||||||
|
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||||
|
"""Create a rectangular mask."""
|
||||||
|
|
||||||
|
width: int = InputField(description="The width of the entire mask.")
|
||||||
|
height: int = InputField(description="The height of the entire mask.")
|
||||||
|
x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
|
||||||
|
y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
|
||||||
|
rectangle_width: int = InputField(description="The width of the rectangular masked region.")
|
||||||
|
rectangle_height: int = InputField(description="The height of the rectangular masked region.")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
|
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
|
||||||
|
mask[:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width] = (
|
||||||
|
True
|
||||||
|
)
|
||||||
|
|
||||||
|
mask_tensor_name = context.tensors.save(mask)
|
||||||
|
return MaskOutput(
|
||||||
|
mask=TensorField(tensor_name=mask_tensor_name),
|
||||||
|
width=self.width,
|
||||||
|
height=self.height,
|
||||||
|
)
|
@ -15,6 +15,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
InputField,
|
InputField,
|
||||||
LatentsField,
|
LatentsField,
|
||||||
OutputField,
|
OutputField,
|
||||||
|
TensorField,
|
||||||
UIComponent,
|
UIComponent,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.images.images_common import ImageDTO
|
from invokeai.app.services.images.images_common import ImageDTO
|
||||||
@ -405,9 +406,19 @@ class ColorInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
# region Conditioning
|
# region Conditioning
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("mask_output")
|
||||||
|
class MaskOutput(BaseInvocationOutput):
|
||||||
|
"""A torch mask tensor."""
|
||||||
|
|
||||||
|
mask: TensorField = OutputField(description="The mask.")
|
||||||
|
width: int = OutputField(description="The width of the mask in pixels.")
|
||||||
|
height: int = OutputField(description="The height of the mask in pixels.")
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("conditioning_output")
|
@invocation_output("conditioning_output")
|
||||||
class ConditioningOutput(BaseInvocationOutput):
|
class ConditioningOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single conditioning tensor"""
|
"""Base class for nodes that output a single conditioning tensor"""
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import locale
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
@ -401,7 +402,7 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
|||||||
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
|
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
|
||||||
"""
|
"""
|
||||||
assert config_path.suffix == ".yaml"
|
assert config_path.suffix == ".yaml"
|
||||||
with open(config_path) as file:
|
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||||
loaded_config_dict = yaml.safe_load(file)
|
loaded_config_dict = yaml.safe_load(file)
|
||||||
|
|
||||||
assert isinstance(loaded_config_dict, dict)
|
assert isinstance(loaded_config_dict, dict)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Model installation class."""
|
"""Model installation class."""
|
||||||
|
|
||||||
|
import locale
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import signal
|
import signal
|
||||||
@ -323,7 +324,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
legacy_models_yaml_path = Path(self._app_config.root_path, legacy_models_yaml_path)
|
legacy_models_yaml_path = Path(self._app_config.root_path, legacy_models_yaml_path)
|
||||||
|
|
||||||
if legacy_models_yaml_path.exists():
|
if legacy_models_yaml_path.exists():
|
||||||
legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
|
with open(legacy_models_yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||||
|
legacy_models_yaml = yaml.safe_load(file)
|
||||||
|
|
||||||
yaml_metadata = legacy_models_yaml.pop("__metadata__")
|
yaml_metadata = legacy_models_yaml.pop("__metadata__")
|
||||||
yaml_version = yaml_metadata.get("version")
|
yaml_version = yaml_metadata.get("version")
|
||||||
|
@ -80,6 +80,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
ram_cache = ModelCache(
|
ram_cache = ModelCache(
|
||||||
max_cache_size=app_config.ram,
|
max_cache_size=app_config.ram,
|
||||||
max_vram_cache_size=app_config.vram,
|
max_vram_cache_size=app_config.vram,
|
||||||
|
lazy_offloading=app_config.lazy_offload,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
execution_device=execution_device,
|
execution_device=execution_device,
|
||||||
)
|
)
|
||||||
|
@ -86,6 +86,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self._poll_now()
|
self._poll_now()
|
||||||
elif event_name == "batch_enqueued":
|
elif event_name == "batch_enqueued":
|
||||||
self._poll_now()
|
self._poll_now()
|
||||||
|
elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [
|
||||||
|
"completed",
|
||||||
|
"failed",
|
||||||
|
"canceled",
|
||||||
|
]:
|
||||||
|
self._poll_now()
|
||||||
|
|
||||||
def resume(self) -> SessionProcessorStatus:
|
def resume(self) -> SessionProcessorStatus:
|
||||||
if not self._resume_event.is_set():
|
if not self._resume_event.is_set():
|
||||||
|
@ -249,6 +249,18 @@ class ImagesInterface(InvocationContextInterface):
|
|||||||
"""
|
"""
|
||||||
return self._services.images.get_dto(image_name)
|
return self._services.images.get_dto(image_name)
|
||||||
|
|
||||||
|
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
|
||||||
|
"""Gets the internal path to an image or thumbnail.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_name: The name of the image to get the path of.
|
||||||
|
thumbnail: Get the path of the thumbnail instead of the full image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The local path of the image or thumbnail.
|
||||||
|
"""
|
||||||
|
return self._services.images.get_path(image_name, thumbnail)
|
||||||
|
|
||||||
|
|
||||||
class TensorsInterface(InvocationContextInterface):
|
class TensorsInterface(InvocationContextInterface):
|
||||||
def save(self, tensor: Tensor) -> str:
|
def save(self, tensor: Tensor) -> str:
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
Initialization file for invokeai.backend.image_util methods.
|
Initialization file for invokeai.backend.image_util methods.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .patchmatch import PatchMatch # noqa: F401
|
from .infill_methods.patchmatch import PatchMatch # noqa: F401
|
||||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
|
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
|
||||||
from .seamless import configure_model_padding # noqa: F401
|
from .seamless import configure_model_padding # noqa: F401
|
||||||
from .util import InitImageResizer, make_grid # noqa: F401
|
from .util import InitImageResizer, make_grid # noqa: F401
|
||||||
|
@ -7,6 +7,7 @@ from PIL import Image
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
|
|
||||||
@ -30,6 +31,14 @@ class LaMA:
|
|||||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||||
device = choose_torch_device()
|
device = choose_torch_device()
|
||||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
||||||
|
|
||||||
|
if not model_location.exists():
|
||||||
|
download_with_progress_bar(
|
||||||
|
name="LaMa Inpainting Model",
|
||||||
|
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||||
|
dest_path=model_location,
|
||||||
|
)
|
||||||
|
|
||||||
model = load_jit_model(model_location, device)
|
model = load_jit_model(model_location, device)
|
||||||
|
|
||||||
image = np.asarray(input_image.convert("RGB"))
|
image = np.asarray(input_image.convert("RGB"))
|
60
invokeai/backend/image_util/infill_methods/mosaic.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def infill_mosaic(
|
||||||
|
image: Image.Image,
|
||||||
|
tile_shape: Tuple[int, int] = (64, 64),
|
||||||
|
min_color: Tuple[int, int, int, int] = (0, 0, 0, 0),
|
||||||
|
max_color: Tuple[int, int, int, int] = (255, 255, 255, 0),
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
image:PIL - A PIL Image
|
||||||
|
tile_shape: Tuple[int,int] - Tile width & Tile Height
|
||||||
|
min_color: Tuple[int,int,int] - RGB values for the lowest color to clip to (0-255)
|
||||||
|
max_color: Tuple[int,int,int] - RGB values for the highest color to clip to (0-255)
|
||||||
|
"""
|
||||||
|
|
||||||
|
np_image = np.array(image) # Convert image to np array
|
||||||
|
alpha = np_image[:, :, 3] # Get the mask from the alpha channel of the image
|
||||||
|
non_transparent_pixels = np_image[alpha != 0, :3] # List of non-transparent pixels
|
||||||
|
|
||||||
|
# Create color tiles to paste in the empty areas of the image
|
||||||
|
tile_width, tile_height = tile_shape
|
||||||
|
|
||||||
|
# Clip the range of colors in the image to a particular spectrum only
|
||||||
|
r_min, g_min, b_min, _ = min_color
|
||||||
|
r_max, g_max, b_max, _ = max_color
|
||||||
|
non_transparent_pixels[:, 0] = np.clip(non_transparent_pixels[:, 0], r_min, r_max)
|
||||||
|
non_transparent_pixels[:, 1] = np.clip(non_transparent_pixels[:, 1], g_min, g_max)
|
||||||
|
non_transparent_pixels[:, 2] = np.clip(non_transparent_pixels[:, 2], b_min, b_max)
|
||||||
|
|
||||||
|
tiles = []
|
||||||
|
for _ in range(256):
|
||||||
|
color = non_transparent_pixels[np.random.randint(len(non_transparent_pixels))]
|
||||||
|
tile = np.zeros((tile_height, tile_width, 3), dtype=np.uint8)
|
||||||
|
tile[:, :] = color
|
||||||
|
tiles.append(tile)
|
||||||
|
|
||||||
|
# Fill the transparent area with tiles
|
||||||
|
filled_image = np.zeros((image.height, image.width, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
for x in range(image.width):
|
||||||
|
for y in range(image.height):
|
||||||
|
tile = tiles[np.random.randint(len(tiles))]
|
||||||
|
try:
|
||||||
|
filled_image[
|
||||||
|
y - (y % tile_height) : y - (y % tile_height) + tile_height,
|
||||||
|
x - (x % tile_width) : x - (x % tile_width) + tile_width,
|
||||||
|
] = tile
|
||||||
|
except ValueError:
|
||||||
|
# Need to handle edge cases - literally
|
||||||
|
pass
|
||||||
|
|
||||||
|
filled_image = Image.fromarray(filled_image) # Convert the filled tiles image to PIL
|
||||||
|
image = Image.composite(
|
||||||
|
image, filled_image, image.split()[-1]
|
||||||
|
) # Composite the original image on top of the filled tiles
|
||||||
|
return image
|
67
invokeai/backend/image_util/infill_methods/patchmatch.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
This module defines a singleton object, "patchmatch" that
|
||||||
|
wraps the actual patchmatch object. It respects the global
|
||||||
|
"try_patchmatch" attribute, so that patchmatch loading can
|
||||||
|
be suppressed or deferred
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
|
||||||
|
|
||||||
|
class PatchMatch:
|
||||||
|
"""
|
||||||
|
Thin class wrapper around the patchmatch function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
patch_match = None
|
||||||
|
tried_load: bool = False
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_patch_match(cls):
|
||||||
|
if cls.tried_load:
|
||||||
|
return
|
||||||
|
if get_config().patchmatch:
|
||||||
|
from patchmatch import patch_match as pm
|
||||||
|
|
||||||
|
if pm.patchmatch_available:
|
||||||
|
logger.info("Patchmatch initialized")
|
||||||
|
cls.patch_match = pm
|
||||||
|
else:
|
||||||
|
logger.info("Patchmatch not loaded (nonfatal)")
|
||||||
|
else:
|
||||||
|
logger.info("Patchmatch loading disabled")
|
||||||
|
cls.tried_load = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def patchmatch_available(cls) -> bool:
|
||||||
|
cls._load_patch_match()
|
||||||
|
if not cls.patch_match:
|
||||||
|
return False
|
||||||
|
return cls.patch_match.patchmatch_available
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def inpaint(cls, image: Image.Image) -> Image.Image:
|
||||||
|
if cls.patch_match is None or not cls.patchmatch_available():
|
||||||
|
return image
|
||||||
|
|
||||||
|
np_image = np.array(image)
|
||||||
|
mask = 255 - np_image[:, :, 3]
|
||||||
|
infilled = cls.patch_match.inpaint(np_image[:, :, :3], mask, patch_size=3)
|
||||||
|
return Image.fromarray(infilled, mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
|
def infill_patchmatch(image: Image.Image) -> Image.Image:
|
||||||
|
IS_PATCHMATCH_AVAILABLE = PatchMatch.patchmatch_available()
|
||||||
|
|
||||||
|
if not IS_PATCHMATCH_AVAILABLE:
|
||||||
|
logger.warning("PatchMatch is not available on this system")
|
||||||
|
return image
|
||||||
|
|
||||||
|
return PatchMatch.inpaint(image)
|
After Width: | Height: | Size: 45 KiB |
After Width: | Height: | Size: 2.2 KiB |
After Width: | Height: | Size: 36 KiB |
After Width: | Height: | Size: 33 KiB |
After Width: | Height: | Size: 21 KiB |
After Width: | Height: | Size: 39 KiB |
After Width: | Height: | Size: 42 KiB |
After Width: | Height: | Size: 48 KiB |
After Width: | Height: | Size: 49 KiB |
After Width: | Height: | Size: 60 KiB |
95
invokeai/backend/image_util/infill_methods/tile.ipynb
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"\"\"\"Smoke test for the tile infill\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from typing import Optional\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"from invokeai.backend.image_util.infill_methods.tile import infill_tile\n",
|
||||||
|
"\n",
|
||||||
|
"images: list[tuple[str, Image.Image]] = []\n",
|
||||||
|
"\n",
|
||||||
|
"for i in sorted(Path(\"./test_images/\").glob(\"*.webp\")):\n",
|
||||||
|
" images.append((i.name, Image.open(i)))\n",
|
||||||
|
" images.append((i.name, Image.open(i).transpose(Image.FLIP_LEFT_RIGHT)))\n",
|
||||||
|
" images.append((i.name, Image.open(i).transpose(Image.FLIP_TOP_BOTTOM)))\n",
|
||||||
|
" images.append((i.name, Image.open(i).resize((512, 512))))\n",
|
||||||
|
" images.append((i.name, Image.open(i).resize((1234, 461))))\n",
|
||||||
|
"\n",
|
||||||
|
"outputs: list[tuple[str, Image.Image, Image.Image, Optional[Image.Image]]] = []\n",
|
||||||
|
"\n",
|
||||||
|
"for name, image in images:\n",
|
||||||
|
" try:\n",
|
||||||
|
" output = infill_tile(image, seed=0, tile_size=32)\n",
|
||||||
|
" outputs.append((name, image, output.infilled, output.tile_image))\n",
|
||||||
|
" except ValueError as e:\n",
|
||||||
|
" print(f\"Skipping image {name}: {e}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Display the images in jupyter notebook\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"from PIL import ImageOps\n",
|
||||||
|
"\n",
|
||||||
|
"fig, axes = plt.subplots(len(outputs), 3, figsize=(10, 3 * len(outputs)))\n",
|
||||||
|
"plt.subplots_adjust(hspace=0)\n",
|
||||||
|
"\n",
|
||||||
|
"for i, (name, original, infilled, tile_image) in enumerate(outputs):\n",
|
||||||
|
" # Add a border to each image, helps to see the edges\n",
|
||||||
|
" size = original.size\n",
|
||||||
|
" original = ImageOps.expand(original, border=5, fill=\"red\")\n",
|
||||||
|
" filled = ImageOps.expand(infilled, border=5, fill=\"red\")\n",
|
||||||
|
" if tile_image:\n",
|
||||||
|
" tile_image = ImageOps.expand(tile_image, border=5, fill=\"red\")\n",
|
||||||
|
"\n",
|
||||||
|
" axes[i, 0].imshow(original)\n",
|
||||||
|
" axes[i, 0].axis(\"off\")\n",
|
||||||
|
" axes[i, 0].set_title(f\"Original ({name} - {size})\")\n",
|
||||||
|
"\n",
|
||||||
|
" if tile_image:\n",
|
||||||
|
" axes[i, 1].imshow(tile_image)\n",
|
||||||
|
" axes[i, 1].axis(\"off\")\n",
|
||||||
|
" axes[i, 1].set_title(\"Tile Image\")\n",
|
||||||
|
" else:\n",
|
||||||
|
" axes[i, 1].axis(\"off\")\n",
|
||||||
|
" axes[i, 1].set_title(\"NO TILES GENERATED (NO TRANSPARENCY)\")\n",
|
||||||
|
"\n",
|
||||||
|
" axes[i, 2].imshow(filled)\n",
|
||||||
|
" axes[i, 2].axis(\"off\")\n",
|
||||||
|
" axes[i, 2].set_title(\"Filled\")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".invokeai",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
122
invokeai/backend/image_util/infill_methods/tile.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def create_tile_pool(img_array: np.ndarray, tile_size: tuple[int, int]) -> list[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Create a pool of tiles from non-transparent areas of the image by systematically walking through the image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_array: numpy array of the image.
|
||||||
|
tile_size: tuple (tile_width, tile_height) specifying the size of each tile.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of numpy arrays, each representing a tile.
|
||||||
|
"""
|
||||||
|
tiles: list[np.ndarray] = []
|
||||||
|
rows, cols = img_array.shape[:2]
|
||||||
|
tile_width, tile_height = tile_size
|
||||||
|
|
||||||
|
for y in range(0, rows - tile_height + 1, tile_height):
|
||||||
|
for x in range(0, cols - tile_width + 1, tile_width):
|
||||||
|
tile = img_array[y : y + tile_height, x : x + tile_width]
|
||||||
|
# Check if the image has an alpha channel and the tile is completely opaque
|
||||||
|
if img_array.shape[2] == 4 and np.all(tile[:, :, 3] == 255):
|
||||||
|
tiles.append(tile)
|
||||||
|
elif img_array.shape[2] == 3: # If no alpha channel, append the tile
|
||||||
|
tiles.append(tile)
|
||||||
|
|
||||||
|
if not tiles:
|
||||||
|
raise ValueError(
|
||||||
|
"Not enough opaque pixels to generate any tiles. Use a smaller tile size or a different image."
|
||||||
|
)
|
||||||
|
|
||||||
|
return tiles
|
||||||
|
|
||||||
|
|
||||||
|
def create_filled_image(
|
||||||
|
img_array: np.ndarray, tile_pool: list[np.ndarray], tile_size: tuple[int, int], seed: int
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Create an image of the same dimensions as the original, filled entirely with tiles from the pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_array: numpy array of the original image.
|
||||||
|
tile_pool: A list of numpy arrays, each representing a tile.
|
||||||
|
tile_size: tuple (tile_width, tile_height) specifying the size of each tile.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy array representing the filled image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rows, cols, _ = img_array.shape
|
||||||
|
tile_width, tile_height = tile_size
|
||||||
|
|
||||||
|
# Prep an empty RGB image
|
||||||
|
filled_img_array = np.zeros((rows, cols, 3), dtype=img_array.dtype)
|
||||||
|
|
||||||
|
# Make the random tile selection reproducible
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
|
||||||
|
for y in range(0, rows, tile_height):
|
||||||
|
for x in range(0, cols, tile_width):
|
||||||
|
# Pick a random tile from the pool
|
||||||
|
tile = tile_pool[rng.integers(len(tile_pool))]
|
||||||
|
|
||||||
|
# Calculate the space available (may be less than tile size near the edges)
|
||||||
|
space_y = min(tile_height, rows - y)
|
||||||
|
space_x = min(tile_width, cols - x)
|
||||||
|
|
||||||
|
# Crop the tile if necessary to fit into the available space
|
||||||
|
cropped_tile = tile[:space_y, :space_x, :3]
|
||||||
|
|
||||||
|
# Fill the available space with the (possibly cropped) tile
|
||||||
|
filled_img_array[y : y + space_y, x : x + space_x, :3] = cropped_tile
|
||||||
|
|
||||||
|
return filled_img_array
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InfillTileOutput:
|
||||||
|
infilled: Image.Image
|
||||||
|
tile_image: Optional[Image.Image] = None
|
||||||
|
|
||||||
|
|
||||||
|
def infill_tile(image_to_infill: Image.Image, seed: int, tile_size: int) -> InfillTileOutput:
|
||||||
|
"""Infills an image with random tiles from the image itself.
|
||||||
|
|
||||||
|
If the image is not an RGBA image, it is returned untouched.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: The image to infill.
|
||||||
|
tile_size: The size of the tiles to use for infilling.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there are not enough opaque pixels to generate any tiles.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if image_to_infill.mode != "RGBA":
|
||||||
|
return InfillTileOutput(infilled=image_to_infill)
|
||||||
|
|
||||||
|
# Internally, we want a tuple of (tile_width, tile_height). In the future, the tile size can be any rectangle.
|
||||||
|
_tile_size = (tile_size, tile_size)
|
||||||
|
np_image = np.array(image_to_infill, dtype=np.uint8)
|
||||||
|
|
||||||
|
# Create the pool of tiles that we will use to infill
|
||||||
|
tile_pool = create_tile_pool(np_image, _tile_size)
|
||||||
|
|
||||||
|
# Create an image from the tiles, same size as the original
|
||||||
|
tile_np_image = create_filled_image(np_image, tile_pool, _tile_size, seed)
|
||||||
|
|
||||||
|
# Paste the OG image over the tile image, effectively infilling the area
|
||||||
|
tile_image = Image.fromarray(tile_np_image, "RGB")
|
||||||
|
infilled = tile_image.copy()
|
||||||
|
infilled.paste(image_to_infill, (0, 0), image_to_infill.split()[-1])
|
||||||
|
|
||||||
|
# I think we want this to be "RGBA"?
|
||||||
|
infilled.convert("RGBA")
|
||||||
|
|
||||||
|
return InfillTileOutput(infilled=infilled, tile_image=tile_image)
|
@ -1,49 +0,0 @@
|
|||||||
"""
|
|
||||||
This module defines a singleton object, "patchmatch" that
|
|
||||||
wraps the actual patchmatch object. It respects the global
|
|
||||||
"try_patchmatch" attribute, so that patchmatch loading can
|
|
||||||
be suppressed or deferred
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
|
||||||
|
|
||||||
|
|
||||||
class PatchMatch:
|
|
||||||
"""
|
|
||||||
Thin class wrapper around the patchmatch function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
patch_match = None
|
|
||||||
tried_load: bool = False
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _load_patch_match(self):
|
|
||||||
if self.tried_load:
|
|
||||||
return
|
|
||||||
if get_config().patchmatch:
|
|
||||||
from patchmatch import patch_match as pm
|
|
||||||
|
|
||||||
if pm.patchmatch_available:
|
|
||||||
logger.info("Patchmatch initialized")
|
|
||||||
else:
|
|
||||||
logger.info("Patchmatch not loaded (nonfatal)")
|
|
||||||
self.patch_match = pm
|
|
||||||
else:
|
|
||||||
logger.info("Patchmatch loading disabled")
|
|
||||||
self.tried_load = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def patchmatch_available(self) -> bool:
|
|
||||||
self._load_patch_match()
|
|
||||||
return self.patch_match and self.patch_match.patchmatch_available
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def inpaint(self, *args, **kwargs) -> np.ndarray:
|
|
||||||
if self.patchmatch_available():
|
|
||||||
return self.patch_match.inpaint(*args, **kwargs)
|
|
@ -1,182 +0,0 @@
|
|||||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
|
||||||
# and modified as needed
|
|
||||||
|
|
||||||
# tencent-ailab comment:
|
|
||||||
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
|
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
|
||||||
|
|
||||||
|
|
||||||
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
|
|
||||||
# loading.
|
|
||||||
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
DiffusersAttnProcessor2_0.__init__(self)
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
temb=None,
|
|
||||||
ip_adapter_image_prompt_embeds=None,
|
|
||||||
):
|
|
||||||
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
|
|
||||||
ip_adapter_image_prompt_embeds parameter.
|
|
||||||
"""
|
|
||||||
return DiffusersAttnProcessor2_0.__call__(
|
|
||||||
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IPAttnProcessor2_0(torch.nn.Module):
|
|
||||||
r"""
|
|
||||||
Attention processor for IP-Adapater for PyTorch 2.0.
|
|
||||||
Args:
|
|
||||||
hidden_size (`int`):
|
|
||||||
The hidden size of the attention layer.
|
|
||||||
cross_attention_dim (`int`):
|
|
||||||
The number of channels in the `encoder_hidden_states`.
|
|
||||||
scale (`float`, defaults to 1.0):
|
|
||||||
the weight scale of image prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
|
||||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
||||||
|
|
||||||
assert len(weights) == len(scales)
|
|
||||||
|
|
||||||
self._weights = weights
|
|
||||||
self._scales = scales
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
temb=None,
|
|
||||||
ip_adapter_image_prompt_embeds=None,
|
|
||||||
):
|
|
||||||
"""Apply IP-Adapter attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
|
|
||||||
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
|
||||||
"""
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
batch_size, channel, height, width = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
# scaled_dot_product_attention expects attention_mask shape to be
|
|
||||||
# (batch, heads, source_length, target_length)
|
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
||||||
|
|
||||||
if attn.group_norm is not None:
|
|
||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
elif attn.norm_cross:
|
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
|
||||||
value = attn.to_v(encoder_hidden_states)
|
|
||||||
|
|
||||||
inner_dim = key.shape[-1]
|
|
||||||
head_dim = inner_dim // attn.heads
|
|
||||||
|
|
||||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
|
||||||
|
|
||||||
if encoder_hidden_states is not None:
|
|
||||||
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
|
||||||
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
|
||||||
assert ip_adapter_image_prompt_embeds is not None
|
|
||||||
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
|
||||||
|
|
||||||
for ipa_embed, ipa_weights, scale in zip(
|
|
||||||
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
|
|
||||||
):
|
|
||||||
# The batch dimensions should match.
|
|
||||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
|
||||||
# The token_len dimensions should match.
|
|
||||||
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
|
||||||
|
|
||||||
ip_hidden_states = ipa_embed
|
|
||||||
|
|
||||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
|
||||||
|
|
||||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
|
||||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
|
||||||
|
|
||||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
|
||||||
|
|
||||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
|
||||||
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
||||||
ip_hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
|
||||||
|
|
||||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + scale * ip_hidden_states
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
||||||
|
|
||||||
if attn.residual_connection:
|
|
||||||
hidden_states = hidden_states + residual
|
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._ram_cache = ram_cache
|
self._ram_cache = ram_cache
|
||||||
self._convert_cache = convert_cache
|
self._convert_cache = convert_cache
|
||||||
self._torch_dtype = torch_dtype(choose_torch_device(), app_config)
|
self._torch_dtype = torch_dtype(choose_torch_device())
|
||||||
|
|
||||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
|
@ -117,7 +117,7 @@ class ModelCacheBase(ABC, Generic[T]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def stats(self) -> CacheStats:
|
def stats(self) -> Optional[CacheStats]:
|
||||||
"""Return collected CacheStats object."""
|
"""Return collected CacheStats object."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -270,12 +270,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
if torch.device(source_device).type == torch.device(target_device).type:
|
if torch.device(source_device).type == torch.device(target_device).type:
|
||||||
return
|
return
|
||||||
|
|
||||||
# may raise an exception here if insufficient GPU VRAM
|
|
||||||
self._check_free_vram(target_device, cache_entry.size)
|
|
||||||
|
|
||||||
start_model_to_time = time.time()
|
start_model_to_time = time.time()
|
||||||
snapshot_before = self._capture_memory_snapshot()
|
snapshot_before = self._capture_memory_snapshot()
|
||||||
|
try:
|
||||||
cache_entry.model.to(target_device)
|
cache_entry.model.to(target_device)
|
||||||
|
except Exception as e: # blow away cache entry
|
||||||
|
self._delete_cache_entry(cache_entry)
|
||||||
|
raise e
|
||||||
|
|
||||||
snapshot_after = self._capture_memory_snapshot()
|
snapshot_after = self._capture_memory_snapshot()
|
||||||
end_model_to_time = time.time()
|
end_model_to_time = time.time()
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
@ -330,11 +332,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_room(self, model_size: int) -> None:
|
def make_room(self, size: int) -> None:
|
||||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||||
# calculate how much memory this model will require
|
# calculate how much memory this model will require
|
||||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||||
bytes_needed = model_size
|
bytes_needed = size
|
||||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||||
current_size = self.cache_size()
|
current_size = self.cache_size()
|
||||||
|
|
||||||
@ -389,12 +391,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
# 1 from onnx runtime object
|
# 1 from onnx runtime object
|
||||||
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||||
)
|
)
|
||||||
current_size -= cache_entry.size
|
current_size -= cache_entry.size
|
||||||
models_cleared += 1
|
models_cleared += 1
|
||||||
del self._cache_stack[pos]
|
self._delete_cache_entry(cache_entry)
|
||||||
del self._cached_models[model_key]
|
|
||||||
del cache_entry
|
del cache_entry
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -422,16 +423,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
|
|
||||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||||
|
|
||||||
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
|
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||||
if target_device.type != "cuda":
|
self._cache_stack.remove(cache_entry.key)
|
||||||
return
|
del self._cached_models[cache_entry.key]
|
||||||
vram_device = ( # mem_get_info() needs an indexed device
|
|
||||||
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
|
|
||||||
)
|
|
||||||
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
|
|
||||||
if needed_size > free_mem:
|
|
||||||
needed_gb = round(needed_size / GIG, 2)
|
|
||||||
free_gb = round(free_mem / GIG, 2)
|
|
||||||
raise torch.cuda.OutOfMemoryError(
|
|
||||||
f"Insufficient VRAM to load model, requested {needed_gb}GB but only had {free_gb}GB free"
|
|
||||||
)
|
|
||||||
|
@ -34,7 +34,6 @@ class ModelLocker(ModelLockerBase):
|
|||||||
|
|
||||||
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
||||||
self._cache_entry.lock()
|
self._cache_entry.lock()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self._cache.lazy_offloading:
|
if self._cache.lazy_offloading:
|
||||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||||
@ -51,6 +50,7 @@ class ModelLocker(ModelLockerBase):
|
|||||||
except Exception:
|
except Exception:
|
||||||
self._cache_entry.unlock()
|
self._cache_entry.unlock()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def unlock(self) -> None:
|
def unlock(self) -> None:
|
||||||
|
@ -21,10 +21,12 @@ from pydantic import Field
|
|||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
IPAdapterData,
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
TextConditioningData,
|
||||||
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||||
from invokeai.backend.util.devices import normalize_device
|
from invokeai.backend.util.devices import normalize_device
|
||||||
|
|
||||||
@ -149,16 +151,6 @@ class ControlNetData:
|
|||||||
resize_mode: str = Field(default="just_resize")
|
resize_mode: str = Field(default="just_resize")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class IPAdapterData:
|
|
||||||
ip_adapter_model: IPAdapter = Field(default=None)
|
|
||||||
# TODO: change to polymorphic so can do different weights per step (once implemented...)
|
|
||||||
weight: Union[float, List[float]] = Field(default=1.0)
|
|
||||||
# weight: float = Field(default=1.0)
|
|
||||||
begin_step_percent: float = Field(default=0.0)
|
|
||||||
end_step_percent: float = Field(default=1.0)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class T2IAdapterData:
|
class T2IAdapterData:
|
||||||
"""A structure containing the information required to apply conditioning from a single T2I-Adapter model."""
|
"""A structure containing the information required to apply conditioning from a single T2I-Adapter model."""
|
||||||
@ -295,7 +287,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self,
|
self,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
conditioning_data: ConditioningData,
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
|
conditioning_data: TextConditioningData,
|
||||||
*,
|
*,
|
||||||
noise: Optional[torch.Tensor],
|
noise: Optional[torch.Tensor],
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
@ -308,7 +301,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
masked_latents: Optional[torch.Tensor] = None,
|
masked_latents: Optional[torch.Tensor] = None,
|
||||||
gradient_mask: Optional[bool] = False,
|
gradient_mask: Optional[bool] = False,
|
||||||
seed: Optional[int] = None,
|
seed: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if init_timestep.shape[0] == 0:
|
if init_timestep.shape[0] == 0:
|
||||||
return latents
|
return latents
|
||||||
@ -326,20 +319,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
|
|
||||||
if noise is None:
|
|
||||||
noise = torch.randn(
|
|
||||||
orig_latents.shape,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device="cpu",
|
|
||||||
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
|
|
||||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
|
||||||
|
|
||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
|
||||||
latents = torch.lerp(
|
|
||||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_inpainting_model(self.unet):
|
if is_inpainting_model(self.unet):
|
||||||
if masked_latents is None:
|
if masked_latents is None:
|
||||||
raise Exception("Source image required for inpaint mask when inpaint model used!")
|
raise Exception("Source image required for inpaint mask when inpaint model used!")
|
||||||
@ -348,6 +327,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self._unet_forward, mask, masked_latents
|
self._unet_forward, mask, masked_latents
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# if no noise provided, noisify unmasked area based on seed
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn(
|
||||||
|
orig_latents.shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu",
|
||||||
|
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||||
|
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||||
|
|
||||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -355,6 +343,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latents,
|
latents,
|
||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
@ -380,7 +369,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self,
|
self,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
@ -397,22 +387,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if timesteps.shape[0] == 0:
|
if timesteps.shape[0] == 0:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
ip_adapter_unet_patcher = None
|
use_ip_adapter = ip_adapter_data is not None
|
||||||
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
|
use_regional_prompting = (
|
||||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
|
||||||
self.invokeai_diffuser.model,
|
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
|
||||||
)
|
)
|
||||||
self.use_ip_adapter = False
|
unet_attention_patcher = None
|
||||||
elif ip_adapter_data is not None:
|
self.use_ip_adapter = use_ip_adapter
|
||||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
|
||||||
# As it is now, the IP-Adapter will silently be skipped.
|
|
||||||
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
|
||||||
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
|
||||||
self.use_ip_adapter = True
|
|
||||||
else:
|
|
||||||
attn_ctx = nullcontext()
|
attn_ctx = nullcontext()
|
||||||
|
if use_ip_adapter or use_regional_prompting:
|
||||||
|
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
|
||||||
|
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||||
|
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
|
|
||||||
with attn_ctx:
|
with attn_ctx:
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
@ -435,11 +420,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
conditioning_data,
|
conditioning_data,
|
||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
|
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||||
@ -463,14 +448,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self,
|
self,
|
||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
|
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
@ -485,23 +470,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||||
|
|
||||||
# handle IP-Adapter
|
|
||||||
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
|
|
||||||
for i, single_ip_adapter_data in enumerate(ip_adapter_data):
|
|
||||||
first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count)
|
|
||||||
last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count)
|
|
||||||
weight = (
|
|
||||||
single_ip_adapter_data.weight[step_index]
|
|
||||||
if isinstance(single_ip_adapter_data.weight, List)
|
|
||||||
else single_ip_adapter_data.weight
|
|
||||||
)
|
|
||||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
|
||||||
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
|
||||||
ip_adapter_unet_patcher.set_scale(i, weight)
|
|
||||||
else:
|
|
||||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
|
||||||
ip_adapter_unet_patcher.set_scale(i, 0.0)
|
|
||||||
|
|
||||||
# Handle ControlNet(s)
|
# Handle ControlNet(s)
|
||||||
down_block_additional_residuals = None
|
down_block_additional_residuals = None
|
||||||
mid_block_additional_residual = None
|
mid_block_additional_residual = None
|
||||||
@ -550,6 +518,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
|
ip_adapter_data=ip_adapter_data,
|
||||||
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
||||||
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
||||||
@ -569,7 +538,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
||||||
|
|
||||||
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
|
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
|
||||||
for guidance in additional_guidance:
|
for guidance in additional_guidance:
|
||||||
|
@ -1,27 +1,17 @@
|
|||||||
import dataclasses
|
import math
|
||||||
import inspect
|
from dataclasses import dataclass
|
||||||
from dataclasses import dataclass, field
|
from typing import List, Optional, Union
|
||||||
from typing import Any, List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .cross_attention_control import Arguments
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ExtraConditioningInfo:
|
|
||||||
tokens_count_including_eos_bos: int
|
|
||||||
cross_attention_control_args: Optional[Arguments] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def wants_cross_attention_control(self):
|
|
||||||
return self.cross_attention_control_args is not None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BasicConditioningInfo:
|
class BasicConditioningInfo:
|
||||||
|
"""SD 1/2 text conditioning information produced by Compel."""
|
||||||
|
|
||||||
embeds: torch.Tensor
|
embeds: torch.Tensor
|
||||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
|
||||||
|
|
||||||
def to(self, device, dtype=None):
|
def to(self, device, dtype=None):
|
||||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||||
@ -35,6 +25,8 @@ class ConditioningFieldData:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||||
|
"""SDXL text conditioning information produced by Compel."""
|
||||||
|
|
||||||
pooled_embeds: torch.Tensor
|
pooled_embeds: torch.Tensor
|
||||||
add_time_ids: torch.Tensor
|
add_time_ids: torch.Tensor
|
||||||
|
|
||||||
@ -57,37 +49,74 @@ class IPAdapterConditioningInfo:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningData:
|
class IPAdapterData:
|
||||||
unconditioned_embeddings: BasicConditioningInfo
|
ip_adapter_model: IPAdapter
|
||||||
text_embeddings: BasicConditioningInfo
|
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||||
"""
|
mask: torch.Tensor
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
|
||||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
|
||||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
|
||||||
"""
|
|
||||||
guidance_scale: Union[float, List[float]]
|
|
||||||
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
|
|
||||||
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
|
||||||
"""
|
|
||||||
guidance_rescale_multiplier: float = 0
|
|
||||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
|
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||||
|
weight: Union[float, List[float]] = 1.0
|
||||||
|
begin_step_percent: float = 0.0
|
||||||
|
end_step_percent: float = 1.0
|
||||||
|
|
||||||
@property
|
def scale_for_step(self, step_index: int, total_steps: int) -> float:
|
||||||
def dtype(self):
|
first_adapter_step = math.floor(self.begin_step_percent * total_steps)
|
||||||
return self.text_embeddings.dtype
|
last_adapter_step = math.ceil(self.end_step_percent * total_steps)
|
||||||
|
weight = self.weight[step_index] if isinstance(self.weight, List) else self.weight
|
||||||
|
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||||
|
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
||||||
|
return weight
|
||||||
|
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||||
|
return 0.0
|
||||||
|
|
||||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
|
||||||
scheduler_args = dict(self.scheduler_args)
|
@dataclass
|
||||||
step_method = inspect.signature(scheduler.step)
|
class Range:
|
||||||
for name, value in kwargs.items():
|
start: int
|
||||||
try:
|
end: int
|
||||||
step_method.bind_partial(**{name: value})
|
|
||||||
except TypeError:
|
|
||||||
# FIXME: don't silently discard arguments
|
class TextConditioningRegions:
|
||||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
def __init__(
|
||||||
else:
|
self,
|
||||||
scheduler_args[name] = value
|
masks: torch.Tensor,
|
||||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
ranges: list[Range],
|
||||||
|
):
|
||||||
|
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
||||||
|
# Shape: (1, num_prompts, height, width)
|
||||||
|
# Dtype: torch.bool
|
||||||
|
self.masks = masks
|
||||||
|
|
||||||
|
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
|
||||||
|
# ranges[i] contains the embedding range for the i'th prompt / mask.
|
||||||
|
self.ranges = ranges
|
||||||
|
|
||||||
|
assert self.masks.shape[1] == len(self.ranges)
|
||||||
|
|
||||||
|
|
||||||
|
class TextConditioningData:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
|
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
|
uncond_regions: Optional[TextConditioningRegions],
|
||||||
|
cond_regions: Optional[TextConditioningRegions],
|
||||||
|
guidance_scale: Union[float, List[float]],
|
||||||
|
guidance_rescale_multiplier: float = 0,
|
||||||
|
):
|
||||||
|
self.uncond_text = uncond_text
|
||||||
|
self.cond_text = cond_text
|
||||||
|
self.uncond_regions = uncond_regions
|
||||||
|
self.cond_regions = cond_regions
|
||||||
|
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
|
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||||
|
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||||
|
self.guidance_scale = guidance_scale
|
||||||
|
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||||
|
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||||
|
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||||
|
|
||||||
|
def is_sdxl(self):
|
||||||
|
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
|
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
|
@ -1,218 +0,0 @@
|
|||||||
# adapted from bloc97's CrossAttentionControl colab
|
|
||||||
# https://github.com/bloc97/CrossAttentionControl
|
|
||||||
|
|
||||||
|
|
||||||
import enum
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from compel.cross_attention_control import Arguments
|
|
||||||
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
|
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
|
||||||
|
|
||||||
from invokeai.backend.util.devices import torch_dtype
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionType(enum.Enum):
|
|
||||||
SELF = 1
|
|
||||||
TOKENS = 2
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttnControlContext:
|
|
||||||
def __init__(self, arguments: Arguments):
|
|
||||||
"""
|
|
||||||
:param arguments: Arguments for the cross-attention control process
|
|
||||||
"""
|
|
||||||
self.cross_attention_mask: Optional[torch.Tensor] = None
|
|
||||||
self.cross_attention_index_map: Optional[torch.Tensor] = None
|
|
||||||
self.arguments = arguments
|
|
||||||
|
|
||||||
def get_active_cross_attention_control_types_for_step(
|
|
||||||
self, percent_through: float = None
|
|
||||||
) -> list[CrossAttentionType]:
|
|
||||||
"""
|
|
||||||
Should cross-attention control be applied on the given step?
|
|
||||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
|
||||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
|
||||||
"""
|
|
||||||
if percent_through is None:
|
|
||||||
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
|
|
||||||
|
|
||||||
opts = self.arguments.edit_options
|
|
||||||
to_control = []
|
|
||||||
if opts["s_start"] <= percent_through < opts["s_end"]:
|
|
||||||
to_control.append(CrossAttentionType.SELF)
|
|
||||||
if opts["t_start"] <= percent_through < opts["t_end"]:
|
|
||||||
to_control.append(CrossAttentionType.TOKENS)
|
|
||||||
return to_control
|
|
||||||
|
|
||||||
|
|
||||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext):
|
|
||||||
"""
|
|
||||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
|
||||||
|
|
||||||
:param model: The unet model to inject into.
|
|
||||||
:return: None
|
|
||||||
"""
|
|
||||||
|
|
||||||
# adapted from init_attention_edit
|
|
||||||
device = context.arguments.edited_conditioning.device
|
|
||||||
|
|
||||||
# urgh. should this be hardcoded?
|
|
||||||
max_length = 77
|
|
||||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
|
||||||
mask = torch.zeros(max_length, dtype=torch_dtype(device))
|
|
||||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
|
||||||
indices = torch.arange(max_length, dtype=torch.long)
|
|
||||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
|
||||||
if b0 < max_length:
|
|
||||||
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
|
|
||||||
# these tokens have not been edited
|
|
||||||
indices[b0:b1] = indices_target[a0:a1]
|
|
||||||
mask[b0:b1] = 1
|
|
||||||
|
|
||||||
context.cross_attention_mask = mask.to(device)
|
|
||||||
context.cross_attention_index_map = indices.to(device)
|
|
||||||
old_attn_processors = unet.attn_processors
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
|
||||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
|
||||||
else:
|
|
||||||
# try to re-use an existing slice size
|
|
||||||
default_slice_size = 4
|
|
||||||
slice_size = next(
|
|
||||||
(p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size
|
|
||||||
)
|
|
||||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SwapCrossAttnContext:
|
|
||||||
modified_text_embeddings: torch.Tensor
|
|
||||||
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
|
|
||||||
mask: torch.Tensor # in the target space of the index_map
|
|
||||||
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
|
|
||||||
|
|
||||||
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
|
|
||||||
return attn_type in self.cross_attention_types_to_do
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_mask_and_index_map(
|
|
||||||
cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention
|
|
||||||
mask = torch.zeros(max_length)
|
|
||||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
|
||||||
indices = torch.arange(max_length, dtype=torch.long)
|
|
||||||
for name, a0, a1, b0, b1 in edit_opcodes:
|
|
||||||
if b0 < max_length:
|
|
||||||
if name == "equal":
|
|
||||||
# these tokens remain the same as in the original prompt
|
|
||||||
indices[b0:b1] = indices_target[a0:a1]
|
|
||||||
mask[b0:b1] = 1
|
|
||||||
|
|
||||||
return mask, indices
|
|
||||||
|
|
||||||
|
|
||||||
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|
||||||
# TODO: dynamically pick slice size based on memory conditions
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn: Attention,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
# kwargs
|
|
||||||
swap_cross_attn_context: SwapCrossAttnContext = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
|
||||||
|
|
||||||
# if cross-attention control is not in play, just call through to the base implementation.
|
|
||||||
if (
|
|
||||||
attention_type is CrossAttentionType.SELF
|
|
||||||
or swap_cross_attn_context is None
|
|
||||||
or not swap_cross_attn_context.wants_cross_attention_control(attention_type)
|
|
||||||
):
|
|
||||||
# print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
|
|
||||||
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
|
|
||||||
# else:
|
|
||||||
# print(f"SwapCrossAttnContext for {attention_type} active")
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
|
||||||
attention_mask = attn.prepare_attention_mask(
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
target_length=sequence_length,
|
|
||||||
batch_size=batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
dim = query.shape[-1]
|
|
||||||
query = attn.head_to_batch_dim(query)
|
|
||||||
|
|
||||||
original_text_embeddings = encoder_hidden_states
|
|
||||||
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
|
||||||
original_text_key = attn.to_k(original_text_embeddings)
|
|
||||||
modified_text_key = attn.to_k(modified_text_embeddings)
|
|
||||||
original_value = attn.to_v(original_text_embeddings)
|
|
||||||
modified_value = attn.to_v(modified_text_embeddings)
|
|
||||||
|
|
||||||
original_text_key = attn.head_to_batch_dim(original_text_key)
|
|
||||||
modified_text_key = attn.head_to_batch_dim(modified_text_key)
|
|
||||||
original_value = attn.head_to_batch_dim(original_value)
|
|
||||||
modified_value = attn.head_to_batch_dim(modified_value)
|
|
||||||
|
|
||||||
# compute slices and prepare output tensor
|
|
||||||
batch_size_attention = query.shape[0]
|
|
||||||
hidden_states = torch.zeros(
|
|
||||||
(batch_size_attention, sequence_length, dim // attn.heads),
|
|
||||||
device=query.device,
|
|
||||||
dtype=query.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# do slices
|
|
||||||
for i in range(max(1, hidden_states.shape[0] // self.slice_size)):
|
|
||||||
start_idx = i * self.slice_size
|
|
||||||
end_idx = (i + 1) * self.slice_size
|
|
||||||
|
|
||||||
query_slice = query[start_idx:end_idx]
|
|
||||||
original_key_slice = original_text_key[start_idx:end_idx]
|
|
||||||
modified_key_slice = modified_text_key[start_idx:end_idx]
|
|
||||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
|
||||||
|
|
||||||
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
|
|
||||||
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
|
|
||||||
|
|
||||||
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
|
||||||
# the original attention probabilities must be remapped to account for token index changes in the
|
|
||||||
# modified prompt
|
|
||||||
remapped_original_attn_slice = torch.index_select(
|
|
||||||
original_attn_slice, -1, swap_cross_attn_context.index_map
|
|
||||||
)
|
|
||||||
|
|
||||||
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
|
||||||
mask = swap_cross_attn_context.mask
|
|
||||||
inverse_mask = 1 - mask
|
|
||||||
attn_slice = remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
|
|
||||||
|
|
||||||
del remapped_original_attn_slice, modified_attn_slice
|
|
||||||
|
|
||||||
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
|
|
||||||
hidden_states[start_idx:end_idx] = attn_slice
|
|
||||||
|
|
||||||
# done
|
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
|
|
||||||
def __init__(self):
|
|
||||||
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice
|
|
198
invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
||||||
|
|
||||||
|
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
|
|
||||||
|
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||||
|
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
||||||
|
This implementation is based on
|
||||||
|
https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204
|
||||||
|
Supported custom features:
|
||||||
|
- IP-Adapter
|
||||||
|
- Regional prompt attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
|
||||||
|
):
|
||||||
|
"""Initialize a CustomAttnProcessor2_0.
|
||||||
|
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||||
|
layer-specific are passed to __init__().
|
||||||
|
Args:
|
||||||
|
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
|
||||||
|
for the i'th IP-Adapter.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._ip_adapter_weights = ip_adapter_weights
|
||||||
|
|
||||||
|
def _is_ip_adapter_enabled(self) -> bool:
|
||||||
|
return self._ip_adapter_weights is not None
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
# For regional prompting:
|
||||||
|
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||||
|
percent_through: Optional[torch.FloatTensor] = None,
|
||||||
|
# For IP-Adapter:
|
||||||
|
regional_ip_data: Optional[RegionalIPData] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""Apply attention.
|
||||||
|
Args:
|
||||||
|
regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
|
||||||
|
apply regional prompt masking.
|
||||||
|
regional_ip_data: The IP-Adapter data for the current batch.
|
||||||
|
"""
|
||||||
|
# If true, we are doing cross-attention, if false we are doing self-attention.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
|
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||||
|
residual = hidden_states
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
|
_, query_seq_len, _ = hidden_states.shape
|
||||||
|
# Handle regional prompt attention masks.
|
||||||
|
if regional_prompt_data is not None and is_cross_attention:
|
||||||
|
assert percent_through is not None
|
||||||
|
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||||
|
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = prompt_region_attention_mask
|
||||||
|
else:
|
||||||
|
attention_mask = prompt_region_attention_mask + attention_mask
|
||||||
|
|
||||||
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
|
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
|
# Apply IP-Adapter conditioning.
|
||||||
|
if is_cross_attention:
|
||||||
|
if self._is_ip_adapter_enabled():
|
||||||
|
assert regional_ip_data is not None
|
||||||
|
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||||
|
assert (
|
||||||
|
len(regional_ip_data.image_prompt_embeds)
|
||||||
|
== len(self._ip_adapter_weights)
|
||||||
|
== len(regional_ip_data.scales)
|
||||||
|
== ip_masks.shape[1]
|
||||||
|
)
|
||||||
|
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
||||||
|
ipa_weights = self._ip_adapter_weights[ipa_index]
|
||||||
|
ipa_scale = regional_ip_data.scales[ipa_index]
|
||||||
|
ip_mask = ip_masks[0, ipa_index, ...]
|
||||||
|
|
||||||
|
# The batch dimensions should match.
|
||||||
|
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||||
|
# The token_len dimensions should match.
|
||||||
|
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
||||||
|
|
||||||
|
ip_hidden_states = ipa_embed
|
||||||
|
|
||||||
|
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||||
|
|
||||||
|
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||||
|
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
|
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||||
|
|
||||||
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||||
|
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
ip_hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||||
|
|
||||||
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||||
|
else:
|
||||||
|
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
||||||
|
assert regional_ip_data is None
|
||||||
|
|
||||||
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
|
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
@ -0,0 +1,72 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class RegionalIPData:
|
||||||
|
"""A class to manage the data for regional IP-Adapter conditioning."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_prompt_embeds: list[torch.Tensor],
|
||||||
|
scales: list[float],
|
||||||
|
masks: list[torch.Tensor],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
max_downscale_factor: int = 8,
|
||||||
|
):
|
||||||
|
"""Initialize a `IPAdapterConditioningData` object."""
|
||||||
|
assert len(image_prompt_embeds) == len(scales) == len(masks)
|
||||||
|
|
||||||
|
# The image prompt embeddings.
|
||||||
|
# regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
|
||||||
|
# has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||||
|
self.image_prompt_embeds = image_prompt_embeds
|
||||||
|
|
||||||
|
# The scales for the IP-Adapter attention.
|
||||||
|
# scales[i] contains the attention scale for the i'th IP-Adapter.
|
||||||
|
self.scales = scales
|
||||||
|
|
||||||
|
# The IP-Adapter masks.
|
||||||
|
# self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of
|
||||||
|
# s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included
|
||||||
|
# regions and 0.0 for excluded regions.
|
||||||
|
self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype)
|
||||||
|
|
||||||
|
def _prepare_masks(
|
||||||
|
self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype
|
||||||
|
) -> dict[int, torch.Tensor]:
|
||||||
|
"""Prepare the masks for the IP-Adapter attention."""
|
||||||
|
# Concatenate the masks so that they can be processed more efficiently.
|
||||||
|
mask_tensor = torch.cat(masks, dim=1)
|
||||||
|
|
||||||
|
mask_tensor = mask_tensor.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
masks_by_seq_len: dict[int, torch.Tensor] = {}
|
||||||
|
|
||||||
|
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||||
|
downscale_factor = 1
|
||||||
|
while downscale_factor <= max_downscale_factor:
|
||||||
|
b, num_ip_adapters, h, w = mask_tensor.shape
|
||||||
|
# Assert that the batch size is 1, because I haven't thought through batch handling for this feature yet.
|
||||||
|
assert b == 1
|
||||||
|
|
||||||
|
# The IP-Adapters are applied in the cross-attention layers, where the query sequence length is the h * w of
|
||||||
|
# the spatial features.
|
||||||
|
query_seq_len = h * w
|
||||||
|
|
||||||
|
masks_by_seq_len[query_seq_len] = mask_tensor.view((b, num_ip_adapters, -1, 1))
|
||||||
|
|
||||||
|
downscale_factor *= 2
|
||||||
|
if downscale_factor <= max_downscale_factor:
|
||||||
|
# We use max pooling because we downscale to a pretty low resolution, so we don't want small mask
|
||||||
|
# regions to be lost entirely.
|
||||||
|
#
|
||||||
|
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
|
||||||
|
#
|
||||||
|
# TODO(ryand): In the future, we may want to experiment with other downsampling methods.
|
||||||
|
mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
return masks_by_seq_len
|
||||||
|
|
||||||
|
def get_masks(self, query_seq_len: int) -> torch.Tensor:
|
||||||
|
"""Get the mask for the given query sequence length."""
|
||||||
|
return self._masks_by_seq_len[query_seq_len]
|
@ -0,0 +1,105 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
TextConditioningRegions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RegionalPromptData:
|
||||||
|
"""A class to manage the prompt data for regional conditioning."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
regions: list[TextConditioningRegions],
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
max_downscale_factor: int = 8,
|
||||||
|
):
|
||||||
|
"""Initialize a `RegionalPromptData` object.
|
||||||
|
Args:
|
||||||
|
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||||
|
batch.
|
||||||
|
device (torch.device): The device to use for the attention masks.
|
||||||
|
dtype (torch.dtype): The data type to use for the attention masks.
|
||||||
|
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
|
||||||
|
in steps of 2x.
|
||||||
|
"""
|
||||||
|
self._regions = regions
|
||||||
|
self._device = device
|
||||||
|
self._dtype = dtype
|
||||||
|
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
|
||||||
|
# sequence length of s.
|
||||||
|
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
|
||||||
|
regions, max_downscale_factor
|
||||||
|
)
|
||||||
|
self._negative_cross_attn_mask_score = -10000.0
|
||||||
|
|
||||||
|
def _prepare_spatial_masks(
|
||||||
|
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
||||||
|
) -> list[dict[int, torch.Tensor]]:
|
||||||
|
"""Prepare the spatial masks for all downscaling factors."""
|
||||||
|
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
|
||||||
|
# of s.
|
||||||
|
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||||
|
|
||||||
|
for batch_sample_regions in regions:
|
||||||
|
batch_sample_masks_by_seq_len.append({})
|
||||||
|
|
||||||
|
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
|
||||||
|
|
||||||
|
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||||
|
downscale_factor = 1
|
||||||
|
while downscale_factor <= max_downscale_factor:
|
||||||
|
b, _num_prompts, h, w = batch_sample_masks.shape
|
||||||
|
assert b == 1
|
||||||
|
query_seq_len = h * w
|
||||||
|
|
||||||
|
batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks
|
||||||
|
|
||||||
|
downscale_factor *= 2
|
||||||
|
if downscale_factor <= max_downscale_factor:
|
||||||
|
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
|
||||||
|
# regions to be lost entirely.
|
||||||
|
#
|
||||||
|
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
|
||||||
|
#
|
||||||
|
# TODO(ryand): In the future, we may want to experiment with other downsampling methods (e.g.
|
||||||
|
# nearest interpolation), and could potentially use a weighted mask rather than a binary mask.
|
||||||
|
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2, ceil_mode=True)
|
||||||
|
|
||||||
|
return batch_sample_masks_by_seq_len
|
||||||
|
|
||||||
|
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
|
||||||
|
"""Get the cross-attention mask for the given query sequence length.
|
||||||
|
Args:
|
||||||
|
query_seq_len: The length of the flattened spatial features at the current downscaling level.
|
||||||
|
key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention
|
||||||
|
layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The cross-attention score mask.
|
||||||
|
shape: (batch_size, query_seq_len, key_seq_len).
|
||||||
|
dtype: float
|
||||||
|
"""
|
||||||
|
batch_size = len(self._spatial_masks_by_seq_len)
|
||||||
|
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||||
|
|
||||||
|
# Create an empty attention mask with the correct shape.
|
||||||
|
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device)
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||||
|
batch_sample_regions = self._regions[batch_idx]
|
||||||
|
|
||||||
|
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||||
|
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
|
||||||
|
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||||
|
|
||||||
|
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||||
|
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone()
|
||||||
|
batch_sample_query_mask = batch_sample_query_scores > 0.5
|
||||||
|
batch_sample_query_scores[batch_sample_query_mask] = 0.0
|
||||||
|
batch_sample_query_scores[~batch_sample_query_mask] = self._negative_cross_attn_mask_score
|
||||||
|
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
|
||||||
|
|
||||||
|
return attn_mask
|
@ -1,26 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
ConditioningData,
|
IPAdapterData,
|
||||||
ExtraConditioningInfo,
|
Range,
|
||||||
SDXLConditioningInfo,
|
TextConditioningData,
|
||||||
)
|
TextConditioningRegions,
|
||||||
|
|
||||||
from .cross_attention_control import (
|
|
||||||
CrossAttentionType,
|
|
||||||
CrossAttnControlContext,
|
|
||||||
SwapCrossAttnContext,
|
|
||||||
setup_cross_attention_control_attention_processors,
|
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
ModelForwardCallback: TypeAlias = Union[
|
ModelForwardCallback: TypeAlias = Union[
|
||||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||||
@ -58,31 +52,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model_forward_callback = model_forward_callback
|
self.model_forward_callback = model_forward_callback
|
||||||
self.cross_attention_control_context = None
|
|
||||||
self.sequential_guidance = config.sequential_guidance
|
self.sequential_guidance = config.sequential_guidance
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def custom_attention_context(
|
|
||||||
self,
|
|
||||||
unet: UNet2DConditionModel,
|
|
||||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
|
||||||
):
|
|
||||||
old_attn_processors = unet.attn_processors
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.cross_attention_control_context = CrossAttnControlContext(
|
|
||||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
|
||||||
)
|
|
||||||
setup_cross_attention_control_attention_processors(
|
|
||||||
unet,
|
|
||||||
self.cross_attention_control_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield None
|
|
||||||
finally:
|
|
||||||
self.cross_attention_control_context = None
|
|
||||||
unet.set_attn_processor(old_attn_processors)
|
|
||||||
|
|
||||||
def do_controlnet_step(
|
def do_controlnet_step(
|
||||||
self,
|
self,
|
||||||
control_data,
|
control_data,
|
||||||
@ -90,7 +61,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
timestep: torch.Tensor,
|
timestep: torch.Tensor,
|
||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
conditioning_data,
|
conditioning_data: TextConditioningData,
|
||||||
):
|
):
|
||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
@ -123,28 +94,28 @@ class InvokeAIDiffuserComponent:
|
|||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
|
|
||||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
encoder_hidden_states = conditioning_data.text_embeddings.embeds
|
encoder_hidden_states = conditioning_data.cond_text.embeds
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
else:
|
else:
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat(
|
||||||
[
|
[
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
# TODO: how to pad? just by zeros? or even truncate?
|
||||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
conditioning_data.uncond_text.pooled_embeds,
|
||||||
conditioning_data.text_embeddings.pooled_embeds,
|
conditioning_data.cond_text.pooled_embeds,
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
"time_ids": torch.cat(
|
"time_ids": torch.cat(
|
||||||
[
|
[
|
||||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
conditioning_data.uncond_text.add_time_ids,
|
||||||
conditioning_data.text_embeddings.add_time_ids,
|
conditioning_data.cond_text.add_time_ids,
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
@ -153,8 +124,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
) = self._concat_conditionings_for_batch(
|
) = self._concat_conditionings_for_batch(
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
conditioning_data.uncond_text.embeds,
|
||||||
conditioning_data.text_embeddings.embeds,
|
conditioning_data.cond_text.embeds,
|
||||||
)
|
)
|
||||||
if isinstance(control_datum.weight, list):
|
if isinstance(control_datum.weight, list):
|
||||||
# if controlnet has multiple weights, use the weight for the current step
|
# if controlnet has multiple weights, use the weight for the current step
|
||||||
@ -198,24 +169,15 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
timestep: torch.Tensor,
|
timestep: torch.Tensor,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
|
ip_adapter_data: Optional[list[IPAdapterData]],
|
||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
):
|
):
|
||||||
cross_attention_control_types_to_do = []
|
if self.sequential_guidance:
|
||||||
if self.cross_attention_control_context is not None:
|
|
||||||
percent_through = step_index / total_step_count
|
|
||||||
cross_attention_control_types_to_do = (
|
|
||||||
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
|
||||||
)
|
|
||||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
|
||||||
|
|
||||||
if wants_cross_attention_control or self.sequential_guidance:
|
|
||||||
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
|
|
||||||
# control is currently only supported in sequential mode.
|
|
||||||
(
|
(
|
||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
@ -223,7 +185,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
x=sample,
|
x=sample,
|
||||||
sigma=timestep,
|
sigma=timestep,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
ip_adapter_data=ip_adapter_data,
|
||||||
|
step_index=step_index,
|
||||||
|
total_step_count=total_step_count,
|
||||||
down_block_additional_residuals=down_block_additional_residuals,
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
mid_block_additional_residual=mid_block_additional_residual,
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||||
@ -236,6 +200,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
x=sample,
|
x=sample,
|
||||||
sigma=timestep,
|
sigma=timestep,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
|
ip_adapter_data=ip_adapter_data,
|
||||||
|
step_index=step_index,
|
||||||
|
total_step_count=total_step_count,
|
||||||
down_block_additional_residuals=down_block_additional_residuals,
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
mid_block_additional_residual=mid_block_additional_residual,
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||||
@ -294,53 +261,84 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
def _apply_standard_conditioning(
|
def _apply_standard_conditioning(
|
||||||
self,
|
self,
|
||||||
x,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma: torch.Tensor,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
|
ip_adapter_data: Optional[list[IPAdapterData]],
|
||||||
|
step_index: int,
|
||||||
|
total_step_count: int,
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
):
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
||||||
the cost of higher memory usage.
|
the cost of higher memory usage.
|
||||||
"""
|
"""
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if ip_adapter_data is not None:
|
||||||
|
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
image_prompt_embeds = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
|
||||||
torch.stack(
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
|
|
||||||
)
|
|
||||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
|
||||||
]
|
]
|
||||||
}
|
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||||
|
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||||
|
regional_ip_data = RegionalIPData(
|
||||||
|
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat(
|
||||||
[
|
[
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
# TODO: how to pad? just by zeros? or even truncate?
|
||||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
conditioning_data.uncond_text.pooled_embeds,
|
||||||
conditioning_data.text_embeddings.pooled_embeds,
|
conditioning_data.cond_text.pooled_embeds,
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
"time_ids": torch.cat(
|
"time_ids": torch.cat(
|
||||||
[
|
[
|
||||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
conditioning_data.uncond_text.add_time_ids,
|
||||||
conditioning_data.text_embeddings.add_time_ids,
|
conditioning_data.cond_text.add_time_ids,
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||||
|
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||||
|
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||||
|
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||||
|
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||||
|
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||||
|
regions = []
|
||||||
|
for c, r in [
|
||||||
|
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
|
||||||
|
(conditioning_data.cond_text, conditioning_data.cond_regions),
|
||||||
|
]:
|
||||||
|
if r is None:
|
||||||
|
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||||
|
_, _, h, w = x.shape
|
||||||
|
r = TextConditioningRegions(
|
||||||
|
masks=torch.ones((1, 1, h, w), dtype=x.dtype),
|
||||||
|
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||||
|
)
|
||||||
|
regions.append(r)
|
||||||
|
|
||||||
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
|
regions=regions, device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
||||||
|
|
||||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds
|
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
||||||
)
|
)
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice,
|
x_twice,
|
||||||
@ -360,8 +358,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
cross_attention_control_types_to_do: list[CrossAttentionType],
|
ip_adapter_data: Optional[list[IPAdapterData]],
|
||||||
|
step_index: int,
|
||||||
|
total_step_count: int,
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
@ -391,53 +391,48 @@ class InvokeAIDiffuserComponent:
|
|||||||
if mid_block_additional_residual is not None:
|
if mid_block_additional_residual is not None:
|
||||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
# If cross-attention control is enabled, prepare the SwapCrossAttnContext.
|
|
||||||
cross_attn_processor_context = None
|
|
||||||
if self.cross_attention_control_context is not None:
|
|
||||||
# Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do.
|
|
||||||
# This list is empty because cross-attention control is not applied in the unconditioned pass. This field
|
|
||||||
# will be populated before the conditioned pass.
|
|
||||||
cross_attn_processor_context = SwapCrossAttnContext(
|
|
||||||
modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning,
|
|
||||||
index_map=self.cross_attention_control_context.cross_attention_index_map,
|
|
||||||
mask=self.cross_attention_control_context.cross_attention_mask,
|
|
||||||
cross_attention_types_to_do=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
# Unconditioned pass
|
# Unconditioned pass
|
||||||
#####################
|
#####################
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
|
|
||||||
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if ip_adapter_data is not None:
|
||||||
|
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
image_prompt_embeds = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
|
||||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
}
|
|
||||||
|
|
||||||
# Prepare cross-attention control kwargs for the unconditioned pass.
|
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||||
if cross_attn_processor_context is not None:
|
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
regional_ip_data = RegionalIPData(
|
||||||
|
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
if conditioning_data.is_sdxl():
|
||||||
if is_sdxl:
|
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.uncond_text.pooled_embeds,
|
||||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
"time_ids": conditioning_data.uncond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Prepare prompt regions for the unconditioned pass.
|
||||||
|
if conditioning_data.uncond_regions is not None:
|
||||||
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
|
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
||||||
|
|
||||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||||
unconditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
conditioning_data.uncond_text.embeds,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=uncond_down_block,
|
down_block_additional_residuals=uncond_down_block,
|
||||||
mid_block_additional_residual=uncond_mid_block,
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
@ -449,36 +444,43 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Conditioned pass
|
# Conditioned pass
|
||||||
###################
|
###################
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
|
|
||||||
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
if ip_adapter_data is not None:
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
image_prompt_embeds = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
|
||||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
}
|
|
||||||
|
|
||||||
# Prepare cross-attention control kwargs for the conditioned pass.
|
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||||
if cross_attn_processor_context is not None:
|
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
regional_ip_data = RegionalIPData(
|
||||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if is_sdxl:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Prepare prompt regions for the conditioned pass.
|
||||||
|
if conditioning_data.cond_regions is not None:
|
||||||
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
|
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
||||||
|
|
||||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.text_embeddings.embeds,
|
conditioning_data.cond_text.embeds,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=cond_down_block,
|
down_block_additional_residuals=cond_down_block,
|
||||||
mid_block_additional_residual=cond_mid_block,
|
mid_block_additional_residual=cond_mid_block,
|
||||||
|
@ -1,52 +1,46 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||||
|
|
||||||
|
|
||||||
class UNetPatcher:
|
class UNetAttentionPatcher:
|
||||||
"""A class that contains multiple IP-Adapters and can apply them to a UNet."""
|
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
||||||
|
|
||||||
def __init__(self, ip_adapters: list[IPAdapter]):
|
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
|
||||||
self._ip_adapters = ip_adapters
|
self._ip_adapters = ip_adapters
|
||||||
self._scales = [1.0] * len(self._ip_adapters)
|
|
||||||
|
|
||||||
def set_scale(self, idx: int, value: float):
|
|
||||||
self._scales[idx] = value
|
|
||||||
|
|
||||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||||
weights into them.
|
weights into them (if IP-Adapters are being applied).
|
||||||
|
|
||||||
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
||||||
"""
|
"""
|
||||||
# Construct a dict of attention processors based on the UNet's architecture.
|
# Construct a dict of attention processors based on the UNet's architecture.
|
||||||
attn_procs = {}
|
attn_procs = {}
|
||||||
for idx, name in enumerate(unet.attn_processors.keys()):
|
for idx, name in enumerate(unet.attn_processors.keys()):
|
||||||
if name.endswith("attn1.processor"):
|
if name.endswith("attn1.processor") or self._ip_adapters is None:
|
||||||
attn_procs[name] = AttnProcessor2_0()
|
# "attn1" processors do not use IP-Adapters.
|
||||||
|
attn_procs[name] = CustomAttnProcessor2_0()
|
||||||
else:
|
else:
|
||||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||||
attn_procs[name] = IPAttnProcessor2_0(
|
attn_procs[name] = CustomAttnProcessor2_0(
|
||||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
||||||
self._scales,
|
|
||||||
)
|
)
|
||||||
return attn_procs
|
return attn_procs
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
|
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
|
||||||
"""A context manager that patches `unet` with IP-Adapter attention processors."""
|
"""A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
|
||||||
|
|
||||||
attn_procs = self._prepare_attention_processors(unet)
|
attn_procs = self._prepare_attention_processors(unet)
|
||||||
|
|
||||||
orig_attn_processors = unet.attn_processors
|
orig_attn_processors = unet.attn_processors
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
|
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
|
||||||
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
|
# the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
|
||||||
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
# moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
||||||
unet.set_attn_processor(attn_procs)
|
unet.set_attn_processor(attn_procs)
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
@ -6,8 +6,7 @@ from typing import Literal, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import PRECISION, get_config
|
||||||
from invokeai.app.services.config.config_default import get_config
|
|
||||||
|
|
||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
CUDA_DEVICE = torch.device("cuda")
|
CUDA_DEVICE = torch.device("cuda")
|
||||||
@ -33,35 +32,34 @@ def get_torch_device_name() -> str:
|
|||||||
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
||||||
|
|
||||||
|
|
||||||
# We are in transition here from using a single global AppConfig to allowing multiple
|
def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]:
|
||||||
# configurations. It is strongly recommended to pass the app_config to this function.
|
|
||||||
def choose_precision(
|
|
||||||
device: torch.device, app_config: Optional[InvokeAIAppConfig] = None
|
|
||||||
) -> Literal["float32", "float16", "bfloat16"]:
|
|
||||||
"""Return an appropriate precision for the given torch device."""
|
"""Return an appropriate precision for the given torch device."""
|
||||||
app_config = app_config or get_config()
|
app_config = get_config()
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
device_name = torch.cuda.get_device_name(device)
|
device_name = torch.cuda.get_device_name(device)
|
||||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||||
if app_config.precision == "float32":
|
# These GPUs have limited support for float16
|
||||||
return "float32"
|
return "float32"
|
||||||
elif app_config.precision == "bfloat16":
|
elif app_config.precision == "auto" or app_config.precision == "autocast":
|
||||||
return "bfloat16"
|
# Default to float16 for CUDA devices
|
||||||
|
return "float16"
|
||||||
else:
|
else:
|
||||||
return "float16"
|
# Use the user-defined precision
|
||||||
|
return app_config.precision
|
||||||
elif device.type == "mps":
|
elif device.type == "mps":
|
||||||
|
if app_config.precision == "auto" or app_config.precision == "autocast":
|
||||||
|
# Default to float16 for MPS devices
|
||||||
return "float16"
|
return "float16"
|
||||||
|
else:
|
||||||
|
# Use the user-defined precision
|
||||||
|
return app_config.precision
|
||||||
|
# CPU / safe fallback
|
||||||
return "float32"
|
return "float32"
|
||||||
|
|
||||||
|
|
||||||
# We are in transition here from using a single global AppConfig to allowing multiple
|
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
|
||||||
# configurations. It is strongly recommended to pass the app_config to this function.
|
|
||||||
def torch_dtype(
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
app_config: Optional[InvokeAIAppConfig] = None,
|
|
||||||
) -> torch.dtype:
|
|
||||||
device = device or choose_torch_device()
|
device = device or choose_torch_device()
|
||||||
precision = choose_precision(device, app_config)
|
precision = choose_precision(device)
|
||||||
if precision == "float16":
|
if precision == "float16":
|
||||||
return torch.float16
|
return torch.float16
|
||||||
if precision == "bfloat16":
|
if precision == "bfloat16":
|
||||||
@ -71,7 +69,7 @@ def torch_dtype(
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
def choose_autocast(precision):
|
def choose_autocast(precision: PRECISION):
|
||||||
"""Returns an autocast context or nullcontext for the given precision string"""
|
"""Returns an autocast context or nullcontext for the given precision string"""
|
||||||
# float16 currently requires autocast to avoid errors like:
|
# float16 currently requires autocast to avoid errors like:
|
||||||
# 'expected scalar type Half but found Float'
|
# 'expected scalar type Half but found Float'
|
||||||
|
53
invokeai/backend/util/mask.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def to_standard_mask_dim(mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Standardize the dimensions of a mask tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask (torch.Tensor): A mask tensor. The shape can be (1, h, w) or (h, w).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The output mask tensor. The shape is (1, h, w).
|
||||||
|
"""
|
||||||
|
# Get the mask height and width.
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
elif mask.ndim == 3 and mask.shape[0] == 1:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported mask shape: {mask.shape}. Expected (1, h, w) or (h, w).")
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def to_standard_float_mask(mask: torch.Tensor, out_dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
"""Standardize the format of a mask tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask (torch.Tensor): A mask tensor. The dtype can be any bool, float, or int type. The shape must be (1, h, w)
|
||||||
|
or (h, w).
|
||||||
|
|
||||||
|
out_dtype (torch.dtype): The dtype of the output mask tensor. Must be a float type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The output mask tensor. The dtype is out_dtype. The shape is (1, h, w). All values are either 0.0
|
||||||
|
or 1.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not out_dtype.is_floating_point:
|
||||||
|
raise ValueError(f"out_dtype must be a float type, but got {out_dtype}")
|
||||||
|
|
||||||
|
mask = to_standard_mask_dim(mask)
|
||||||
|
mask = mask.to(out_dtype)
|
||||||
|
|
||||||
|
# Set masked regions to 1.0.
|
||||||
|
if mask.dtype == torch.bool:
|
||||||
|
mask = mask.to(out_dtype)
|
||||||
|
else:
|
||||||
|
mask = mask.to(out_dtype)
|
||||||
|
mask_region = mask > 0.5
|
||||||
|
mask[mask_region] = 1.0
|
||||||
|
mask[~mask_region] = 0.0
|
||||||
|
|
||||||
|
return mask
|
@ -52,6 +52,7 @@
|
|||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@chakra-ui/react-use-size": "^2.1.0",
|
"@chakra-ui/react-use-size": "^2.1.0",
|
||||||
|
"@dagrejs/dagre": "^1.1.1",
|
||||||
"@dagrejs/graphlib": "^2.2.1",
|
"@dagrejs/graphlib": "^2.2.1",
|
||||||
"@dnd-kit/core": "^6.1.0",
|
"@dnd-kit/core": "^6.1.0",
|
||||||
"@dnd-kit/sortable": "^8.0.0",
|
"@dnd-kit/sortable": "^8.0.0",
|
||||||
|
@ -11,6 +11,9 @@ dependencies:
|
|||||||
'@chakra-ui/react-use-size':
|
'@chakra-ui/react-use-size':
|
||||||
specifier: ^2.1.0
|
specifier: ^2.1.0
|
||||||
version: 2.1.0(react@18.2.0)
|
version: 2.1.0(react@18.2.0)
|
||||||
|
'@dagrejs/dagre':
|
||||||
|
specifier: ^1.1.1
|
||||||
|
version: 1.1.1
|
||||||
'@dagrejs/graphlib':
|
'@dagrejs/graphlib':
|
||||||
specifier: ^2.2.1
|
specifier: ^2.2.1
|
||||||
version: 2.2.1
|
version: 2.2.1
|
||||||
@ -3092,6 +3095,12 @@ packages:
|
|||||||
dev: true
|
dev: true
|
||||||
optional: true
|
optional: true
|
||||||
|
|
||||||
|
/@dagrejs/dagre@1.1.1:
|
||||||
|
resolution: {integrity: sha512-AQfT6pffEuPE32weFzhS/u3UpX+bRXUARIXL7UqLaxz497cN8pjuBlX6axO4IIECE2gBV8eLFQkGCtKX5sDaUA==}
|
||||||
|
dependencies:
|
||||||
|
'@dagrejs/graphlib': 2.2.1
|
||||||
|
dev: false
|
||||||
|
|
||||||
/@dagrejs/graphlib@2.2.1:
|
/@dagrejs/graphlib@2.2.1:
|
||||||
resolution: {integrity: sha512-xJsN1v6OAxXk6jmNdM+OS/bBE8nDCwM0yDNprXR18ZNatL6to9ggod9+l2XtiLhXfLm0NkE7+Er/cpdlM+SkUA==}
|
resolution: {integrity: sha512-xJsN1v6OAxXk6jmNdM+OS/bBE8nDCwM0yDNprXR18ZNatL6to9ggod9+l2XtiLhXfLm0NkE7+Er/cpdlM+SkUA==}
|
||||||
engines: {node: '>17.0.0'}
|
engines: {node: '>17.0.0'}
|
||||||
|
@ -291,7 +291,6 @@
|
|||||||
"canvasMerged": "تم دمج الخط",
|
"canvasMerged": "تم دمج الخط",
|
||||||
"sentToImageToImage": "تم إرسال إلى صورة إلى صورة",
|
"sentToImageToImage": "تم إرسال إلى صورة إلى صورة",
|
||||||
"sentToUnifiedCanvas": "تم إرسال إلى لوحة موحدة",
|
"sentToUnifiedCanvas": "تم إرسال إلى لوحة موحدة",
|
||||||
"parametersSet": "تم تعيين المعلمات",
|
|
||||||
"parametersNotSet": "لم يتم تعيين المعلمات",
|
"parametersNotSet": "لم يتم تعيين المعلمات",
|
||||||
"metadataLoadFailed": "فشل تحميل البيانات الوصفية"
|
"metadataLoadFailed": "فشل تحميل البيانات الوصفية"
|
||||||
},
|
},
|
||||||
|
@ -75,7 +75,8 @@
|
|||||||
"copy": "Kopieren",
|
"copy": "Kopieren",
|
||||||
"aboutHeading": "Nutzen Sie Ihre kreative Energie",
|
"aboutHeading": "Nutzen Sie Ihre kreative Energie",
|
||||||
"toResolve": "Lösen",
|
"toResolve": "Lösen",
|
||||||
"add": "Hinzufügen"
|
"add": "Hinzufügen",
|
||||||
|
"loglevel": "Protokoll Stufe"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"galleryImageSize": "Bildgröße",
|
"galleryImageSize": "Bildgröße",
|
||||||
@ -388,7 +389,14 @@
|
|||||||
"vaePrecision": "VAE-Präzision",
|
"vaePrecision": "VAE-Präzision",
|
||||||
"variant": "Variante",
|
"variant": "Variante",
|
||||||
"modelDeleteFailed": "Modell konnte nicht gelöscht werden",
|
"modelDeleteFailed": "Modell konnte nicht gelöscht werden",
|
||||||
"noModelSelected": "Kein Modell ausgewählt"
|
"noModelSelected": "Kein Modell ausgewählt",
|
||||||
|
"huggingFace": "HuggingFace",
|
||||||
|
"defaultSettings": "Standardeinstellungen",
|
||||||
|
"edit": "Bearbeiten",
|
||||||
|
"cancel": "Stornieren",
|
||||||
|
"defaultSettingsSaved": "Standardeinstellungen gespeichert",
|
||||||
|
"addModels": "Model hinzufügen",
|
||||||
|
"deleteModelImage": "Lösche Model Bild"
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "Bilder",
|
"images": "Bilder",
|
||||||
@ -472,7 +480,6 @@
|
|||||||
"canvasMerged": "Leinwand zusammengeführt",
|
"canvasMerged": "Leinwand zusammengeführt",
|
||||||
"sentToImageToImage": "Gesendet an Bild zu Bild",
|
"sentToImageToImage": "Gesendet an Bild zu Bild",
|
||||||
"sentToUnifiedCanvas": "Gesendet an Leinwand",
|
"sentToUnifiedCanvas": "Gesendet an Leinwand",
|
||||||
"parametersSet": "Parameter festlegen",
|
|
||||||
"parametersNotSet": "Parameter nicht festgelegt",
|
"parametersNotSet": "Parameter nicht festgelegt",
|
||||||
"metadataLoadFailed": "Metadaten konnten nicht geladen werden",
|
"metadataLoadFailed": "Metadaten konnten nicht geladen werden",
|
||||||
"setCanvasInitialImage": "Ausgangsbild setzen",
|
"setCanvasInitialImage": "Ausgangsbild setzen",
|
||||||
@ -677,7 +684,8 @@
|
|||||||
"body": "Körper",
|
"body": "Körper",
|
||||||
"hands": "Hände",
|
"hands": "Hände",
|
||||||
"dwOpenpose": "DW Openpose",
|
"dwOpenpose": "DW Openpose",
|
||||||
"dwOpenposeDescription": "Posenschätzung mit DW Openpose"
|
"dwOpenposeDescription": "Posenschätzung mit DW Openpose",
|
||||||
|
"selectCLIPVisionModel": "Wähle ein CLIP Vision Model aus"
|
||||||
},
|
},
|
||||||
"queue": {
|
"queue": {
|
||||||
"status": "Status",
|
"status": "Status",
|
||||||
@ -765,7 +773,10 @@
|
|||||||
"recallParameters": "Parameter wiederherstellen",
|
"recallParameters": "Parameter wiederherstellen",
|
||||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||||
"allPrompts": "Alle Prompts",
|
"allPrompts": "Alle Prompts",
|
||||||
"imageDimensions": "Bilder Auslösungen"
|
"imageDimensions": "Bilder Auslösungen",
|
||||||
|
"parameterSet": "Parameter {{parameter}} setzen",
|
||||||
|
"recallParameter": "{{label}} Abrufen",
|
||||||
|
"parsingFailed": "Parsing Fehlgeschlagen"
|
||||||
},
|
},
|
||||||
"popovers": {
|
"popovers": {
|
||||||
"noiseUseCPU": {
|
"noiseUseCPU": {
|
||||||
@ -1030,7 +1041,8 @@
|
|||||||
"title": "Bild"
|
"title": "Bild"
|
||||||
},
|
},
|
||||||
"advanced": {
|
"advanced": {
|
||||||
"title": "Erweitert"
|
"title": "Erweitert",
|
||||||
|
"options": "$t(accordions.advanced.title) Optionen"
|
||||||
},
|
},
|
||||||
"control": {
|
"control": {
|
||||||
"title": "Kontrolle"
|
"title": "Kontrolle"
|
||||||
|
@ -684,6 +684,7 @@
|
|||||||
"noModelsInstalled": "No Models Installed",
|
"noModelsInstalled": "No Models Installed",
|
||||||
"noModelsInstalledDesc1": "Install models with the",
|
"noModelsInstalledDesc1": "Install models with the",
|
||||||
"noModelSelected": "No Model Selected",
|
"noModelSelected": "No Model Selected",
|
||||||
|
"noMatchingModels": "No matching Models",
|
||||||
"none": "none",
|
"none": "none",
|
||||||
"path": "Path",
|
"path": "Path",
|
||||||
"pathToConfig": "Path To Config",
|
"pathToConfig": "Path To Config",
|
||||||
@ -848,6 +849,7 @@
|
|||||||
"version": "Version",
|
"version": "Version",
|
||||||
"versionUnknown": " Version Unknown",
|
"versionUnknown": " Version Unknown",
|
||||||
"workflow": "Workflow",
|
"workflow": "Workflow",
|
||||||
|
"graph": "Graph",
|
||||||
"workflowAuthor": "Author",
|
"workflowAuthor": "Author",
|
||||||
"workflowContact": "Contact",
|
"workflowContact": "Contact",
|
||||||
"workflowDescription": "Short Description",
|
"workflowDescription": "Short Description",
|
||||||
@ -887,6 +889,11 @@
|
|||||||
"imageFit": "Fit Initial Image To Output Size",
|
"imageFit": "Fit Initial Image To Output Size",
|
||||||
"images": "Images",
|
"images": "Images",
|
||||||
"infillMethod": "Infill Method",
|
"infillMethod": "Infill Method",
|
||||||
|
"infillMosaicTileWidth": "Tile Width",
|
||||||
|
"infillMosaicTileHeight": "Tile Height",
|
||||||
|
"infillMosaicMinColor": "Min Color",
|
||||||
|
"infillMosaicMaxColor": "Max Color",
|
||||||
|
"infillColorValue": "Fill Color",
|
||||||
"info": "Info",
|
"info": "Info",
|
||||||
"invoke": {
|
"invoke": {
|
||||||
"addingImagesTo": "Adding images to",
|
"addingImagesTo": "Adding images to",
|
||||||
@ -1035,10 +1042,10 @@
|
|||||||
"metadataLoadFailed": "Failed to load metadata",
|
"metadataLoadFailed": "Failed to load metadata",
|
||||||
"modelAddedSimple": "Model Added to Queue",
|
"modelAddedSimple": "Model Added to Queue",
|
||||||
"modelImportCanceled": "Model Import Canceled",
|
"modelImportCanceled": "Model Import Canceled",
|
||||||
|
"parameters": "Parameters",
|
||||||
"parameterNotSet": "{{parameter}} not set",
|
"parameterNotSet": "{{parameter}} not set",
|
||||||
"parameterSet": "{{parameter}} set",
|
"parameterSet": "{{parameter}} set",
|
||||||
"parametersNotSet": "Parameters Not Set",
|
"parametersNotSet": "Parameters Not Set",
|
||||||
"parametersSet": "Parameters Set",
|
|
||||||
"problemCopyingCanvas": "Problem Copying Canvas",
|
"problemCopyingCanvas": "Problem Copying Canvas",
|
||||||
"problemCopyingCanvasDesc": "Unable to export base layer",
|
"problemCopyingCanvasDesc": "Unable to export base layer",
|
||||||
"problemCopyingImage": "Unable to Copy Image",
|
"problemCopyingImage": "Unable to Copy Image",
|
||||||
@ -1417,6 +1424,7 @@
|
|||||||
"eraseBoundingBox": "Erase Bounding Box",
|
"eraseBoundingBox": "Erase Bounding Box",
|
||||||
"eraser": "Eraser",
|
"eraser": "Eraser",
|
||||||
"fillBoundingBox": "Fill Bounding Box",
|
"fillBoundingBox": "Fill Bounding Box",
|
||||||
|
"initialFitImageSize": "Fit Image Size on Drop",
|
||||||
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
|
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
|
||||||
"layer": "Layer",
|
"layer": "Layer",
|
||||||
"limitStrokesToBox": "Limit Strokes to Box",
|
"limitStrokesToBox": "Limit Strokes to Box",
|
||||||
@ -1475,7 +1483,11 @@
|
|||||||
"workflowName": "Workflow Name",
|
"workflowName": "Workflow Name",
|
||||||
"newWorkflowCreated": "New Workflow Created",
|
"newWorkflowCreated": "New Workflow Created",
|
||||||
"workflowCleared": "Workflow Cleared",
|
"workflowCleared": "Workflow Cleared",
|
||||||
"workflowEditorMenu": "Workflow Editor Menu"
|
"workflowEditorMenu": "Workflow Editor Menu",
|
||||||
|
"loadFromGraph": "Load Workflow from Graph",
|
||||||
|
"convertGraph": "Convert Graph",
|
||||||
|
"loadWorkflow": "$t(common.load) Workflow",
|
||||||
|
"autoLayout": "Auto Layout"
|
||||||
},
|
},
|
||||||
"app": {
|
"app": {
|
||||||
"storeNotInitialized": "Store is not initialized"
|
"storeNotInitialized": "Store is not initialized"
|
||||||
|
@ -363,7 +363,6 @@
|
|||||||
"canvasMerged": "Lienzo consolidado",
|
"canvasMerged": "Lienzo consolidado",
|
||||||
"sentToImageToImage": "Enviar hacia Imagen a Imagen",
|
"sentToImageToImage": "Enviar hacia Imagen a Imagen",
|
||||||
"sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado",
|
"sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado",
|
||||||
"parametersSet": "Parámetros establecidos",
|
|
||||||
"parametersNotSet": "Parámetros no establecidos",
|
"parametersNotSet": "Parámetros no establecidos",
|
||||||
"metadataLoadFailed": "Error al cargar metadatos",
|
"metadataLoadFailed": "Error al cargar metadatos",
|
||||||
"serverError": "Error en el servidor",
|
"serverError": "Error en el servidor",
|
||||||
|
@ -298,7 +298,6 @@
|
|||||||
"canvasMerged": "Canvas fusionné",
|
"canvasMerged": "Canvas fusionné",
|
||||||
"sentToImageToImage": "Envoyé à Image à Image",
|
"sentToImageToImage": "Envoyé à Image à Image",
|
||||||
"sentToUnifiedCanvas": "Envoyé à Canvas unifié",
|
"sentToUnifiedCanvas": "Envoyé à Canvas unifié",
|
||||||
"parametersSet": "Paramètres définis",
|
|
||||||
"parametersNotSet": "Paramètres non définis",
|
"parametersNotSet": "Paramètres non définis",
|
||||||
"metadataLoadFailed": "Échec du chargement des métadonnées"
|
"metadataLoadFailed": "Échec du chargement des métadonnées"
|
||||||
},
|
},
|
||||||
|
@ -306,7 +306,6 @@
|
|||||||
"canvasMerged": "קנבס מוזג",
|
"canvasMerged": "קנבס מוזג",
|
||||||
"sentToImageToImage": "נשלח לתמונה לתמונה",
|
"sentToImageToImage": "נשלח לתמונה לתמונה",
|
||||||
"sentToUnifiedCanvas": "נשלח אל קנבס מאוחד",
|
"sentToUnifiedCanvas": "נשלח אל קנבס מאוחד",
|
||||||
"parametersSet": "הגדרת פרמטרים",
|
|
||||||
"parametersNotSet": "פרמטרים לא הוגדרו",
|
"parametersNotSet": "פרמטרים לא הוגדרו",
|
||||||
"metadataLoadFailed": "טעינת מטא-נתונים נכשלה"
|
"metadataLoadFailed": "טעינת מטא-נתונים נכשלה"
|
||||||
},
|
},
|
||||||
|
@ -366,7 +366,7 @@
|
|||||||
"modelConverted": "Modello convertito",
|
"modelConverted": "Modello convertito",
|
||||||
"alpha": "Alpha",
|
"alpha": "Alpha",
|
||||||
"convertToDiffusersHelpText1": "Questo modello verrà convertito nel formato 🧨 Diffusori.",
|
"convertToDiffusersHelpText1": "Questo modello verrà convertito nel formato 🧨 Diffusori.",
|
||||||
"convertToDiffusersHelpText3": "Il file Checkpoint su disco verrà eliminato se si trova nella cartella principale di InvokeAI. Se si trova invece in una posizione personalizzata, NON verrà eliminato.",
|
"convertToDiffusersHelpText3": "Il file del modello su disco verrà eliminato se si trova nella cartella principale di InvokeAI. Se si trova invece in una posizione personalizzata, NON verrà eliminato.",
|
||||||
"v2_base": "v2 (512px)",
|
"v2_base": "v2 (512px)",
|
||||||
"v2_768": "v2 (768px)",
|
"v2_768": "v2 (768px)",
|
||||||
"none": "nessuno",
|
"none": "nessuno",
|
||||||
@ -443,7 +443,8 @@
|
|||||||
"noModelsInstalled": "Nessun modello installato",
|
"noModelsInstalled": "Nessun modello installato",
|
||||||
"hfTokenInvalidErrorMessage2": "Aggiornalo in ",
|
"hfTokenInvalidErrorMessage2": "Aggiornalo in ",
|
||||||
"main": "Principali",
|
"main": "Principali",
|
||||||
"noModelsInstalledDesc1": "Installa i modelli con"
|
"noModelsInstalledDesc1": "Installa i modelli con",
|
||||||
|
"ipAdapters": "Adattatori IP"
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "Immagini",
|
"images": "Immagini",
|
||||||
@ -568,7 +569,6 @@
|
|||||||
"canvasMerged": "Tela unita",
|
"canvasMerged": "Tela unita",
|
||||||
"sentToImageToImage": "Inviato a Immagine a Immagine",
|
"sentToImageToImage": "Inviato a Immagine a Immagine",
|
||||||
"sentToUnifiedCanvas": "Inviato a Tela Unificata",
|
"sentToUnifiedCanvas": "Inviato a Tela Unificata",
|
||||||
"parametersSet": "Parametri impostati",
|
|
||||||
"parametersNotSet": "Parametri non impostati",
|
"parametersNotSet": "Parametri non impostati",
|
||||||
"metadataLoadFailed": "Impossibile caricare i metadati",
|
"metadataLoadFailed": "Impossibile caricare i metadati",
|
||||||
"serverError": "Errore del Server",
|
"serverError": "Errore del Server",
|
||||||
@ -937,7 +937,8 @@
|
|||||||
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
|
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
|
||||||
"mediapipeFace": "Mediapipe Volto",
|
"mediapipeFace": "Mediapipe Volto",
|
||||||
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
|
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
|
||||||
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))"
|
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))",
|
||||||
|
"selectCLIPVisionModel": "Seleziona un modello CLIP Vision"
|
||||||
},
|
},
|
||||||
"queue": {
|
"queue": {
|
||||||
"queueFront": "Aggiungi all'inizio della coda",
|
"queueFront": "Aggiungi all'inizio della coda",
|
||||||
|
@ -420,7 +420,6 @@
|
|||||||
"canvasMerged": "Canvas samengevoegd",
|
"canvasMerged": "Canvas samengevoegd",
|
||||||
"sentToImageToImage": "Gestuurd naar Afbeelding naar afbeelding",
|
"sentToImageToImage": "Gestuurd naar Afbeelding naar afbeelding",
|
||||||
"sentToUnifiedCanvas": "Gestuurd naar Centraal canvas",
|
"sentToUnifiedCanvas": "Gestuurd naar Centraal canvas",
|
||||||
"parametersSet": "Parameters ingesteld",
|
|
||||||
"parametersNotSet": "Parameters niet ingesteld",
|
"parametersNotSet": "Parameters niet ingesteld",
|
||||||
"metadataLoadFailed": "Fout bij laden metagegevens",
|
"metadataLoadFailed": "Fout bij laden metagegevens",
|
||||||
"serverError": "Serverfout",
|
"serverError": "Serverfout",
|
||||||
|
@ -267,7 +267,6 @@
|
|||||||
"canvasMerged": "Scalono widoczne warstwy",
|
"canvasMerged": "Scalono widoczne warstwy",
|
||||||
"sentToImageToImage": "Wysłano do Obraz na obraz",
|
"sentToImageToImage": "Wysłano do Obraz na obraz",
|
||||||
"sentToUnifiedCanvas": "Wysłano do trybu uniwersalnego",
|
"sentToUnifiedCanvas": "Wysłano do trybu uniwersalnego",
|
||||||
"parametersSet": "Ustawiono parametry",
|
|
||||||
"parametersNotSet": "Nie ustawiono parametrów",
|
"parametersNotSet": "Nie ustawiono parametrów",
|
||||||
"metadataLoadFailed": "Błąd wczytywania metadanych"
|
"metadataLoadFailed": "Błąd wczytywania metadanych"
|
||||||
},
|
},
|
||||||
|
@ -310,7 +310,6 @@
|
|||||||
"canvasMerged": "Tela Fundida",
|
"canvasMerged": "Tela Fundida",
|
||||||
"sentToImageToImage": "Mandar Para Imagem Para Imagem",
|
"sentToImageToImage": "Mandar Para Imagem Para Imagem",
|
||||||
"sentToUnifiedCanvas": "Enviada para a Tela Unificada",
|
"sentToUnifiedCanvas": "Enviada para a Tela Unificada",
|
||||||
"parametersSet": "Parâmetros Definidos",
|
|
||||||
"parametersNotSet": "Parâmetros Não Definidos",
|
"parametersNotSet": "Parâmetros Não Definidos",
|
||||||
"metadataLoadFailed": "Falha ao tentar carregar metadados"
|
"metadataLoadFailed": "Falha ao tentar carregar metadados"
|
||||||
},
|
},
|
||||||
|
@ -307,7 +307,6 @@
|
|||||||
"canvasMerged": "Tela Fundida",
|
"canvasMerged": "Tela Fundida",
|
||||||
"sentToImageToImage": "Mandar Para Imagem Para Imagem",
|
"sentToImageToImage": "Mandar Para Imagem Para Imagem",
|
||||||
"sentToUnifiedCanvas": "Enviada para a Tela Unificada",
|
"sentToUnifiedCanvas": "Enviada para a Tela Unificada",
|
||||||
"parametersSet": "Parâmetros Definidos",
|
|
||||||
"parametersNotSet": "Parâmetros Não Definidos",
|
"parametersNotSet": "Parâmetros Não Definidos",
|
||||||
"metadataLoadFailed": "Falha ao tentar carregar metadados"
|
"metadataLoadFailed": "Falha ao tentar carregar metadados"
|
||||||
},
|
},
|
||||||
|
@ -575,7 +575,6 @@
|
|||||||
"canvasMerged": "Холст объединен",
|
"canvasMerged": "Холст объединен",
|
||||||
"sentToImageToImage": "Отправить в img2img",
|
"sentToImageToImage": "Отправить в img2img",
|
||||||
"sentToUnifiedCanvas": "Отправлено на Единый холст",
|
"sentToUnifiedCanvas": "Отправлено на Единый холст",
|
||||||
"parametersSet": "Параметры заданы",
|
|
||||||
"parametersNotSet": "Параметры не заданы",
|
"parametersNotSet": "Параметры не заданы",
|
||||||
"metadataLoadFailed": "Не удалось загрузить метаданные",
|
"metadataLoadFailed": "Не удалось загрузить метаданные",
|
||||||
"serverError": "Ошибка сервера",
|
"serverError": "Ошибка сервера",
|
||||||
|
@ -315,7 +315,6 @@
|
|||||||
"canvasMerged": "Полотно об'єднане",
|
"canvasMerged": "Полотно об'єднане",
|
||||||
"sentToImageToImage": "Надіслати до img2img",
|
"sentToImageToImage": "Надіслати до img2img",
|
||||||
"sentToUnifiedCanvas": "Надіслати на полотно",
|
"sentToUnifiedCanvas": "Надіслати на полотно",
|
||||||
"parametersSet": "Параметри задані",
|
|
||||||
"parametersNotSet": "Параметри не задані",
|
"parametersNotSet": "Параметри не задані",
|
||||||
"metadataLoadFailed": "Не вдалося завантажити метадані",
|
"metadataLoadFailed": "Не вдалося завантажити метадані",
|
||||||
"serverError": "Помилка сервера",
|
"serverError": "Помилка сервера",
|
||||||
|
@ -487,7 +487,6 @@
|
|||||||
"canvasMerged": "画布已合并",
|
"canvasMerged": "画布已合并",
|
||||||
"sentToImageToImage": "已发送到图生图",
|
"sentToImageToImage": "已发送到图生图",
|
||||||
"sentToUnifiedCanvas": "已发送到统一画布",
|
"sentToUnifiedCanvas": "已发送到统一画布",
|
||||||
"parametersSet": "参数已设定",
|
|
||||||
"parametersNotSet": "参数未设定",
|
"parametersNotSet": "参数未设定",
|
||||||
"metadataLoadFailed": "加载元数据失败",
|
"metadataLoadFailed": "加载元数据失败",
|
||||||
"uploadFailedInvalidUploadDesc": "必须是单张的 PNG 或 JPEG 图片",
|
"uploadFailedInvalidUploadDesc": "必须是单张的 PNG 或 JPEG 图片",
|
||||||
|
@ -1,12 +1,18 @@
|
|||||||
import { isAnyOf } from '@reduxjs/toolkit';
|
import { isAnyOf } from '@reduxjs/toolkit';
|
||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { canvasBatchIdsReset, commitStagingAreaImage, discardStagedImages } from 'features/canvas/store/canvasSlice';
|
import {
|
||||||
|
canvasBatchIdsReset,
|
||||||
|
commitStagingAreaImage,
|
||||||
|
discardStagedImages,
|
||||||
|
resetCanvas,
|
||||||
|
setInitialCanvasImage,
|
||||||
|
} from 'features/canvas/store/canvasSlice';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
|
|
||||||
const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages);
|
const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages, resetCanvas, setInitialCanvasImage);
|
||||||
|
|
||||||
export const addCommitStagingAreaImageListener = (startAppListening: AppStartListening) => {
|
export const addCommitStagingAreaImageListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
|
@ -49,14 +49,20 @@ const selector = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
|||||||
const ClearStagingIntermediatesIconButton = () => {
|
const ClearStagingIntermediatesIconButton = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const totalStagedImages = useAppSelector((s) => s.canvas.layerState.stagingArea.images.length);
|
||||||
|
|
||||||
const handleDiscardStagingArea = useCallback(() => {
|
const handleDiscardStagingArea = useCallback(() => {
|
||||||
dispatch(discardStagedImages());
|
dispatch(discardStagedImages());
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
const handleDiscardStagingImage = useCallback(() => {
|
const handleDiscardStagingImage = useCallback(() => {
|
||||||
|
// Discarding all staged images triggers cancelation of all canvas batches. It's too easy to accidentally
|
||||||
|
// click the discard button, so to prevent accidental cancelation of all batches, we only discard the current
|
||||||
|
// image if there are more than one staged images.
|
||||||
|
if (totalStagedImages > 1) {
|
||||||
dispatch(discardStagedImage());
|
dispatch(discardStagedImage());
|
||||||
}, [dispatch]);
|
}
|
||||||
|
}, [dispatch, totalStagedImages]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@ -67,6 +73,7 @@ const ClearStagingIntermediatesIconButton = () => {
|
|||||||
onClick={handleDiscardStagingImage}
|
onClick={handleDiscardStagingImage}
|
||||||
colorScheme="invokeBlue"
|
colorScheme="invokeBlue"
|
||||||
fontSize={16}
|
fontSize={16}
|
||||||
|
isDisabled={totalStagedImages <= 1}
|
||||||
/>
|
/>
|
||||||
<IconButton
|
<IconButton
|
||||||
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
|
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
|
||||||
|
@ -18,6 +18,7 @@ import {
|
|||||||
setShouldAutoSave,
|
setShouldAutoSave,
|
||||||
setShouldCropToBoundingBoxOnSave,
|
setShouldCropToBoundingBoxOnSave,
|
||||||
setShouldDarkenOutsideBoundingBox,
|
setShouldDarkenOutsideBoundingBox,
|
||||||
|
setShouldFitImageSize,
|
||||||
setShouldInvertBrushSizeScrollDirection,
|
setShouldInvertBrushSizeScrollDirection,
|
||||||
setShouldRestrictStrokesToBox,
|
setShouldRestrictStrokesToBox,
|
||||||
setShouldShowCanvasDebugInfo,
|
setShouldShowCanvasDebugInfo,
|
||||||
@ -48,6 +49,7 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
const shouldSnapToGrid = useAppSelector((s) => s.canvas.shouldSnapToGrid);
|
const shouldSnapToGrid = useAppSelector((s) => s.canvas.shouldSnapToGrid);
|
||||||
const shouldRestrictStrokesToBox = useAppSelector((s) => s.canvas.shouldRestrictStrokesToBox);
|
const shouldRestrictStrokesToBox = useAppSelector((s) => s.canvas.shouldRestrictStrokesToBox);
|
||||||
const shouldAntialias = useAppSelector((s) => s.canvas.shouldAntialias);
|
const shouldAntialias = useAppSelector((s) => s.canvas.shouldAntialias);
|
||||||
|
const shouldFitImageSize = useAppSelector((s) => s.canvas.shouldFitImageSize);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
['n'],
|
['n'],
|
||||||
@ -102,6 +104,10 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldAntialias(e.target.checked)),
|
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldAntialias(e.target.checked)),
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
const handleChangeShouldFitImageSize = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldFitImageSize(e.target.checked)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Popover>
|
<Popover>
|
||||||
@ -165,6 +171,10 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
<FormLabel>{t('unifiedCanvas.antialiasing')}</FormLabel>
|
<FormLabel>{t('unifiedCanvas.antialiasing')}</FormLabel>
|
||||||
<Checkbox isChecked={shouldAntialias} onChange={handleChangeShouldAntialias} />
|
<Checkbox isChecked={shouldAntialias} onChange={handleChangeShouldAntialias} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
<FormControl>
|
||||||
|
<FormLabel>{t('unifiedCanvas.initialFitImageSize')}</FormLabel>
|
||||||
|
<Checkbox isChecked={shouldFitImageSize} onChange={handleChangeShouldFitImageSize} />
|
||||||
|
</FormControl>
|
||||||
</FormControlGroup>
|
</FormControlGroup>
|
||||||
<ClearCanvasHistoryButtonModal />
|
<ClearCanvasHistoryButtonModal />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -66,6 +66,7 @@ const initialCanvasState: CanvasState = {
|
|||||||
shouldAutoSave: false,
|
shouldAutoSave: false,
|
||||||
shouldCropToBoundingBoxOnSave: false,
|
shouldCropToBoundingBoxOnSave: false,
|
||||||
shouldDarkenOutsideBoundingBox: false,
|
shouldDarkenOutsideBoundingBox: false,
|
||||||
|
shouldFitImageSize: true,
|
||||||
shouldInvertBrushSizeScrollDirection: false,
|
shouldInvertBrushSizeScrollDirection: false,
|
||||||
shouldLockBoundingBox: false,
|
shouldLockBoundingBox: false,
|
||||||
shouldPreserveMaskedArea: false,
|
shouldPreserveMaskedArea: false,
|
||||||
@ -144,11 +145,19 @@ export const canvasSlice = createSlice({
|
|||||||
reducer: (state, action: PayloadActionWithOptimalDimension<ImageDTO>) => {
|
reducer: (state, action: PayloadActionWithOptimalDimension<ImageDTO>) => {
|
||||||
const { width, height, image_name } = action.payload;
|
const { width, height, image_name } = action.payload;
|
||||||
const { optimalDimension } = action.meta;
|
const { optimalDimension } = action.meta;
|
||||||
const { stageDimensions } = state;
|
const { stageDimensions, shouldFitImageSize } = state;
|
||||||
|
|
||||||
const newBoundingBoxDimensions = {
|
const newBoundingBoxDimensions = shouldFitImageSize
|
||||||
|
? {
|
||||||
|
width: roundDownToMultiple(width, CANVAS_GRID_SIZE_FINE),
|
||||||
|
height: roundDownToMultiple(height, CANVAS_GRID_SIZE_FINE),
|
||||||
|
}
|
||||||
|
: {
|
||||||
width: roundDownToMultiple(clamp(width, CANVAS_GRID_SIZE_FINE, optimalDimension), CANVAS_GRID_SIZE_FINE),
|
width: roundDownToMultiple(clamp(width, CANVAS_GRID_SIZE_FINE, optimalDimension), CANVAS_GRID_SIZE_FINE),
|
||||||
height: roundDownToMultiple(clamp(height, CANVAS_GRID_SIZE_FINE, optimalDimension), CANVAS_GRID_SIZE_FINE),
|
height: roundDownToMultiple(
|
||||||
|
clamp(height, CANVAS_GRID_SIZE_FINE, optimalDimension),
|
||||||
|
CANVAS_GRID_SIZE_FINE
|
||||||
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
const newBoundingBoxCoordinates = {
|
const newBoundingBoxCoordinates = {
|
||||||
@ -181,7 +190,6 @@ export const canvasSlice = createSlice({
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
state.futureLayerStates = [];
|
state.futureLayerStates = [];
|
||||||
state.batchIds = [];
|
|
||||||
|
|
||||||
const newScale = calculateScale(
|
const newScale = calculateScale(
|
||||||
stageDimensions.width,
|
stageDimensions.width,
|
||||||
@ -277,33 +285,14 @@ export const canvasSlice = createSlice({
|
|||||||
},
|
},
|
||||||
discardStagedImages: (state) => {
|
discardStagedImages: (state) => {
|
||||||
pushToPrevLayerStates(state);
|
pushToPrevLayerStates(state);
|
||||||
|
resetStagingArea(state);
|
||||||
state.layerState.stagingArea = deepClone(initialLayerState.stagingArea);
|
|
||||||
|
|
||||||
state.futureLayerStates = [];
|
state.futureLayerStates = [];
|
||||||
state.shouldShowStagingOutline = true;
|
|
||||||
state.shouldShowStagingImage = true;
|
|
||||||
state.batchIds = [];
|
|
||||||
},
|
},
|
||||||
discardStagedImage: (state) => {
|
discardStagedImage: (state) => {
|
||||||
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
||||||
pushToPrevLayerStates(state);
|
pushToPrevLayerStates(state);
|
||||||
|
|
||||||
if (!images.length) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
images.splice(selectedImageIndex, 1);
|
images.splice(selectedImageIndex, 1);
|
||||||
|
state.layerState.stagingArea.selectedImageIndex = Math.max(0, images.length - 1);
|
||||||
if (selectedImageIndex >= images.length) {
|
|
||||||
state.layerState.stagingArea.selectedImageIndex = images.length - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!images.length) {
|
|
||||||
state.shouldShowStagingImage = false;
|
|
||||||
state.shouldShowStagingOutline = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
state.futureLayerStates = [];
|
state.futureLayerStates = [];
|
||||||
},
|
},
|
||||||
addFillRect: (state) => {
|
addFillRect: (state) => {
|
||||||
@ -417,7 +406,6 @@ export const canvasSlice = createSlice({
|
|||||||
pushToPrevLayerStates(state);
|
pushToPrevLayerStates(state);
|
||||||
state.layerState = deepClone(initialLayerState);
|
state.layerState = deepClone(initialLayerState);
|
||||||
state.futureLayerStates = [];
|
state.futureLayerStates = [];
|
||||||
state.batchIds = [];
|
|
||||||
state.boundingBoxCoordinates = {
|
state.boundingBoxCoordinates = {
|
||||||
...initialCanvasState.boundingBoxCoordinates,
|
...initialCanvasState.boundingBoxCoordinates,
|
||||||
};
|
};
|
||||||
@ -518,12 +506,9 @@ export const canvasSlice = createSlice({
|
|||||||
...imageToCommit,
|
...imageToCommit,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
state.layerState.stagingArea = deepClone(initialLayerState.stagingArea);
|
|
||||||
|
|
||||||
|
resetStagingArea(state);
|
||||||
state.futureLayerStates = [];
|
state.futureLayerStates = [];
|
||||||
state.shouldShowStagingOutline = true;
|
|
||||||
state.shouldShowStagingImage = true;
|
|
||||||
state.batchIds = [];
|
|
||||||
},
|
},
|
||||||
setBoundingBoxScaleMethod: {
|
setBoundingBoxScaleMethod: {
|
||||||
reducer: (state, action: PayloadActionWithOptimalDimension<BoundingBoxScaleMethod>) => {
|
reducer: (state, action: PayloadActionWithOptimalDimension<BoundingBoxScaleMethod>) => {
|
||||||
@ -575,6 +560,9 @@ export const canvasSlice = createSlice({
|
|||||||
setShouldAntialias: (state, action: PayloadAction<boolean>) => {
|
setShouldAntialias: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldAntialias = action.payload;
|
state.shouldAntialias = action.payload;
|
||||||
},
|
},
|
||||||
|
setShouldFitImageSize: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldFitImageSize = action.payload;
|
||||||
|
},
|
||||||
setShouldCropToBoundingBoxOnSave: (state, action: PayloadAction<boolean>) => {
|
setShouldCropToBoundingBoxOnSave: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldCropToBoundingBoxOnSave = action.payload;
|
state.shouldCropToBoundingBoxOnSave = action.payload;
|
||||||
},
|
},
|
||||||
@ -628,12 +616,19 @@ export const canvasSlice = createSlice({
|
|||||||
if (batch_status.in_progress === 0 && batch_status.pending === 0) {
|
if (batch_status.in_progress === 0 && batch_status.pending === 0) {
|
||||||
state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id);
|
state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const queueItemStatus = action.payload.data.queue_item.status;
|
||||||
|
if (queueItemStatus === 'canceled' || queueItemStatus === 'failed') {
|
||||||
|
resetStagingAreaIfEmpty(state);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
builder.addMatcher(queueApi.endpoints.clearQueue.matchFulfilled, (state) => {
|
builder.addMatcher(queueApi.endpoints.clearQueue.matchFulfilled, (state) => {
|
||||||
state.batchIds = [];
|
state.batchIds = [];
|
||||||
|
resetStagingAreaIfEmpty(state);
|
||||||
});
|
});
|
||||||
builder.addMatcher(queueApi.endpoints.cancelByBatchIds.matchFulfilled, (state, action) => {
|
builder.addMatcher(queueApi.endpoints.cancelByBatchIds.matchFulfilled, (state, action) => {
|
||||||
state.batchIds = state.batchIds.filter((id) => !action.meta.arg.originalArgs.batch_ids.includes(id));
|
state.batchIds = state.batchIds.filter((id) => !action.meta.arg.originalArgs.batch_ids.includes(id));
|
||||||
|
resetStagingAreaIfEmpty(state);
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@ -685,6 +680,7 @@ export const {
|
|||||||
setShouldRestrictStrokesToBox,
|
setShouldRestrictStrokesToBox,
|
||||||
stagingAreaInitialized,
|
stagingAreaInitialized,
|
||||||
setShouldAntialias,
|
setShouldAntialias,
|
||||||
|
setShouldFitImageSize,
|
||||||
canvasResized,
|
canvasResized,
|
||||||
canvasBatchIdAdded,
|
canvasBatchIdAdded,
|
||||||
canvasBatchIdsReset,
|
canvasBatchIdsReset,
|
||||||
@ -706,7 +702,7 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
|
|||||||
name: canvasSlice.name,
|
name: canvasSlice.name,
|
||||||
initialState: initialCanvasState,
|
initialState: initialCanvasState,
|
||||||
migrate: migrateCanvasState,
|
migrate: migrateCanvasState,
|
||||||
persistDenylist: [],
|
persistDenylist: ['shouldShowStagingImage', 'shouldShowStagingOutline'],
|
||||||
};
|
};
|
||||||
|
|
||||||
const pushToPrevLayerStates = (state: CanvasState) => {
|
const pushToPrevLayerStates = (state: CanvasState) => {
|
||||||
@ -722,3 +718,15 @@ const pushToFutureLayerStates = (state: CanvasState) => {
|
|||||||
state.futureLayerStates = state.futureLayerStates.slice(0, MAX_HISTORY);
|
state.futureLayerStates = state.futureLayerStates.slice(0, MAX_HISTORY);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const resetStagingAreaIfEmpty = (state: CanvasState) => {
|
||||||
|
if (state.batchIds.length === 0 && state.layerState.stagingArea.images.length === 0) {
|
||||||
|
resetStagingArea(state);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const resetStagingArea = (state: CanvasState) => {
|
||||||
|
state.layerState.stagingArea = { ...initialCanvasState.layerState.stagingArea };
|
||||||
|
state.shouldShowStagingImage = initialCanvasState.shouldShowStagingImage;
|
||||||
|
state.shouldShowStagingOutline = initialCanvasState.shouldShowStagingOutline;
|
||||||
|
};
|
||||||
|
@ -120,6 +120,7 @@ export interface CanvasState {
|
|||||||
shouldAutoSave: boolean;
|
shouldAutoSave: boolean;
|
||||||
shouldCropToBoundingBoxOnSave: boolean;
|
shouldCropToBoundingBoxOnSave: boolean;
|
||||||
shouldDarkenOutsideBoundingBox: boolean;
|
shouldDarkenOutsideBoundingBox: boolean;
|
||||||
|
shouldFitImageSize: boolean;
|
||||||
shouldInvertBrushSizeScrollDirection: boolean;
|
shouldInvertBrushSizeScrollDirection: boolean;
|
||||||
shouldLockBoundingBox: boolean;
|
shouldLockBoundingBox: boolean;
|
||||||
shouldPreserveMaskedArea: boolean;
|
shouldPreserveMaskedArea: boolean;
|
||||||
|
@ -33,6 +33,7 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
<MetadataItem metadata={metadata} handlers={handlers.scheduler} />
|
<MetadataItem metadata={metadata} handlers={handlers.scheduler} />
|
||||||
<MetadataItem metadata={metadata} handlers={handlers.cfgScale} />
|
<MetadataItem metadata={metadata} handlers={handlers.cfgScale} />
|
||||||
<MetadataItem metadata={metadata} handlers={handlers.cfgRescaleMultiplier} />
|
<MetadataItem metadata={metadata} handlers={handlers.cfgRescaleMultiplier} />
|
||||||
|
<MetadataItem metadata={metadata} handlers={handlers.initialImage} />
|
||||||
<MetadataItem metadata={metadata} handlers={handlers.strength} />
|
<MetadataItem metadata={metadata} handlers={handlers.strength} />
|
||||||
<MetadataItem metadata={metadata} handlers={handlers.hrfEnabled} />
|
<MetadataItem metadata={metadata} handlers={handlers.hrfEnabled} />
|
||||||
<MetadataItem metadata={metadata} handlers={handlers.hrfMethod} />
|
<MetadataItem metadata={metadata} handlers={handlers.hrfMethod} />
|
||||||
|
@ -189,6 +189,12 @@ export const handlers = {
|
|||||||
recaller: recallers.cfgScale,
|
recaller: recallers.cfgScale,
|
||||||
}),
|
}),
|
||||||
height: buildHandlers({ getLabel: () => t('metadata.height'), parser: parsers.height, recaller: recallers.height }),
|
height: buildHandlers({ getLabel: () => t('metadata.height'), parser: parsers.height, recaller: recallers.height }),
|
||||||
|
initialImage: buildHandlers({
|
||||||
|
getLabel: () => t('metadata.initImage'),
|
||||||
|
parser: parsers.initialImage,
|
||||||
|
recaller: recallers.initialImage,
|
||||||
|
renderValue: async (imageDTO) => imageDTO.image_name,
|
||||||
|
}),
|
||||||
negativePrompt: buildHandlers({
|
negativePrompt: buildHandlers({
|
||||||
getLabel: () => t('metadata.negativePrompt'),
|
getLabel: () => t('metadata.negativePrompt'),
|
||||||
parser: parsers.negativePrompt,
|
parser: parsers.negativePrompt,
|
||||||
@ -405,6 +411,6 @@ export const parseAndRecallAllMetadata = async (metadata: unknown, skip: (keyof
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
if (results.some((result) => result.status === 'fulfilled')) {
|
if (results.some((result) => result.status === 'fulfilled')) {
|
||||||
parameterSetToast(t('toast.parametersSet'));
|
parameterSetToast(t('toast.parameters'));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
import {
|
import {
|
||||||
initialControlNet,
|
initialControlNet,
|
||||||
initialIPAdapter,
|
initialIPAdapter,
|
||||||
@ -57,6 +58,8 @@ import {
|
|||||||
isParameterWidth,
|
isParameterWidth,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
import { get, isArray, isString } from 'lodash-es';
|
import { get, isArray, isString } from 'lodash-es';
|
||||||
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
import type { ImageDTO } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
isControlNetModelConfig,
|
isControlNetModelConfig,
|
||||||
isIPAdapterModelConfig,
|
isIPAdapterModelConfig,
|
||||||
@ -135,6 +138,14 @@ const parseCFGRescaleMultiplier: MetadataParseFunc<ParameterCFGRescaleMultiplier
|
|||||||
const parseScheduler: MetadataParseFunc<ParameterScheduler> = (metadata) =>
|
const parseScheduler: MetadataParseFunc<ParameterScheduler> = (metadata) =>
|
||||||
getProperty(metadata, 'scheduler', isParameterScheduler);
|
getProperty(metadata, 'scheduler', isParameterScheduler);
|
||||||
|
|
||||||
|
const parseInitialImage: MetadataParseFunc<ImageDTO> = async (metadata) => {
|
||||||
|
const imageName = await getProperty(metadata, 'init_image', isString);
|
||||||
|
const imageDTORequest = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(imageName));
|
||||||
|
const imageDTO = await imageDTORequest.unwrap();
|
||||||
|
imageDTORequest.unsubscribe();
|
||||||
|
return imageDTO;
|
||||||
|
};
|
||||||
|
|
||||||
const parseWidth: MetadataParseFunc<ParameterWidth> = (metadata) => getProperty(metadata, 'width', isParameterWidth);
|
const parseWidth: MetadataParseFunc<ParameterWidth> = (metadata) => getProperty(metadata, 'width', isParameterWidth);
|
||||||
|
|
||||||
const parseHeight: MetadataParseFunc<ParameterHeight> = (metadata) =>
|
const parseHeight: MetadataParseFunc<ParameterHeight> = (metadata) =>
|
||||||
@ -402,6 +413,7 @@ export const parsers = {
|
|||||||
cfgScale: parseCFGScale,
|
cfgScale: parseCFGScale,
|
||||||
cfgRescaleMultiplier: parseCFGRescaleMultiplier,
|
cfgRescaleMultiplier: parseCFGRescaleMultiplier,
|
||||||
scheduler: parseScheduler,
|
scheduler: parseScheduler,
|
||||||
|
initialImage: parseInitialImage,
|
||||||
width: parseWidth,
|
width: parseWidth,
|
||||||
height: parseHeight,
|
height: parseHeight,
|
||||||
steps: parseSteps,
|
steps: parseSteps,
|
||||||
|
@ -17,6 +17,7 @@ import type {
|
|||||||
import { modelSelected } from 'features/parameters/store/actions';
|
import { modelSelected } from 'features/parameters/store/actions';
|
||||||
import {
|
import {
|
||||||
heightRecalled,
|
heightRecalled,
|
||||||
|
initialImageChanged,
|
||||||
setCfgRescaleMultiplier,
|
setCfgRescaleMultiplier,
|
||||||
setCfgScale,
|
setCfgScale,
|
||||||
setImg2imgStrength,
|
setImg2imgStrength,
|
||||||
@ -61,6 +62,7 @@ import {
|
|||||||
setRefinerStart,
|
setRefinerStart,
|
||||||
setRefinerSteps,
|
setRefinerSteps,
|
||||||
} from 'features/sdxl/store/sdxlSlice';
|
} from 'features/sdxl/store/sdxlSlice';
|
||||||
|
import type { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
|
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
|
||||||
getStore().dispatch(setPositivePrompt(positivePrompt));
|
getStore().dispatch(setPositivePrompt(positivePrompt));
|
||||||
@ -94,6 +96,10 @@ const recallScheduler: MetadataRecallFunc<ParameterScheduler> = (scheduler) => {
|
|||||||
getStore().dispatch(setScheduler(scheduler));
|
getStore().dispatch(setScheduler(scheduler));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const recallInitialImage: MetadataRecallFunc<ImageDTO> = async (imageDTO) => {
|
||||||
|
getStore().dispatch(initialImageChanged(imageDTO));
|
||||||
|
};
|
||||||
|
|
||||||
const recallWidth: MetadataRecallFunc<ParameterWidth> = (width) => {
|
const recallWidth: MetadataRecallFunc<ParameterWidth> = (width) => {
|
||||||
getStore().dispatch(widthRecalled(width));
|
getStore().dispatch(widthRecalled(width));
|
||||||
};
|
};
|
||||||
@ -235,6 +241,7 @@ export const recallers = {
|
|||||||
cfgScale: recallCFGScale,
|
cfgScale: recallCFGScale,
|
||||||
cfgRescaleMultiplier: recallCFGRescaleMultiplier,
|
cfgRescaleMultiplier: recallCFGRescaleMultiplier,
|
||||||
scheduler: recallScheduler,
|
scheduler: recallScheduler,
|
||||||
|
initialImage: recallInitialImage,
|
||||||
width: recallWidth,
|
width: recallWidth,
|
||||||
height: recallHeight,
|
height: recallHeight,
|
||||||
steps: recallSteps,
|
steps: recallSteps,
|
||||||
|
@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
|
|||||||
import type { PersistConfig } from 'app/store/store';
|
import type { PersistConfig } from 'app/store/store';
|
||||||
import type { ModelType } from 'services/api/types';
|
import type { ModelType } from 'services/api/types';
|
||||||
|
|
||||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'>;
|
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
|
||||||
|
|
||||||
type ModelManagerState = {
|
type ModelManagerState = {
|
||||||
_version: 1;
|
_version: 1;
|
||||||
|
@ -74,7 +74,6 @@ export const InstallModelForm = () => {
|
|||||||
onClick={handleSubmit(onSubmit)}
|
onClick={handleSubmit(onSubmit)}
|
||||||
isDisabled={!formState.dirtyFields.location}
|
isDisabled={!formState.dirtyFields.location}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
type="submit"
|
|
||||||
size="sm"
|
size="sm"
|
||||||
>
|
>
|
||||||
{t('modelManager.install')}
|
{t('modelManager.install')}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { Flex } from '@invoke-ai/ui-library';
|
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
|
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import {
|
||||||
@ -9,10 +10,11 @@ import {
|
|||||||
useIPAdapterModels,
|
useIPAdapterModels,
|
||||||
useLoRAModels,
|
useLoRAModels,
|
||||||
useMainModels,
|
useMainModels,
|
||||||
|
useRefinerModels,
|
||||||
useT2IAdapterModels,
|
useT2IAdapterModels,
|
||||||
useVAEModels,
|
useVAEModels,
|
||||||
} from 'services/api/hooks/modelsByType';
|
} from 'services/api/hooks/modelsByType';
|
||||||
import type { AnyModelConfig, ModelType } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import { FetchingModelsLoader } from './FetchingModelsLoader';
|
import { FetchingModelsLoader } from './FetchingModelsLoader';
|
||||||
import { ModelListWrapper } from './ModelListWrapper';
|
import { ModelListWrapper } from './ModelListWrapper';
|
||||||
@ -27,6 +29,12 @@ const ModelList = () => {
|
|||||||
[mainModels, searchTerm, filteredModelType]
|
[mainModels, searchTerm, filteredModelType]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const [refinerModels, { isLoading: isLoadingRefinerModels }] = useRefinerModels();
|
||||||
|
const filteredRefinerModels = useMemo(
|
||||||
|
() => modelsFilter(refinerModels, searchTerm, filteredModelType),
|
||||||
|
[refinerModels, searchTerm, filteredModelType]
|
||||||
|
);
|
||||||
|
|
||||||
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
|
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
|
||||||
const filteredLoRAModels = useMemo(
|
const filteredLoRAModels = useMemo(
|
||||||
() => modelsFilter(loraModels, searchTerm, filteredModelType),
|
() => modelsFilter(loraModels, searchTerm, filteredModelType),
|
||||||
@ -63,6 +71,28 @@ const ModelList = () => {
|
|||||||
[vaeModels, searchTerm, filteredModelType]
|
[vaeModels, searchTerm, filteredModelType]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const totalFilteredModels = useMemo(() => {
|
||||||
|
return (
|
||||||
|
filteredMainModels.length +
|
||||||
|
filteredRefinerModels.length +
|
||||||
|
filteredLoRAModels.length +
|
||||||
|
filteredEmbeddingModels.length +
|
||||||
|
filteredControlNetModels.length +
|
||||||
|
filteredT2IAdapterModels.length +
|
||||||
|
filteredIPAdapterModels.length +
|
||||||
|
filteredVAEModels.length
|
||||||
|
);
|
||||||
|
}, [
|
||||||
|
filteredControlNetModels.length,
|
||||||
|
filteredEmbeddingModels.length,
|
||||||
|
filteredIPAdapterModels.length,
|
||||||
|
filteredLoRAModels.length,
|
||||||
|
filteredMainModels.length,
|
||||||
|
filteredRefinerModels.length,
|
||||||
|
filteredT2IAdapterModels.length,
|
||||||
|
filteredVAEModels.length,
|
||||||
|
]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ScrollableContent>
|
<ScrollableContent>
|
||||||
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
||||||
@ -71,6 +101,11 @@ const ModelList = () => {
|
|||||||
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
||||||
<ModelListWrapper title={t('modelManager.main')} modelList={filteredMainModels} key="main" />
|
<ModelListWrapper title={t('modelManager.main')} modelList={filteredMainModels} key="main" />
|
||||||
)}
|
)}
|
||||||
|
{/* Refiner Model List */}
|
||||||
|
{isLoadingRefinerModels && <FetchingModelsLoader loadingMessage="Loading Refiner Models..." />}
|
||||||
|
{!isLoadingRefinerModels && filteredRefinerModels.length > 0 && (
|
||||||
|
<ModelListWrapper title={t('sdxl.refiner')} modelList={filteredRefinerModels} key="refiner" />
|
||||||
|
)}
|
||||||
{/* LoRAs List */}
|
{/* LoRAs List */}
|
||||||
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
||||||
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
|
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
|
||||||
@ -108,6 +143,11 @@ const ModelList = () => {
|
|||||||
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
||||||
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
||||||
)}
|
)}
|
||||||
|
{totalFilteredModels === 0 && (
|
||||||
|
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||||
|
<Text>{t('modelManager.noMatchingModels')}</Text>
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</ScrollableContent>
|
</ScrollableContent>
|
||||||
);
|
);
|
||||||
@ -118,12 +158,24 @@ export default memo(ModelList);
|
|||||||
const modelsFilter = <T extends AnyModelConfig>(
|
const modelsFilter = <T extends AnyModelConfig>(
|
||||||
data: T[],
|
data: T[],
|
||||||
nameFilter: string,
|
nameFilter: string,
|
||||||
filteredModelType: ModelType | null
|
filteredModelType: FilterableModelType | null
|
||||||
): T[] => {
|
): T[] => {
|
||||||
return data.filter((model) => {
|
return data.filter((model) => {
|
||||||
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
||||||
const matchesType = filteredModelType ? model.type === filteredModelType : true;
|
const matchesType = getMatchesType(model, filteredModelType);
|
||||||
|
|
||||||
return matchesFilter && matchesType;
|
return matchesFilter && matchesType;
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const getMatchesType = (modelConfig: AnyModelConfig, filteredModelType: FilterableModelType | null): boolean => {
|
||||||
|
if (filteredModelType === 'refiner') {
|
||||||
|
return modelConfig.base === 'sdxl-refiner';
|
||||||
|
}
|
||||||
|
|
||||||
|
if (filteredModelType === 'main' && modelConfig.base === 'sdxl-refiner') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredModelType ? modelConfig.type === filteredModelType : true;
|
||||||
|
};
|
||||||
|
@ -13,6 +13,7 @@ export const ModelTypeFilter = () => {
|
|||||||
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
|
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
|
||||||
() => ({
|
() => ({
|
||||||
main: t('modelManager.main'),
|
main: t('modelManager.main'),
|
||||||
|
refiner: t('sdxl.refiner'),
|
||||||
lora: 'LoRA',
|
lora: 'LoRA',
|
||||||
embedding: t('modelManager.textualInversions'),
|
embedding: t('modelManager.textualInversions'),
|
||||||
controlnet: 'ControlNet',
|
controlnet: 'ControlNet',
|
||||||
|
@ -86,7 +86,6 @@ export const ControlNetOrT2IAdapterDefaultSettings = () => {
|
|||||||
colorScheme="invokeYellow"
|
colorScheme="invokeYellow"
|
||||||
isDisabled={!formState.isDirty}
|
isDisabled={!formState.isDirty}
|
||||||
onClick={handleSubmit(onSubmit)}
|
onClick={handleSubmit(onSubmit)}
|
||||||
type="submit"
|
|
||||||
isLoading={isLoadingUpdateModel}
|
isLoading={isLoadingUpdateModel}
|
||||||
>
|
>
|
||||||
{t('common.save')}
|
{t('common.save')}
|
||||||
|
@ -116,7 +116,6 @@ export const MainModelDefaultSettings = () => {
|
|||||||
colorScheme="invokeYellow"
|
colorScheme="invokeYellow"
|
||||||
isDisabled={!formState.isDirty}
|
isDisabled={!formState.isDirty}
|
||||||
onClick={handleSubmit(onSubmit)}
|
onClick={handleSubmit(onSubmit)}
|
||||||
type="submit"
|
|
||||||
isLoading={isLoadingUpdateModel}
|
isLoading={isLoadingUpdateModel}
|
||||||
>
|
>
|
||||||
{t('common.save')}
|
{t('common.save')}
|
||||||
|
@ -88,7 +88,6 @@ export const TriggerPhrases = () => {
|
|||||||
<Button
|
<Button
|
||||||
leftIcon={<PiPlusBold />}
|
leftIcon={<PiPlusBold />}
|
||||||
size="sm"
|
size="sm"
|
||||||
type="submit"
|
|
||||||
onClick={addTriggerPhrase}
|
onClick={addTriggerPhrase}
|
||||||
isDisabled={!phrase || Boolean(errors.length)}
|
isDisabled={!phrase || Boolean(errors.length)}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
|
@ -3,6 +3,7 @@ import 'reactflow/dist/style.css';
|
|||||||
import { Flex } from '@invoke-ai/ui-library';
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
import TopPanel from 'features/nodes/components/flow/panels/TopPanel/TopPanel';
|
import TopPanel from 'features/nodes/components/flow/panels/TopPanel/TopPanel';
|
||||||
|
import { LoadWorkflowFromGraphModal } from 'features/workflowLibrary/components/LoadWorkflowFromGraphModal/LoadWorkflowFromGraphModal';
|
||||||
import { SaveWorkflowAsDialog } from 'features/workflowLibrary/components/SaveWorkflowAsDialog/SaveWorkflowAsDialog';
|
import { SaveWorkflowAsDialog } from 'features/workflowLibrary/components/SaveWorkflowAsDialog/SaveWorkflowAsDialog';
|
||||||
import type { AnimationProps } from 'framer-motion';
|
import type { AnimationProps } from 'framer-motion';
|
||||||
import { AnimatePresence, motion } from 'framer-motion';
|
import { AnimatePresence, motion } from 'framer-motion';
|
||||||
@ -61,6 +62,7 @@ const NodeEditor = () => {
|
|||||||
<BottomLeftPanel />
|
<BottomLeftPanel />
|
||||||
<MinimapPanel />
|
<MinimapPanel />
|
||||||
<SaveWorkflowAsDialog />
|
<SaveWorkflowAsDialog />
|
||||||
|
<LoadWorkflowFromGraphModal />
|
||||||
</motion.div>
|
</motion.div>
|
||||||
)}
|
)}
|
||||||
</AnimatePresence>
|
</AnimatePresence>
|
||||||
|
@ -37,34 +37,50 @@ const NumberFieldInputComponent = (
|
|||||||
);
|
);
|
||||||
|
|
||||||
const min = useMemo(() => {
|
const min = useMemo(() => {
|
||||||
|
let min = -NUMPY_RAND_MAX;
|
||||||
if (!isNil(fieldTemplate.minimum)) {
|
if (!isNil(fieldTemplate.minimum)) {
|
||||||
return fieldTemplate.minimum;
|
min = fieldTemplate.minimum;
|
||||||
}
|
}
|
||||||
if (!isNil(fieldTemplate.exclusiveMinimum)) {
|
if (!isNil(fieldTemplate.exclusiveMinimum)) {
|
||||||
return fieldTemplate.exclusiveMinimum + 0.01;
|
min = fieldTemplate.exclusiveMinimum + 0.01;
|
||||||
}
|
}
|
||||||
return;
|
return min;
|
||||||
}, [fieldTemplate.exclusiveMinimum, fieldTemplate.minimum]);
|
}, [fieldTemplate.exclusiveMinimum, fieldTemplate.minimum]);
|
||||||
|
|
||||||
const max = useMemo(() => {
|
const max = useMemo(() => {
|
||||||
|
let max = NUMPY_RAND_MAX;
|
||||||
if (!isNil(fieldTemplate.maximum)) {
|
if (!isNil(fieldTemplate.maximum)) {
|
||||||
return fieldTemplate.maximum;
|
max = fieldTemplate.maximum;
|
||||||
}
|
}
|
||||||
if (!isNil(fieldTemplate.exclusiveMaximum)) {
|
if (!isNil(fieldTemplate.exclusiveMaximum)) {
|
||||||
return fieldTemplate.exclusiveMaximum - 0.01;
|
max = fieldTemplate.exclusiveMaximum - 0.01;
|
||||||
}
|
}
|
||||||
return;
|
return max;
|
||||||
}, [fieldTemplate.exclusiveMaximum, fieldTemplate.maximum]);
|
}, [fieldTemplate.exclusiveMaximum, fieldTemplate.maximum]);
|
||||||
|
|
||||||
|
const step = useMemo(() => {
|
||||||
|
if (isNil(fieldTemplate.multipleOf)) {
|
||||||
|
return isIntegerField ? 1 : 0.1;
|
||||||
|
}
|
||||||
|
return fieldTemplate.multipleOf;
|
||||||
|
}, [fieldTemplate.multipleOf, isIntegerField]);
|
||||||
|
|
||||||
|
const fineStep = useMemo(() => {
|
||||||
|
if (isNil(fieldTemplate.multipleOf)) {
|
||||||
|
return isIntegerField ? 1 : 0.01;
|
||||||
|
}
|
||||||
|
return fieldTemplate.multipleOf;
|
||||||
|
}, [fieldTemplate.multipleOf, isIntegerField]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<CompositeNumberInput
|
<CompositeNumberInput
|
||||||
defaultValue={fieldTemplate.default}
|
defaultValue={fieldTemplate.default}
|
||||||
onChange={handleValueChanged}
|
onChange={handleValueChanged}
|
||||||
value={field.value}
|
value={field.value}
|
||||||
min={min ?? -NUMPY_RAND_MAX}
|
min={min}
|
||||||
max={max ?? NUMPY_RAND_MAX}
|
max={max}
|
||||||
step={isIntegerField ? 1 : 0.1}
|
step={step}
|
||||||
fineStep={isIntegerField ? 1 : 0.01}
|
fineStep={fineStep}
|
||||||
className="nodrag"
|
className="nodrag"
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
@ -1,26 +1,18 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import {
|
import type { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types';
|
||||||
type CreateDenoiseMaskInvocation,
|
import { isRefinerMainModelModelConfig } from 'services/api/types';
|
||||||
type ImageDTO,
|
|
||||||
isRefinerMainModelModelConfig,
|
|
||||||
type NonNullableGraph,
|
|
||||||
type SeamlessModeInvocation,
|
|
||||||
} from 'services/api/types';
|
|
||||||
|
|
||||||
import {
|
import {
|
||||||
CANVAS_OUTPUT,
|
CANVAS_OUTPUT,
|
||||||
INPAINT_IMAGE_RESIZE_UP,
|
INPAINT_CREATE_MASK,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MASK_COMBINE,
|
|
||||||
MASK_RESIZE_UP,
|
|
||||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||||
SDXL_CANVAS_INPAINT_GRAPH,
|
SDXL_CANVAS_INPAINT_GRAPH,
|
||||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||||
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||||
SDXL_MODEL_LOADER,
|
SDXL_MODEL_LOADER,
|
||||||
SDXL_REFINER_DENOISE_LATENTS,
|
SDXL_REFINER_DENOISE_LATENTS,
|
||||||
SDXL_REFINER_INPAINT_CREATE_MASK,
|
|
||||||
SDXL_REFINER_MODEL_LOADER,
|
SDXL_REFINER_MODEL_LOADER,
|
||||||
SDXL_REFINER_NEGATIVE_CONDITIONING,
|
SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||||
SDXL_REFINER_POSITIVE_CONDITIONING,
|
SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||||
@ -33,9 +25,7 @@ export const addSDXLRefinerToGraph = async (
|
|||||||
state: RootState,
|
state: RootState,
|
||||||
graph: NonNullableGraph,
|
graph: NonNullableGraph,
|
||||||
baseNodeId: string,
|
baseNodeId: string,
|
||||||
modelLoaderNodeId?: string,
|
modelLoaderNodeId?: string
|
||||||
canvasInitImage?: ImageDTO,
|
|
||||||
canvasMaskImage?: ImageDTO
|
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
const {
|
const {
|
||||||
refinerModel,
|
refinerModel,
|
||||||
@ -51,11 +41,9 @@ export const addSDXLRefinerToGraph = async (
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { seamlessXAxis, seamlessYAxis, vaePrecision } = state.generation;
|
const { seamlessXAxis, seamlessYAxis } = state.generation;
|
||||||
const { boundingBoxScaleMethod } = state.canvas;
|
const { boundingBoxScaleMethod } = state.canvas;
|
||||||
|
|
||||||
const fp32 = vaePrecision === 'fp32';
|
|
||||||
|
|
||||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(boundingBoxScaleMethod);
|
const isUsingScaledDimensions = ['auto', 'manual'].includes(boundingBoxScaleMethod);
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig);
|
const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig);
|
||||||
|
|
||||||
@ -214,67 +202,9 @@ export const addSDXLRefinerToGraph = async (
|
|||||||
);
|
);
|
||||||
|
|
||||||
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH || graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
|
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH || graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
|
||||||
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
|
|
||||||
type: 'create_denoise_mask',
|
|
||||||
id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
|
||||||
is_intermediate: true,
|
|
||||||
fp32,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (isUsingScaledDimensions) {
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: {
|
source: {
|
||||||
node_id: INPAINT_IMAGE_RESIZE_UP,
|
node_id: INPAINT_CREATE_MASK,
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
|
|
||||||
...(graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] as CreateDenoiseMaskInvocation),
|
|
||||||
image: canvasInitImage,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH) {
|
|
||||||
if (isUsingScaledDimensions) {
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: MASK_RESIZE_UP,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
|
||||||
field: 'mask',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
|
|
||||||
...(graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] as CreateDenoiseMaskInvocation),
|
|
||||||
mask: canvasMaskImage,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
|
||||||
field: 'mask',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
|
||||||
field: 'denoise_mask',
|
field: 'denoise_mask',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
|
@ -17,7 +17,6 @@ import {
|
|||||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||||
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||||
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
||||||
SDXL_REFINER_INPAINT_CREATE_MASK,
|
|
||||||
SDXL_REFINER_SEAMLESS,
|
SDXL_REFINER_SEAMLESS,
|
||||||
SDXL_TEXT_TO_IMAGE_GRAPH,
|
SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
@ -166,27 +165,6 @@ export const addVAEToGraph = async (
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (refinerModel) {
|
|
||||||
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH || graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: isSeamlessEnabled
|
|
||||||
? isUsingRefiner
|
|
||||||
? SDXL_REFINER_SEAMLESS
|
|
||||||
: SEAMLESS
|
|
||||||
: isAutoVae
|
|
||||||
? modelLoaderNodeId
|
|
||||||
: VAE_LOADER,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (vae) {
|
if (vae) {
|
||||||
upsertMetadata(graph, { vae });
|
upsertMetadata(graph, { vae });
|
||||||
}
|
}
|
||||||
|
@ -65,6 +65,11 @@ export const buildCanvasOutpaintGraph = async (
|
|||||||
infillTileSize,
|
infillTileSize,
|
||||||
infillPatchmatchDownscaleSize,
|
infillPatchmatchDownscaleSize,
|
||||||
infillMethod,
|
infillMethod,
|
||||||
|
// infillMosaicTileWidth,
|
||||||
|
// infillMosaicTileHeight,
|
||||||
|
// infillMosaicMinColor,
|
||||||
|
// infillMosaicMaxColor,
|
||||||
|
infillColorValue,
|
||||||
clipSkip,
|
clipSkip,
|
||||||
seamlessXAxis,
|
seamlessXAxis,
|
||||||
seamlessYAxis,
|
seamlessYAxis,
|
||||||
@ -356,6 +361,28 @@ export const buildCanvasOutpaintGraph = async (
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: add mosaic back
|
||||||
|
// if (infillMethod === 'mosaic') {
|
||||||
|
// graph.nodes[INPAINT_INFILL] = {
|
||||||
|
// type: 'infill_mosaic',
|
||||||
|
// id: INPAINT_INFILL,
|
||||||
|
// is_intermediate,
|
||||||
|
// tile_width: infillMosaicTileWidth,
|
||||||
|
// tile_height: infillMosaicTileHeight,
|
||||||
|
// min_color: infillMosaicMinColor,
|
||||||
|
// max_color: infillMosaicMaxColor,
|
||||||
|
// };
|
||||||
|
// }
|
||||||
|
|
||||||
|
if (infillMethod === 'color') {
|
||||||
|
graph.nodes[INPAINT_INFILL] = {
|
||||||
|
type: 'infill_rgba',
|
||||||
|
id: INPAINT_INFILL,
|
||||||
|
color: infillColorValue,
|
||||||
|
is_intermediate,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// Handle Scale Before Processing
|
// Handle Scale Before Processing
|
||||||
if (isUsingScaledDimensions) {
|
if (isUsingScaledDimensions) {
|
||||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||||
|
@ -133,7 +133,7 @@ export const buildCanvasSDXLInpaintGraph = async (
|
|||||||
id: INPAINT_CREATE_MASK,
|
id: INPAINT_CREATE_MASK,
|
||||||
is_intermediate,
|
is_intermediate,
|
||||||
coherence_mode: canvasCoherenceMode,
|
coherence_mode: canvasCoherenceMode,
|
||||||
minimum_denoise: canvasCoherenceMinDenoise,
|
minimum_denoise: refinerModel ? Math.max(0.2, canvasCoherenceMinDenoise) : canvasCoherenceMinDenoise,
|
||||||
edge_radius: canvasCoherenceEdgeSize,
|
edge_radius: canvasCoherenceEdgeSize,
|
||||||
},
|
},
|
||||||
[SDXL_DENOISE_LATENTS]: {
|
[SDXL_DENOISE_LATENTS]: {
|
||||||
@ -426,14 +426,7 @@ export const buildCanvasSDXLInpaintGraph = async (
|
|||||||
|
|
||||||
// Add Refiner if enabled
|
// Add Refiner if enabled
|
||||||
if (refinerModel) {
|
if (refinerModel) {
|
||||||
await addSDXLRefinerToGraph(
|
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
state,
|
|
||||||
graph,
|
|
||||||
SDXL_DENOISE_LATENTS,
|
|
||||||
modelLoaderNodeId,
|
|
||||||
canvasInitImage,
|
|
||||||
canvasMaskImage
|
|
||||||
);
|
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
}
|
}
|
||||||
|
@ -66,6 +66,11 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
|||||||
infillTileSize,
|
infillTileSize,
|
||||||
infillPatchmatchDownscaleSize,
|
infillPatchmatchDownscaleSize,
|
||||||
infillMethod,
|
infillMethod,
|
||||||
|
// infillMosaicTileWidth,
|
||||||
|
// infillMosaicTileHeight,
|
||||||
|
// infillMosaicMinColor,
|
||||||
|
// infillMosaicMaxColor,
|
||||||
|
infillColorValue,
|
||||||
seamlessXAxis,
|
seamlessXAxis,
|
||||||
seamlessYAxis,
|
seamlessYAxis,
|
||||||
canvasCoherenceMode,
|
canvasCoherenceMode,
|
||||||
@ -151,7 +156,7 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
|||||||
is_intermediate,
|
is_intermediate,
|
||||||
coherence_mode: canvasCoherenceMode,
|
coherence_mode: canvasCoherenceMode,
|
||||||
edge_radius: canvasCoherenceEdgeSize,
|
edge_radius: canvasCoherenceEdgeSize,
|
||||||
minimum_denoise: canvasCoherenceMinDenoise,
|
minimum_denoise: refinerModel ? Math.max(0.2, canvasCoherenceMinDenoise) : canvasCoherenceMinDenoise,
|
||||||
},
|
},
|
||||||
[SDXL_DENOISE_LATENTS]: {
|
[SDXL_DENOISE_LATENTS]: {
|
||||||
type: 'denoise_latents',
|
type: 'denoise_latents',
|
||||||
@ -365,6 +370,28 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: add mosaic back
|
||||||
|
// if (infillMethod === 'mosaic') {
|
||||||
|
// graph.nodes[INPAINT_INFILL] = {
|
||||||
|
// type: 'infill_mosaic',
|
||||||
|
// id: INPAINT_INFILL,
|
||||||
|
// is_intermediate,
|
||||||
|
// tile_width: infillMosaicTileWidth,
|
||||||
|
// tile_height: infillMosaicTileHeight,
|
||||||
|
// min_color: infillMosaicMinColor,
|
||||||
|
// max_color: infillMosaicMaxColor,
|
||||||
|
// };
|
||||||
|
// }
|
||||||
|
|
||||||
|
if (infillMethod === 'color') {
|
||||||
|
graph.nodes[INPAINT_INFILL] = {
|
||||||
|
type: 'infill_rgba',
|
||||||
|
id: INPAINT_INFILL,
|
||||||
|
is_intermediate,
|
||||||
|
color: infillColorValue,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// Handle Scale Before Processing
|
// Handle Scale Before Processing
|
||||||
if (isUsingScaledDimensions) {
|
if (isUsingScaledDimensions) {
|
||||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||||
@ -555,7 +582,7 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
|||||||
|
|
||||||
// Add Refiner if enabled
|
// Add Refiner if enabled
|
||||||
if (refinerModel) {
|
if (refinerModel) {
|
||||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
|
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,148 @@
|
|||||||
|
import * as dagre from '@dagrejs/dagre';
|
||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
|
import { NODE_WIDTH } from 'features/nodes/types/constants';
|
||||||
|
import type { FieldInputInstance } from 'features/nodes/types/field';
|
||||||
|
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||||
|
import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import type { NonNullableGraph } from 'services/api/types';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a graph to a workflow. This is a best-effort conversion and may not be perfect.
|
||||||
|
* For example, if a graph references an unknown node type, that node will be skipped.
|
||||||
|
* @param graph The graph to convert to a workflow
|
||||||
|
* @param autoLayout Whether to auto-layout the nodes using `dagre`. If false, nodes will be simply stacked on top of one another with an offset.
|
||||||
|
* @returns The workflow.
|
||||||
|
*/
|
||||||
|
export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): WorkflowV3 => {
|
||||||
|
const invocationTemplates = getStore().getState().nodes.templates;
|
||||||
|
|
||||||
|
if (!invocationTemplates) {
|
||||||
|
throw new Error(t('app.storeNotInitialized'));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the workflow
|
||||||
|
const workflow: WorkflowV3 = {
|
||||||
|
name: '',
|
||||||
|
author: '',
|
||||||
|
contact: '',
|
||||||
|
description: '',
|
||||||
|
meta: {
|
||||||
|
category: 'user',
|
||||||
|
version: '3.0.0',
|
||||||
|
},
|
||||||
|
notes: '',
|
||||||
|
tags: '',
|
||||||
|
version: '',
|
||||||
|
exposedFields: [],
|
||||||
|
edges: [],
|
||||||
|
nodes: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert nodes
|
||||||
|
forEach(graph.nodes, (node) => {
|
||||||
|
const template = invocationTemplates[node.type];
|
||||||
|
|
||||||
|
// Skip missing node templates - this is a best-effort
|
||||||
|
if (!template) {
|
||||||
|
logger('nodes').warn(`Node type ${node.type} not found in invocationTemplates`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build field input instances for each attr
|
||||||
|
const inputs: Record<string, FieldInputInstance> = {};
|
||||||
|
|
||||||
|
forEach(node, (value, key) => {
|
||||||
|
// Ignore the non-input keys - I think this is all of them?
|
||||||
|
if (key === 'id' || key === 'type' || key === 'is_intermediate' || key === 'use_cache') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const inputTemplate = template.inputs[key];
|
||||||
|
|
||||||
|
// Skip missing input templates
|
||||||
|
if (!inputTemplate) {
|
||||||
|
logger('nodes').warn(`Input ${key} not found in template for node type ${node.type}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This _should_ be all we need to do!
|
||||||
|
const inputInstance = buildFieldInputInstance(node.id, inputTemplate);
|
||||||
|
inputInstance.value = value;
|
||||||
|
inputs[key] = inputInstance;
|
||||||
|
});
|
||||||
|
|
||||||
|
workflow.nodes.push({
|
||||||
|
id: node.id,
|
||||||
|
type: 'invocation',
|
||||||
|
position: { x: 0, y: 0 }, // we'll do layout later, just need something here
|
||||||
|
data: {
|
||||||
|
id: node.id,
|
||||||
|
type: node.type,
|
||||||
|
version: template.version,
|
||||||
|
label: '',
|
||||||
|
notes: '',
|
||||||
|
isOpen: true,
|
||||||
|
isIntermediate: node.is_intermediate ?? false,
|
||||||
|
useCache: node.use_cache ?? true,
|
||||||
|
inputs,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
forEach(graph.edges, (edge) => {
|
||||||
|
workflow.edges.push({
|
||||||
|
id: uuidv4(), // we don't have edge IDs in the graph
|
||||||
|
type: 'default',
|
||||||
|
source: edge.source.node_id,
|
||||||
|
sourceHandle: edge.source.field,
|
||||||
|
target: edge.destination.node_id,
|
||||||
|
targetHandle: edge.destination.field,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
if (autoLayout) {
|
||||||
|
// Best-effort auto layout via dagre - not perfect but better than nothing
|
||||||
|
const dagreGraph = new dagre.graphlib.Graph();
|
||||||
|
// `rankdir` and `align` could be tweaked, but it's gonna be janky no matter what we choose
|
||||||
|
dagreGraph.setGraph({ rankdir: 'TB', align: 'UL' });
|
||||||
|
dagreGraph.setDefaultEdgeLabel(() => ({}));
|
||||||
|
|
||||||
|
// We don't know the dimensions of the nodes until we load the graph into `reactflow` - use a reasonable value
|
||||||
|
forEach(graph.nodes, (node) => {
|
||||||
|
const width = NODE_WIDTH;
|
||||||
|
const height = NODE_WIDTH * 1.5;
|
||||||
|
dagreGraph.setNode(node.id, { width, height });
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.forEach((edge) => {
|
||||||
|
dagreGraph.setEdge(edge.source.node_id, edge.destination.node_id);
|
||||||
|
});
|
||||||
|
|
||||||
|
// This does the magic
|
||||||
|
dagre.layout(dagreGraph);
|
||||||
|
|
||||||
|
// Update the workflow now that we've got the positions
|
||||||
|
workflow.nodes.forEach((node) => {
|
||||||
|
const nodeWithPosition = dagreGraph.node(node.id);
|
||||||
|
node.position = {
|
||||||
|
x: nodeWithPosition.x - nodeWithPosition.width / 2,
|
||||||
|
y: nodeWithPosition.y - nodeWithPosition.height / 2,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// Stack nodes with a 50px,50px offset from the previous ndoe
|
||||||
|
let x = 0;
|
||||||
|
let y = 0;
|
||||||
|
workflow.nodes.forEach((node) => {
|
||||||
|
node.position = { x, y };
|
||||||
|
x = x + 50;
|
||||||
|
y = y + 50;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return workflow;
|
||||||
|
};
|
@ -0,0 +1,46 @@
|
|||||||
|
import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||||
|
import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import type { RgbaColor } from 'react-colorful';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const ParamInfillColorOptions = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(selectGenerationSlice, (generation) => ({
|
||||||
|
infillColor: generation.infillColorValue,
|
||||||
|
})),
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { infillColor } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
||||||
|
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleInfillColor = useCallback(
|
||||||
|
(v: RgbaColor) => {
|
||||||
|
dispatch(setInfillColorValue(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex flexDir="column" gap={4}>
|
||||||
|
<FormControl isDisabled={infillMethod !== 'color'}>
|
||||||
|
<FormLabel>{t('parameters.infillColorValue')}</FormLabel>
|
||||||
|
<Box w="full" pt={2} pb={2}>
|
||||||
|
<IAIColorPicker color={infillColor} onChange={handleInfillColor} />
|
||||||
|
</Box>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamInfillColorOptions);
|
@ -0,0 +1,127 @@
|
|||||||
|
import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||||
|
import {
|
||||||
|
selectGenerationSlice,
|
||||||
|
setInfillMosaicMaxColor,
|
||||||
|
setInfillMosaicMinColor,
|
||||||
|
setInfillMosaicTileHeight,
|
||||||
|
setInfillMosaicTileWidth,
|
||||||
|
} from 'features/parameters/store/generationSlice';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import type { RgbaColor } from 'react-colorful';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const ParamInfillMosaicTileSize = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(selectGenerationSlice, (generation) => ({
|
||||||
|
infillMosaicTileWidth: generation.infillMosaicTileWidth,
|
||||||
|
infillMosaicTileHeight: generation.infillMosaicTileHeight,
|
||||||
|
infillMosaicMinColor: generation.infillMosaicMinColor,
|
||||||
|
infillMosaicMaxColor: generation.infillMosaicMaxColor,
|
||||||
|
})),
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { infillMosaicTileWidth, infillMosaicTileHeight, infillMosaicMinColor, infillMosaicMaxColor } =
|
||||||
|
useAppSelector(selector);
|
||||||
|
|
||||||
|
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
||||||
|
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleInfillMosaicTileWidthChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
dispatch(setInfillMosaicTileWidth(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleInfillMosaicTileHeightChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
dispatch(setInfillMosaicTileHeight(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleInfillMosaicMinColor = useCallback(
|
||||||
|
(v: RgbaColor) => {
|
||||||
|
dispatch(setInfillMosaicMinColor(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleInfillMosaicMaxColor = useCallback(
|
||||||
|
(v: RgbaColor) => {
|
||||||
|
dispatch(setInfillMosaicMaxColor(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex flexDir="column" gap={4}>
|
||||||
|
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||||
|
<FormLabel>{t('parameters.infillMosaicTileWidth')}</FormLabel>
|
||||||
|
<CompositeSlider
|
||||||
|
min={8}
|
||||||
|
max={256}
|
||||||
|
value={infillMosaicTileWidth}
|
||||||
|
defaultValue={64}
|
||||||
|
onChange={handleInfillMosaicTileWidthChange}
|
||||||
|
step={8}
|
||||||
|
fineStep={8}
|
||||||
|
marks
|
||||||
|
/>
|
||||||
|
<CompositeNumberInput
|
||||||
|
min={8}
|
||||||
|
max={1024}
|
||||||
|
value={infillMosaicTileWidth}
|
||||||
|
defaultValue={64}
|
||||||
|
onChange={handleInfillMosaicTileWidthChange}
|
||||||
|
step={8}
|
||||||
|
fineStep={8}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||||
|
<FormLabel>{t('parameters.infillMosaicTileHeight')}</FormLabel>
|
||||||
|
<CompositeSlider
|
||||||
|
min={8}
|
||||||
|
max={256}
|
||||||
|
value={infillMosaicTileHeight}
|
||||||
|
defaultValue={64}
|
||||||
|
onChange={handleInfillMosaicTileHeightChange}
|
||||||
|
step={8}
|
||||||
|
fineStep={8}
|
||||||
|
marks
|
||||||
|
/>
|
||||||
|
<CompositeNumberInput
|
||||||
|
min={8}
|
||||||
|
max={1024}
|
||||||
|
value={infillMosaicTileHeight}
|
||||||
|
defaultValue={64}
|
||||||
|
onChange={handleInfillMosaicTileHeightChange}
|
||||||
|
step={8}
|
||||||
|
fineStep={8}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||||
|
<FormLabel>{t('parameters.infillMosaicMinColor')}</FormLabel>
|
||||||
|
<Box w="full" pt={2} pb={2}>
|
||||||
|
<IAIColorPicker color={infillMosaicMinColor} onChange={handleInfillMosaicMinColor} />
|
||||||
|
</Box>
|
||||||
|
</FormControl>
|
||||||
|
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||||
|
<FormLabel>{t('parameters.infillMosaicMaxColor')}</FormLabel>
|
||||||
|
<Box w="full" pt={2} pb={2}>
|
||||||
|
<IAIColorPicker color={infillMosaicMaxColor} onChange={handleInfillMosaicMaxColor} />
|
||||||
|
</Box>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamInfillMosaicTileSize);
|
@ -1,6 +1,8 @@
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
import ParamInfillColorOptions from './ParamInfillColorOptions';
|
||||||
|
import ParamInfillMosaicOptions from './ParamInfillMosaicOptions';
|
||||||
import ParamInfillPatchmatchDownscaleSize from './ParamInfillPatchmatchDownscaleSize';
|
import ParamInfillPatchmatchDownscaleSize from './ParamInfillPatchmatchDownscaleSize';
|
||||||
import ParamInfillTilesize from './ParamInfillTilesize';
|
import ParamInfillTilesize from './ParamInfillTilesize';
|
||||||
|
|
||||||
@ -14,6 +16,14 @@ const ParamInfillOptions = () => {
|
|||||||
return <ParamInfillPatchmatchDownscaleSize />;
|
return <ParamInfillPatchmatchDownscaleSize />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (infillMethod === 'mosaic') {
|
||||||
|
return <ParamInfillMosaicOptions />;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (infillMethod === 'color') {
|
||||||
|
return <ParamInfillColorOptions />;
|
||||||
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ import type {
|
|||||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
|
import type { RgbaColor } from 'react-colorful';
|
||||||
import type { ImageDTO } from 'services/api/types';
|
import type { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
import type { GenerationState } from './types';
|
import type { GenerationState } from './types';
|
||||||
@ -43,8 +44,6 @@ const initialGenerationState: GenerationState = {
|
|||||||
shouldFitToWidthHeight: true,
|
shouldFitToWidthHeight: true,
|
||||||
shouldRandomizeSeed: true,
|
shouldRandomizeSeed: true,
|
||||||
steps: 50,
|
steps: 50,
|
||||||
infillTileSize: 32,
|
|
||||||
infillPatchmatchDownscaleSize: 1,
|
|
||||||
width: 512,
|
width: 512,
|
||||||
model: null,
|
model: null,
|
||||||
vae: null,
|
vae: null,
|
||||||
@ -55,6 +54,13 @@ const initialGenerationState: GenerationState = {
|
|||||||
shouldUseCpuNoise: true,
|
shouldUseCpuNoise: true,
|
||||||
shouldShowAdvancedOptions: false,
|
shouldShowAdvancedOptions: false,
|
||||||
aspectRatio: { ...initialAspectRatioState },
|
aspectRatio: { ...initialAspectRatioState },
|
||||||
|
infillTileSize: 32,
|
||||||
|
infillPatchmatchDownscaleSize: 1,
|
||||||
|
infillMosaicTileWidth: 64,
|
||||||
|
infillMosaicTileHeight: 64,
|
||||||
|
infillMosaicMinColor: { r: 0, g: 0, b: 0, a: 1 },
|
||||||
|
infillMosaicMaxColor: { r: 255, g: 255, b: 255, a: 1 },
|
||||||
|
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
|
||||||
};
|
};
|
||||||
|
|
||||||
export const generationSlice = createSlice({
|
export const generationSlice = createSlice({
|
||||||
@ -116,15 +122,6 @@ export const generationSlice = createSlice({
|
|||||||
setCanvasCoherenceMinDenoise: (state, action: PayloadAction<number>) => {
|
setCanvasCoherenceMinDenoise: (state, action: PayloadAction<number>) => {
|
||||||
state.canvasCoherenceMinDenoise = action.payload;
|
state.canvasCoherenceMinDenoise = action.payload;
|
||||||
},
|
},
|
||||||
setInfillMethod: (state, action: PayloadAction<string>) => {
|
|
||||||
state.infillMethod = action.payload;
|
|
||||||
},
|
|
||||||
setInfillTileSize: (state, action: PayloadAction<number>) => {
|
|
||||||
state.infillTileSize = action.payload;
|
|
||||||
},
|
|
||||||
setInfillPatchmatchDownscaleSize: (state, action: PayloadAction<number>) => {
|
|
||||||
state.infillPatchmatchDownscaleSize = action.payload;
|
|
||||||
},
|
|
||||||
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
|
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
|
||||||
const { image_name, width, height } = action.payload;
|
const { image_name, width, height } = action.payload;
|
||||||
state.initialImage = { imageName: image_name, width, height };
|
state.initialImage = { imageName: image_name, width, height };
|
||||||
@ -206,6 +203,30 @@ export const generationSlice = createSlice({
|
|||||||
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
|
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
|
||||||
state.aspectRatio = action.payload;
|
state.aspectRatio = action.payload;
|
||||||
},
|
},
|
||||||
|
setInfillMethod: (state, action: PayloadAction<string>) => {
|
||||||
|
state.infillMethod = action.payload;
|
||||||
|
},
|
||||||
|
setInfillTileSize: (state, action: PayloadAction<number>) => {
|
||||||
|
state.infillTileSize = action.payload;
|
||||||
|
},
|
||||||
|
setInfillPatchmatchDownscaleSize: (state, action: PayloadAction<number>) => {
|
||||||
|
state.infillPatchmatchDownscaleSize = action.payload;
|
||||||
|
},
|
||||||
|
setInfillMosaicTileWidth: (state, action: PayloadAction<number>) => {
|
||||||
|
state.infillMosaicTileWidth = action.payload;
|
||||||
|
},
|
||||||
|
setInfillMosaicTileHeight: (state, action: PayloadAction<number>) => {
|
||||||
|
state.infillMosaicTileHeight = action.payload;
|
||||||
|
},
|
||||||
|
setInfillMosaicMinColor: (state, action: PayloadAction<RgbaColor>) => {
|
||||||
|
state.infillMosaicMinColor = action.payload;
|
||||||
|
},
|
||||||
|
setInfillMosaicMaxColor: (state, action: PayloadAction<RgbaColor>) => {
|
||||||
|
state.infillMosaicMaxColor = action.payload;
|
||||||
|
},
|
||||||
|
setInfillColorValue: (state, action: PayloadAction<RgbaColor>) => {
|
||||||
|
state.infillColorValue = action.payload;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(configChanged, (state, action) => {
|
builder.addCase(configChanged, (state, action) => {
|
||||||
@ -249,8 +270,6 @@ export const {
|
|||||||
setShouldFitToWidthHeight,
|
setShouldFitToWidthHeight,
|
||||||
setShouldRandomizeSeed,
|
setShouldRandomizeSeed,
|
||||||
setSteps,
|
setSteps,
|
||||||
setInfillTileSize,
|
|
||||||
setInfillPatchmatchDownscaleSize,
|
|
||||||
initialImageChanged,
|
initialImageChanged,
|
||||||
modelChanged,
|
modelChanged,
|
||||||
vaeSelected,
|
vaeSelected,
|
||||||
@ -264,6 +283,13 @@ export const {
|
|||||||
heightChanged,
|
heightChanged,
|
||||||
widthRecalled,
|
widthRecalled,
|
||||||
heightRecalled,
|
heightRecalled,
|
||||||
|
setInfillTileSize,
|
||||||
|
setInfillPatchmatchDownscaleSize,
|
||||||
|
setInfillMosaicTileWidth,
|
||||||
|
setInfillMosaicTileHeight,
|
||||||
|
setInfillMosaicMinColor,
|
||||||
|
setInfillMosaicMaxColor,
|
||||||
|
setInfillColorValue,
|
||||||
} = generationSlice.actions;
|
} = generationSlice.actions;
|
||||||
|
|
||||||
export const { selectOptimalDimension } = generationSlice.selectors;
|
export const { selectOptimalDimension } = generationSlice.selectors;
|
||||||
|
@ -17,6 +17,7 @@ import type {
|
|||||||
ParameterVAEModel,
|
ParameterVAEModel,
|
||||||
ParameterWidth,
|
ParameterWidth,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import type { RgbaColor } from 'react-colorful';
|
||||||
|
|
||||||
export interface GenerationState {
|
export interface GenerationState {
|
||||||
_version: 2;
|
_version: 2;
|
||||||
@ -39,8 +40,6 @@ export interface GenerationState {
|
|||||||
shouldFitToWidthHeight: boolean;
|
shouldFitToWidthHeight: boolean;
|
||||||
shouldRandomizeSeed: boolean;
|
shouldRandomizeSeed: boolean;
|
||||||
steps: ParameterSteps;
|
steps: ParameterSteps;
|
||||||
infillTileSize: number;
|
|
||||||
infillPatchmatchDownscaleSize: number;
|
|
||||||
width: ParameterWidth;
|
width: ParameterWidth;
|
||||||
model: ParameterModel | null;
|
model: ParameterModel | null;
|
||||||
vae: ParameterVAEModel | null;
|
vae: ParameterVAEModel | null;
|
||||||
@ -51,6 +50,13 @@ export interface GenerationState {
|
|||||||
shouldUseCpuNoise: boolean;
|
shouldUseCpuNoise: boolean;
|
||||||
shouldShowAdvancedOptions: boolean;
|
shouldShowAdvancedOptions: boolean;
|
||||||
aspectRatio: AspectRatioState;
|
aspectRatio: AspectRatioState;
|
||||||
|
infillTileSize: number;
|
||||||
|
infillPatchmatchDownscaleSize: number;
|
||||||
|
infillMosaicTileWidth: number;
|
||||||
|
infillMosaicTileHeight: number;
|
||||||
|
infillMosaicMinColor: RgbaColor;
|
||||||
|
infillMosaicMaxColor: RgbaColor;
|
||||||
|
infillColorValue: RgbaColor;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;
|
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
|
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
|
|
||||||
import type { InvokeTabName } from './tabMap';
|
import type { InvokeTabName } from './tabMap';
|
||||||
@ -45,6 +46,9 @@ export const uiSlice = createSlice({
|
|||||||
builder.addCase(initialImageChanged, (state) => {
|
builder.addCase(initialImageChanged, (state) => {
|
||||||
state.activeTab = 'img2img';
|
state.activeTab = 'img2img';
|
||||||
});
|
});
|
||||||
|
builder.addCase(workflowLoadRequested, (state) => {
|
||||||
|
state.activeTab = 'nodes';
|
||||||
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -0,0 +1,111 @@
|
|||||||
|
import {
|
||||||
|
Button,
|
||||||
|
Checkbox,
|
||||||
|
Flex,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Modal,
|
||||||
|
ModalBody,
|
||||||
|
ModalCloseButton,
|
||||||
|
ModalContent,
|
||||||
|
ModalHeader,
|
||||||
|
ModalOverlay,
|
||||||
|
Spacer,
|
||||||
|
Textarea,
|
||||||
|
} from '@invoke-ai/ui-library';
|
||||||
|
import { useStore } from '@nanostores/react';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||||
|
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
|
||||||
|
import { atom } from 'nanostores';
|
||||||
|
import type { ChangeEvent } from 'react';
|
||||||
|
import { useCallback, useState } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const $isOpen = atom<boolean>(false);
|
||||||
|
|
||||||
|
export const useLoadWorkflowFromGraphModal = () => {
|
||||||
|
const isOpen = useStore($isOpen);
|
||||||
|
const onOpen = useCallback(() => {
|
||||||
|
$isOpen.set(true);
|
||||||
|
}, []);
|
||||||
|
const onClose = useCallback(() => {
|
||||||
|
$isOpen.set(false);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return { isOpen, onOpen, onClose };
|
||||||
|
};
|
||||||
|
|
||||||
|
export const LoadWorkflowFromGraphModal = () => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { isOpen, onClose } = useLoadWorkflowFromGraphModal();
|
||||||
|
const [graphRaw, setGraphRaw] = useState<string>('');
|
||||||
|
const [workflowRaw, setWorkflowRaw] = useState<string>('');
|
||||||
|
const [shouldAutoLayout, setShouldAutoLayout] = useState(true);
|
||||||
|
const onChangeGraphRaw = useCallback((e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
|
setGraphRaw(e.target.value);
|
||||||
|
}, []);
|
||||||
|
const onChangeWorkflowRaw = useCallback((e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
|
setWorkflowRaw(e.target.value);
|
||||||
|
}, []);
|
||||||
|
const onChangeShouldAutoLayout = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
setShouldAutoLayout(e.target.checked);
|
||||||
|
}, []);
|
||||||
|
const parse = useCallback(() => {
|
||||||
|
const graph = JSON.parse(graphRaw);
|
||||||
|
const workflow = graphToWorkflow(graph, shouldAutoLayout);
|
||||||
|
setWorkflowRaw(JSON.stringify(workflow, null, 2));
|
||||||
|
}, [graphRaw, shouldAutoLayout]);
|
||||||
|
const loadWorkflow = useCallback(() => {
|
||||||
|
const workflow = JSON.parse(workflowRaw);
|
||||||
|
dispatch(workflowLoadRequested({ workflow, asCopy: true }));
|
||||||
|
onClose();
|
||||||
|
}, [dispatch, onClose, workflowRaw]);
|
||||||
|
return (
|
||||||
|
<Modal isOpen={isOpen} onClose={onClose} isCentered>
|
||||||
|
<ModalOverlay />
|
||||||
|
<ModalContent w="80vw" h="80vh" maxW="unset" maxH="unset">
|
||||||
|
<ModalHeader>{t('workflows.loadFromGraph')}</ModalHeader>
|
||||||
|
<ModalCloseButton />
|
||||||
|
<ModalBody as={Flex} flexDir="column" gap={4} w="full" h="full" pb={4}>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<Button onClick={parse} size="sm" flexShrink={0}>
|
||||||
|
{t('workflows.convertGraph')}
|
||||||
|
</Button>
|
||||||
|
<FormControl>
|
||||||
|
<FormLabel>{t('workflows.autoLayout')}</FormLabel>
|
||||||
|
<Checkbox isChecked={shouldAutoLayout} onChange={onChangeShouldAutoLayout} />
|
||||||
|
</FormControl>
|
||||||
|
<Spacer />
|
||||||
|
<Button onClick={loadWorkflow} size="sm" flexShrink={0}>
|
||||||
|
{t('workflows.loadWorkflow')}
|
||||||
|
</Button>
|
||||||
|
</Flex>
|
||||||
|
<FormControl orientation="vertical" h="50%">
|
||||||
|
<FormLabel>{t('nodes.graph')}</FormLabel>
|
||||||
|
<Textarea
|
||||||
|
h="full"
|
||||||
|
value={graphRaw}
|
||||||
|
fontFamily="monospace"
|
||||||
|
whiteSpace="pre-wrap"
|
||||||
|
overflowWrap="normal"
|
||||||
|
onChange={onChangeGraphRaw}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<FormControl orientation="vertical" h="50%">
|
||||||
|
<FormLabel>{t('nodes.workflow')}</FormLabel>
|
||||||
|
<Textarea
|
||||||
|
h="full"
|
||||||
|
value={workflowRaw}
|
||||||
|
fontFamily="monospace"
|
||||||
|
whiteSpace="pre-wrap"
|
||||||
|
overflowWrap="normal"
|
||||||
|
onChange={onChangeWorkflowRaw}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</ModalBody>
|
||||||
|
</ModalContent>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,18 @@
|
|||||||
|
import { MenuItem } from '@invoke-ai/ui-library';
|
||||||
|
import { useLoadWorkflowFromGraphModal } from 'features/workflowLibrary/components/LoadWorkflowFromGraphModal/LoadWorkflowFromGraphModal';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiFlaskBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
const LoadWorkflowFromGraphMenuItem = () => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const { onOpen } = useLoadWorkflowFromGraphModal();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<MenuItem as="button" icon={<PiFlaskBold />} onClick={onOpen}>
|
||||||
|
{t('workflows.loadFromGraph')}
|
||||||
|
</MenuItem>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(LoadWorkflowFromGraphMenuItem);
|
@ -6,8 +6,10 @@ import {
|
|||||||
MenuList,
|
MenuList,
|
||||||
useDisclosure,
|
useDisclosure,
|
||||||
useGlobalMenuClose,
|
useGlobalMenuClose,
|
||||||
|
useShiftModifier,
|
||||||
} from '@invoke-ai/ui-library';
|
} from '@invoke-ai/ui-library';
|
||||||
import DownloadWorkflowMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/DownloadWorkflowMenuItem';
|
import DownloadWorkflowMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/DownloadWorkflowMenuItem';
|
||||||
|
import LoadWorkflowFromGraphMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/LoadWorkflowFromGraphMenuItem';
|
||||||
import { NewWorkflowMenuItem } from 'features/workflowLibrary/components/WorkflowLibraryMenu/NewWorkflowMenuItem';
|
import { NewWorkflowMenuItem } from 'features/workflowLibrary/components/WorkflowLibraryMenu/NewWorkflowMenuItem';
|
||||||
import SaveWorkflowAsMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/SaveWorkflowAsMenuItem';
|
import SaveWorkflowAsMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/SaveWorkflowAsMenuItem';
|
||||||
import SaveWorkflowMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/SaveWorkflowMenuItem';
|
import SaveWorkflowMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/SaveWorkflowMenuItem';
|
||||||
@ -20,6 +22,7 @@ import { PiDotsThreeOutlineFill } from 'react-icons/pi';
|
|||||||
const WorkflowLibraryMenu = () => {
|
const WorkflowLibraryMenu = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
const shift = useShiftModifier();
|
||||||
useGlobalMenuClose(onClose);
|
useGlobalMenuClose(onClose);
|
||||||
return (
|
return (
|
||||||
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose}>
|
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose}>
|
||||||
@ -38,6 +41,8 @@ const WorkflowLibraryMenu = () => {
|
|||||||
<DownloadWorkflowMenuItem />
|
<DownloadWorkflowMenuItem />
|
||||||
<MenuDivider />
|
<MenuDivider />
|
||||||
<SettingsMenuItem />
|
<SettingsMenuItem />
|
||||||
|
{shift && <MenuDivider />}
|
||||||
|
{shift && <LoadWorkflowFromGraphMenuItem />}
|
||||||
</MenuList>
|
</MenuList>
|
||||||
</Menu>
|
</Menu>
|
||||||
);
|
);
|
||||||
|
@ -134,7 +134,6 @@ export type CollectInvocation = S['CollectInvocation'];
|
|||||||
export type ImageResizeInvocation = S['ImageResizeInvocation'];
|
export type ImageResizeInvocation = S['ImageResizeInvocation'];
|
||||||
export type InfillPatchMatchInvocation = S['InfillPatchMatchInvocation'];
|
export type InfillPatchMatchInvocation = S['InfillPatchMatchInvocation'];
|
||||||
export type InfillTileInvocation = S['InfillTileInvocation'];
|
export type InfillTileInvocation = S['InfillTileInvocation'];
|
||||||
export type CreateDenoiseMaskInvocation = S['CreateDenoiseMaskInvocation'];
|
|
||||||
export type CreateGradientMaskInvocation = S['CreateGradientMaskInvocation'];
|
export type CreateGradientMaskInvocation = S['CreateGradientMaskInvocation'];
|
||||||
export type CanvasPasteBackInvocation = S['CanvasPasteBackInvocation'];
|
export type CanvasPasteBackInvocation = S['CanvasPasteBackInvocation'];
|
||||||
export type NoiseInvocation = S['NoiseInvocation'];
|
export type NoiseInvocation = S['NoiseInvocation'];
|
||||||
|