mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/nodes-phase-5
This commit is contained in:
commit
258b0814a8
@ -116,16 +116,15 @@ class CompelInvocation(BaseInvocation):
|
|||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=True,
|
truncate_long_prompts=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
|
||||||
|
|
||||||
if context.services.configuration.log_tokenization:
|
if context.services.configuration.log_tokenization:
|
||||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||||
@ -231,7 +230,7 @@ class SDXLPromptInvocationBase:
|
|||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=True, # TODO:
|
truncate_long_prompts=False, # TODO:
|
||||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||||
requires_pooled=get_pooled,
|
requires_pooled=get_pooled,
|
||||||
)
|
)
|
||||||
@ -240,8 +239,7 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
if context.services.configuration.log_tokenization:
|
if context.services.configuration.log_tokenization:
|
||||||
# TODO: better logging for and syntax
|
# TODO: better logging for and syntax
|
||||||
for prompt_obj in conjunction.prompts:
|
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||||
log_tokenization_for_prompt_object(prompt_obj, tokenizer)
|
|
||||||
|
|
||||||
# TODO: ask for optimizations? to not run text_encoder twice
|
# TODO: ask for optimizations? to not run text_encoder twice
|
||||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import diffusers
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
from diffusers.loaders import FromOriginalControlnetMixin
|
from diffusers.loaders import FromOriginalControlnetMixin
|
||||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||||
|
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||||
from diffusers.models.embeddings import (
|
from diffusers.models.embeddings import (
|
||||||
TextImageProjection,
|
TextImageProjection,
|
||||||
TextImageTimeEmbedding,
|
TextImageTimeEmbedding,
|
||||||
@ -14,16 +14,9 @@ from diffusers.models.embeddings import (
|
|||||||
Timesteps,
|
Timesteps,
|
||||||
)
|
)
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
from diffusers.models.unet_2d_blocks import (
|
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, get_down_block
|
||||||
CrossAttnDownBlock2D,
|
|
||||||
DownBlock2D,
|
|
||||||
UNetMidBlock2DCrossAttn,
|
|
||||||
get_down_block,
|
|
||||||
)
|
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from torch import nn
|
||||||
import diffusers
|
|
||||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
@ -45,7 +38,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
|||||||
Whether to flip the sin to cos in the time embedding.
|
Whether to flip the sin to cos in the time embedding.
|
||||||
freq_shift (`int`, defaults to 0):
|
freq_shift (`int`, defaults to 0):
|
||||||
The frequency shift to apply to the time embedding.
|
The frequency shift to apply to the time embedding.
|
||||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", \
|
||||||
|
"CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||||
The tuple of downsample blocks to use.
|
The tuple of downsample blocks to use.
|
||||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||||
@ -147,7 +141,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
|||||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
# when this library was created...
|
||||||
|
# The incorrect naming was only discovered much ...
|
||||||
|
# later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||||
# which is why we correct for the naming here.
|
# which is why we correct for the naming here.
|
||||||
num_attention_heads = num_attention_heads or attention_head_dim
|
num_attention_heads = num_attention_heads or attention_head_dim
|
||||||
@ -155,17 +151,20 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
|||||||
# Check inputs
|
# Check inputs
|
||||||
if len(block_out_channels) != len(down_block_types):
|
if len(block_out_channels) != len(down_block_types):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
f"Must provide the same number of `block_out_channels` as `down_block_types`. \
|
||||||
|
`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
f"Must provide the same number of `only_cross_attention` as `down_block_types`. \
|
||||||
|
`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
f"Must provide the same number of `num_attention_heads` as `down_block_types`. \
|
||||||
|
`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(transformer_layers_per_block, int):
|
if isinstance(transformer_layers_per_block, int):
|
||||||
@ -202,7 +201,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
|||||||
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
||||||
elif encoder_hid_dim_type == "text_image_proj":
|
elif encoder_hid_dim_type == "text_image_proj":
|
||||||
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
# they are set to `cross_attention_dim` here as this is exactly the required dimension ...
|
||||||
|
# for the currently only use
|
||||||
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
||||||
self.encoder_hid_proj = TextImageProjection(
|
self.encoder_hid_proj = TextImageProjection(
|
||||||
text_embed_dim=encoder_hid_dim,
|
text_embed_dim=encoder_hid_dim,
|
||||||
@ -250,8 +250,10 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
|||||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||||
)
|
)
|
||||||
elif addition_embed_type == "text_image":
|
elif addition_embed_type == "text_image":
|
||||||
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`.
|
||||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
# To not clutter the __init__ too much
|
||||||
|
# they are set to `cross_attention_dim` here as this is exactly the required dimension...
|
||||||
|
# for the currently only use
|
||||||
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
||||||
self.add_embedding = TextImageTimeEmbedding(
|
self.add_embedding = TextImageTimeEmbedding(
|
||||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||||
@ -673,12 +675,14 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
|||||||
elif self.config.addition_embed_type == "text_time":
|
elif self.config.addition_embed_type == "text_time":
|
||||||
if "text_embeds" not in added_cond_kwargs:
|
if "text_embeds" not in added_cond_kwargs:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which \
|
||||||
|
requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||||
)
|
)
|
||||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||||
if "time_ids" not in added_cond_kwargs:
|
if "time_ids" not in added_cond_kwargs:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which \
|
||||||
|
requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||||
)
|
)
|
||||||
time_ids = added_cond_kwargs.get("time_ids")
|
time_ids = added_cond_kwargs.get("time_ids")
|
||||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||||
@ -776,3 +780,49 @@ def new_LoRACompatibleConv_forward(self, x):
|
|||||||
|
|
||||||
|
|
||||||
diffusers.models.lora.LoRACompatibleConv.forward = new_LoRACompatibleConv_forward
|
diffusers.models.lora.LoRACompatibleConv.forward = new_LoRACompatibleConv_forward
|
||||||
|
|
||||||
|
try:
|
||||||
|
import xformers
|
||||||
|
|
||||||
|
xformers_available = True
|
||||||
|
except Exception:
|
||||||
|
xformers_available = False
|
||||||
|
|
||||||
|
|
||||||
|
if xformers_available:
|
||||||
|
# TODO: remove when fixed in diffusers
|
||||||
|
_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention
|
||||||
|
|
||||||
|
def new_memory_efficient_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attn_bias=None,
|
||||||
|
p: float = 0.0,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
*,
|
||||||
|
op=None,
|
||||||
|
):
|
||||||
|
# diffusers not align shape to 8, which is required by xformers
|
||||||
|
if attn_bias is not None and type(attn_bias) is torch.Tensor:
|
||||||
|
orig_size = attn_bias.shape[-1]
|
||||||
|
new_size = ((orig_size + 7) // 8) * 8
|
||||||
|
aligned_attn_bias = torch.zeros(
|
||||||
|
(attn_bias.shape[0], attn_bias.shape[1], new_size),
|
||||||
|
device=attn_bias.device,
|
||||||
|
dtype=attn_bias.dtype,
|
||||||
|
)
|
||||||
|
aligned_attn_bias[:, :, :orig_size] = attn_bias
|
||||||
|
attn_bias = aligned_attn_bias[:, :, :orig_size]
|
||||||
|
|
||||||
|
return _xformers_memory_efficient_attention(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
attn_bias=attn_bias,
|
||||||
|
p=p,
|
||||||
|
scale=scale,
|
||||||
|
op=op,
|
||||||
|
)
|
||||||
|
|
||||||
|
xformers.ops.memory_efficient_attention = new_memory_efficient_attention
|
||||||
|
@ -15,7 +15,9 @@ import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndIm
|
|||||||
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
|
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
|
||||||
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
||||||
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
||||||
|
import { addCanvasImageToControlNetListener } from './listeners/canvasImageToControlNet';
|
||||||
import { addCanvasMaskSavedToGalleryListener } from './listeners/canvasMaskSavedToGallery';
|
import { addCanvasMaskSavedToGalleryListener } from './listeners/canvasMaskSavedToGallery';
|
||||||
|
import { addCanvasMaskToControlNetListener } from './listeners/canvasMaskToControlNet';
|
||||||
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
||||||
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
|
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
|
||||||
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
||||||
@ -41,6 +43,8 @@ import {
|
|||||||
addImageUploadedFulfilledListener,
|
addImageUploadedFulfilledListener,
|
||||||
addImageUploadedRejectedListener,
|
addImageUploadedRejectedListener,
|
||||||
} from './listeners/imageUploaded';
|
} from './listeners/imageUploaded';
|
||||||
|
import { addImagesStarredListener } from './listeners/imagesStarred';
|
||||||
|
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
|
||||||
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
||||||
import { addModelSelectedListener } from './listeners/modelSelected';
|
import { addModelSelectedListener } from './listeners/modelSelected';
|
||||||
import { addModelsLoadedListener } from './listeners/modelsLoaded';
|
import { addModelsLoadedListener } from './listeners/modelsLoaded';
|
||||||
@ -80,8 +84,6 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
|||||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||||
import { addImagesStarredListener } from './listeners/imagesStarred';
|
|
||||||
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
|
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -137,6 +139,8 @@ addSessionReadyToInvokeListener();
|
|||||||
// Canvas actions
|
// Canvas actions
|
||||||
addCanvasSavedToGalleryListener();
|
addCanvasSavedToGalleryListener();
|
||||||
addCanvasMaskSavedToGalleryListener();
|
addCanvasMaskSavedToGalleryListener();
|
||||||
|
addCanvasImageToControlNetListener();
|
||||||
|
addCanvasMaskToControlNetListener();
|
||||||
addCanvasDownloadedAsImageListener();
|
addCanvasDownloadedAsImageListener();
|
||||||
addCanvasCopiedToClipboardListener();
|
addCanvasCopiedToClipboardListener();
|
||||||
addCanvasMergedListener();
|
addCanvasMergedListener();
|
||||||
|
@ -0,0 +1,58 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { canvasImageToControlNet } from 'features/canvas/store/actions';
|
||||||
|
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||||
|
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
|
export const addCanvasImageToControlNetListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: canvasImageToControlNet,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const log = logger('canvas');
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
const blob = await getBaseLayerBlob(state);
|
||||||
|
|
||||||
|
if (!blob) {
|
||||||
|
log.error('Problem getting base layer blob');
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: 'Problem Saving Canvas',
|
||||||
|
description: 'Unable to export base layer',
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { autoAddBoardId } = state.gallery;
|
||||||
|
|
||||||
|
const imageDTO = await dispatch(
|
||||||
|
imagesApi.endpoints.uploadImage.initiate({
|
||||||
|
file: new File([blob], 'savedCanvas.png', {
|
||||||
|
type: 'image/png',
|
||||||
|
}),
|
||||||
|
image_category: 'mask',
|
||||||
|
is_intermediate: false,
|
||||||
|
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||||
|
crop_visible: true,
|
||||||
|
postUploadAction: {
|
||||||
|
type: 'TOAST',
|
||||||
|
toastOptions: { title: 'Canvas Sent to ControlNet & Assets' },
|
||||||
|
},
|
||||||
|
})
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
const { image_name } = imageDTO;
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
controlNetImageChanged({
|
||||||
|
controlNetId: action.payload.controlNet.controlNetId,
|
||||||
|
controlImage: image_name,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,70 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { canvasMaskToControlNet } from 'features/canvas/store/actions';
|
||||||
|
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||||
|
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
|
export const addCanvasMaskToControlNetListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: canvasMaskToControlNet,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const log = logger('canvas');
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
const canvasBlobsAndImageData = await getCanvasData(
|
||||||
|
state.canvas.layerState,
|
||||||
|
state.canvas.boundingBoxCoordinates,
|
||||||
|
state.canvas.boundingBoxDimensions,
|
||||||
|
state.canvas.isMaskEnabled,
|
||||||
|
state.canvas.shouldPreserveMaskedArea
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!canvasBlobsAndImageData) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { maskBlob } = canvasBlobsAndImageData;
|
||||||
|
|
||||||
|
if (!maskBlob) {
|
||||||
|
log.error('Problem getting mask layer blob');
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: 'Problem Importing Mask',
|
||||||
|
description: 'Unable to export mask',
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { autoAddBoardId } = state.gallery;
|
||||||
|
|
||||||
|
const imageDTO = await dispatch(
|
||||||
|
imagesApi.endpoints.uploadImage.initiate({
|
||||||
|
file: new File([maskBlob], 'canvasMaskImage.png', {
|
||||||
|
type: 'image/png',
|
||||||
|
}),
|
||||||
|
image_category: 'mask',
|
||||||
|
is_intermediate: false,
|
||||||
|
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||||
|
crop_visible: true,
|
||||||
|
postUploadAction: {
|
||||||
|
type: 'TOAST',
|
||||||
|
toastOptions: { title: 'Mask Sent to ControlNet & Assets' },
|
||||||
|
},
|
||||||
|
})
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
const { image_name } = imageDTO;
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
controlNetImageChanged({
|
||||||
|
controlNetId: action.payload.controlNet.controlNetId,
|
||||||
|
controlImage: image_name,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -1,4 +1,5 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
|
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
|
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
|
||||||
@ -20,3 +21,11 @@ export const canvasMerged = createAction('canvas/canvasMerged');
|
|||||||
export const stagingAreaImageSaved = createAction<{ imageDTO: ImageDTO }>(
|
export const stagingAreaImageSaved = createAction<{ imageDTO: ImageDTO }>(
|
||||||
'canvas/stagingAreaImageSaved'
|
'canvas/stagingAreaImageSaved'
|
||||||
);
|
);
|
||||||
|
|
||||||
|
export const canvasMaskToControlNet = createAction<{
|
||||||
|
controlNet: ControlNetConfig;
|
||||||
|
}>('canvas/canvasMaskToControlNet');
|
||||||
|
|
||||||
|
export const canvasImageToControlNet = createAction<{
|
||||||
|
controlNet: ControlNetConfig;
|
||||||
|
}>('canvas/canvasImageToControlNet');
|
||||||
|
@ -17,11 +17,13 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { useToggle } from 'react-use';
|
import { useToggle } from 'react-use';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
import ControlNetImagePreview from './ControlNetImagePreview';
|
import ControlNetImagePreview from './ControlNetImagePreview';
|
||||||
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
|
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
|
||||||
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
|
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
|
||||||
|
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
|
||||||
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
||||||
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
|
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
|
||||||
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
||||||
@ -36,6 +38,8 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
const { controlNetId } = controlNet;
|
const { controlNetId } = controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
({ controlNet }) => {
|
({ controlNet }) => {
|
||||||
@ -108,6 +112,9 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
>
|
>
|
||||||
<ParamControlNetModel controlNet={controlNet} />
|
<ParamControlNetModel controlNet={controlNet} />
|
||||||
</Box>
|
</Box>
|
||||||
|
{activeTabName === 'unifiedCanvas' && (
|
||||||
|
<ControlNetCanvasImageImports controlNet={controlNet} />
|
||||||
|
)}
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
size="sm"
|
size="sm"
|
||||||
tooltip="Duplicate"
|
tooltip="Duplicate"
|
||||||
@ -167,6 +174,7 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
|
<Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
|
||||||
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
|
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
|
||||||
<Flex
|
<Flex
|
||||||
|
@ -10,8 +10,12 @@ import {
|
|||||||
TypesafeDroppableData,
|
TypesafeDroppableData,
|
||||||
} from 'features/dnd/types';
|
} from 'features/dnd/types';
|
||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { FaUndo } from 'react-icons/fa';
|
import { FaSave, FaUndo } from 'react-icons/fa';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import {
|
||||||
|
useAddImageToBoardMutation,
|
||||||
|
useChangeImageIsIntermediateMutation,
|
||||||
|
useGetImageDTOQuery,
|
||||||
|
} from 'services/api/endpoints/images';
|
||||||
import { PostUploadAction } from 'services/api/types';
|
import { PostUploadAction } from 'services/api/types';
|
||||||
import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon';
|
import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon';
|
||||||
import {
|
import {
|
||||||
@ -26,11 +30,13 @@ type Props = {
|
|||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
({ controlNet }) => {
|
({ controlNet, gallery }) => {
|
||||||
const { pendingControlImages } = controlNet;
|
const { pendingControlImages } = controlNet;
|
||||||
|
const { autoAddBoardId } = gallery;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
pendingControlImages,
|
pendingControlImages,
|
||||||
|
autoAddBoardId,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
@ -47,7 +53,7 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
|
|||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { pendingControlImages } = useAppSelector(selector);
|
const { pendingControlImages, autoAddBoardId } = useAppSelector(selector);
|
||||||
|
|
||||||
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
||||||
|
|
||||||
@ -59,9 +65,26 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
|
|||||||
processedControlImageName ?? skipToken
|
processedControlImageName ?? skipToken
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();
|
||||||
|
const [addToBoard] = useAddImageToBoardMutation();
|
||||||
|
|
||||||
const handleResetControlImage = useCallback(() => {
|
const handleResetControlImage = useCallback(() => {
|
||||||
dispatch(controlNetImageChanged({ controlNetId, controlImage: null }));
|
dispatch(controlNetImageChanged({ controlNetId, controlImage: null }));
|
||||||
}, [controlNetId, dispatch]);
|
}, [controlNetId, dispatch]);
|
||||||
|
|
||||||
|
const handleSaveControlImage = useCallback(() => {
|
||||||
|
if (!processedControlImage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
changeIsIntermediate({
|
||||||
|
imageDTO: processedControlImage,
|
||||||
|
is_intermediate: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
addToBoard({ imageDTO: processedControlImage, board_id: autoAddBoardId });
|
||||||
|
}, [processedControlImage, autoAddBoardId, changeIsIntermediate, addToBoard]);
|
||||||
|
|
||||||
const handleMouseEnter = useCallback(() => {
|
const handleMouseEnter = useCallback(() => {
|
||||||
setIsMouseOverImage(true);
|
setIsMouseOverImage(true);
|
||||||
}, []);
|
}, []);
|
||||||
@ -122,11 +145,19 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
|
|||||||
isDropDisabled={shouldShowProcessedImage || !isEnabled}
|
isDropDisabled={shouldShowProcessedImage || !isEnabled}
|
||||||
postUploadAction={postUploadAction}
|
postUploadAction={postUploadAction}
|
||||||
>
|
>
|
||||||
|
<>
|
||||||
<IAIDndImageIcon
|
<IAIDndImageIcon
|
||||||
onClick={handleResetControlImage}
|
onClick={handleResetControlImage}
|
||||||
icon={controlImage ? <FaUndo /> : undefined}
|
icon={controlImage ? <FaUndo /> : undefined}
|
||||||
tooltip="Reset Control Image"
|
tooltip="Reset Control Image"
|
||||||
/>
|
/>
|
||||||
|
<IAIDndImageIcon
|
||||||
|
onClick={handleSaveControlImage}
|
||||||
|
icon={controlImage ? <FaSave size={16} /> : undefined}
|
||||||
|
tooltip="Save Control Image"
|
||||||
|
styleOverrides={{ marginTop: 6 }}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
</IAIDndImage>
|
</IAIDndImage>
|
||||||
|
|
||||||
<Box
|
<Box
|
||||||
|
@ -0,0 +1,54 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import {
|
||||||
|
canvasImageToControlNet,
|
||||||
|
canvasMaskToControlNet,
|
||||||
|
} from 'features/canvas/store/actions';
|
||||||
|
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { FaImage, FaMask } from 'react-icons/fa';
|
||||||
|
|
||||||
|
type ControlNetCanvasImageImportsProps = {
|
||||||
|
controlNet: ControlNetConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ControlNetCanvasImageImports = (
|
||||||
|
props: ControlNetCanvasImageImportsProps
|
||||||
|
) => {
|
||||||
|
const { controlNet } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleImportImageFromCanvas = useCallback(() => {
|
||||||
|
dispatch(canvasImageToControlNet({ controlNet }));
|
||||||
|
}, [controlNet, dispatch]);
|
||||||
|
|
||||||
|
const handleImportMaskFromCanvas = useCallback(() => {
|
||||||
|
dispatch(canvasMaskToControlNet({ controlNet }));
|
||||||
|
}, [controlNet, dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
gap: 2,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<IAIIconButton
|
||||||
|
size="sm"
|
||||||
|
icon={<FaImage />}
|
||||||
|
tooltip="Import Image From Canvas"
|
||||||
|
aria-label="Import Image From Canvas"
|
||||||
|
onClick={handleImportImageFromCanvas}
|
||||||
|
/>
|
||||||
|
<IAIIconButton
|
||||||
|
size="sm"
|
||||||
|
icon={<FaMask />}
|
||||||
|
tooltip="Import Mask From Canvas"
|
||||||
|
aria-label="Import Mask From Canvas"
|
||||||
|
onClick={handleImportMaskFromCanvas}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ControlNetCanvasImageImports);
|
@ -36,7 +36,7 @@ dependencies = [
|
|||||||
"albumentations",
|
"albumentations",
|
||||||
"click",
|
"click",
|
||||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel~=2.0.0",
|
"compel~=2.0.2",
|
||||||
"controlnet-aux>=0.0.6",
|
"controlnet-aux>=0.0.6",
|
||||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||||
"datasets",
|
"datasets",
|
||||||
|
Loading…
Reference in New Issue
Block a user