Merge branch 'main' into feat/nodes-phase-5

This commit is contained in:
blessedcoolant 2023-08-30 02:33:49 +12:00
commit 258b0814a8
10 changed files with 322 additions and 40 deletions

View File

@ -116,16 +116,15 @@ class CompelInvocation(BaseInvocation):
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=ti_manager, textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, truncate_long_prompts=False,
) )
conjunction = Compel.parse_prompt_string(self.prompt) conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
if context.services.configuration.log_tokenization: if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer) log_tokenization_for_conjunction(conjunction, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction), tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
@ -231,7 +230,7 @@ class SDXLPromptInvocationBase:
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=ti_manager, textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO: truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled, requires_pooled=get_pooled,
) )
@ -240,8 +239,7 @@ class SDXLPromptInvocationBase:
if context.services.configuration.log_tokenization: if context.services.configuration.log_tokenization:
# TODO: better logging for and syntax # TODO: better logging for and syntax
for prompt_obj in conjunction.prompts: log_tokenization_for_conjunction(conjunction, tokenizer)
log_tokenization_for_prompt_object(prompt_obj, tokenizer)
# TODO: ask for optimizations? to not run text_encoder twice # TODO: ask for optimizations? to not run text_encoder twice
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)

View File

@ -1,11 +1,11 @@
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import diffusers
import torch import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlnetMixin from diffusers.loaders import FromOriginalControlnetMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from diffusers.models.embeddings import ( from diffusers.models.embeddings import (
TextImageProjection, TextImageProjection,
TextImageTimeEmbedding, TextImageTimeEmbedding,
@ -14,16 +14,9 @@ from diffusers.models.embeddings import (
Timesteps, Timesteps,
) )
from diffusers.models.modeling_utils import ModelMixin from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import ( from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, get_down_block
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.unet_2d_condition import UNet2DConditionModel
from torch import nn
import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -45,7 +38,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
Whether to flip the sin to cos in the time embedding. Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, defaults to 0): freq_shift (`int`, defaults to 0):
The frequency shift to apply to the time embedding. The frequency shift to apply to the time embedding.
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", \
"CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use. The tuple of downsample blocks to use.
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
@ -147,7 +141,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
# If `num_attention_heads` is not defined (which is the case for most models) # If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced # The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # when this library was created...
# The incorrect naming was only discovered much ...
# later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here. # which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim num_attention_heads = num_attention_heads or attention_head_dim
@ -155,17 +151,20 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
# Check inputs # Check inputs
if len(block_out_channels) != len(down_block_types): if len(block_out_channels) != len(down_block_types):
raise ValueError( raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." f"Must provide the same number of `block_out_channels` as `down_block_types`. \
`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
) )
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError( raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." f"Must provide the same number of `only_cross_attention` as `down_block_types`. \
`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
) )
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError( raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." f"Must provide the same number of `num_attention_heads` as `down_block_types`. \
`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
) )
if isinstance(transformer_layers_per_block, int): if isinstance(transformer_layers_per_block, int):
@ -202,7 +201,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj": elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # they are set to `cross_attention_dim` here as this is exactly the required dimension ...
# for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection( self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim, text_embed_dim=encoder_hid_dim,
@ -250,8 +250,10 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
) )
elif addition_embed_type == "text_image": elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`.
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension...
# for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding( self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
@ -673,12 +675,14 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
elif self.config.addition_embed_type == "text_time": elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs: if "text_embeds" not in added_cond_kwargs:
raise ValueError( raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which \
requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
) )
text_embeds = added_cond_kwargs.get("text_embeds") text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs: if "time_ids" not in added_cond_kwargs:
raise ValueError( raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which \
requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
) )
time_ids = added_cond_kwargs.get("time_ids") time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = self.add_time_proj(time_ids.flatten())
@ -776,3 +780,49 @@ def new_LoRACompatibleConv_forward(self, x):
diffusers.models.lora.LoRACompatibleConv.forward = new_LoRACompatibleConv_forward diffusers.models.lora.LoRACompatibleConv.forward = new_LoRACompatibleConv_forward
try:
import xformers
xformers_available = True
except Exception:
xformers_available = False
if xformers_available:
# TODO: remove when fixed in diffusers
_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention
def new_memory_efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias=None,
p: float = 0.0,
scale: Optional[float] = None,
*,
op=None,
):
# diffusers not align shape to 8, which is required by xformers
if attn_bias is not None and type(attn_bias) is torch.Tensor:
orig_size = attn_bias.shape[-1]
new_size = ((orig_size + 7) // 8) * 8
aligned_attn_bias = torch.zeros(
(attn_bias.shape[0], attn_bias.shape[1], new_size),
device=attn_bias.device,
dtype=attn_bias.dtype,
)
aligned_attn_bias[:, :, :orig_size] = attn_bias
attn_bias = aligned_attn_bias[:, :, :orig_size]
return _xformers_memory_efficient_attention(
query=query,
key=key,
value=value,
attn_bias=attn_bias,
p=p,
scale=scale,
op=op,
)
xformers.ops.memory_efficient_attention = new_memory_efficient_attention

View File

@ -15,7 +15,9 @@ import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndIm
import { addBoardIdSelectedListener } from './listeners/boardIdSelected'; import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard'; import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage'; import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasImageToControlNetListener } from './listeners/canvasImageToControlNet';
import { addCanvasMaskSavedToGalleryListener } from './listeners/canvasMaskSavedToGallery'; import { addCanvasMaskSavedToGalleryListener } from './listeners/canvasMaskSavedToGallery';
import { addCanvasMaskToControlNetListener } from './listeners/canvasMaskToControlNet';
import { addCanvasMergedListener } from './listeners/canvasMerged'; import { addCanvasMergedListener } from './listeners/canvasMerged';
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery'; import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess'; import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
@ -41,6 +43,8 @@ import {
addImageUploadedFulfilledListener, addImageUploadedFulfilledListener,
addImageUploadedRejectedListener, addImageUploadedRejectedListener,
} from './listeners/imageUploaded'; } from './listeners/imageUploaded';
import { addImagesStarredListener } from './listeners/imagesStarred';
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected'; import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addModelSelectedListener } from './listeners/modelSelected'; import { addModelSelectedListener } from './listeners/modelSelected';
import { addModelsLoadedListener } from './listeners/modelsLoaded'; import { addModelsLoadedListener } from './listeners/modelsLoaded';
@ -80,8 +84,6 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addImagesStarredListener } from './listeners/imagesStarred';
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -137,6 +139,8 @@ addSessionReadyToInvokeListener();
// Canvas actions // Canvas actions
addCanvasSavedToGalleryListener(); addCanvasSavedToGalleryListener();
addCanvasMaskSavedToGalleryListener(); addCanvasMaskSavedToGalleryListener();
addCanvasImageToControlNetListener();
addCanvasMaskToControlNetListener();
addCanvasDownloadedAsImageListener(); addCanvasDownloadedAsImageListener();
addCanvasCopiedToClipboardListener(); addCanvasCopiedToClipboardListener();
addCanvasMergedListener(); addCanvasMergedListener();

View File

@ -0,0 +1,58 @@
import { logger } from 'app/logging/logger';
import { canvasImageToControlNet } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { addToast } from 'features/system/store/systemSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
export const addCanvasImageToControlNetListener = () => {
startAppListening({
actionCreator: canvasImageToControlNet,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const blob = await getBaseLayerBlob(state);
if (!blob) {
log.error('Problem getting base layer blob');
dispatch(
addToast({
title: 'Problem Saving Canvas',
description: 'Unable to export base layer',
status: 'error',
})
);
return;
}
const { autoAddBoardId } = state.gallery;
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'savedCanvas.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
toastOptions: { title: 'Canvas Sent to ControlNet & Assets' },
},
})
).unwrap();
const { image_name } = imageDTO;
dispatch(
controlNetImageChanged({
controlNetId: action.payload.controlNet.controlNetId,
controlImage: image_name,
})
);
},
});
};

View File

@ -0,0 +1,70 @@
import { logger } from 'app/logging/logger';
import { canvasMaskToControlNet } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { addToast } from 'features/system/store/systemSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
export const addCanvasMaskToControlNetListener = () => {
startAppListening({
actionCreator: canvasMaskToControlNet,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState,
state.canvas.boundingBoxCoordinates,
state.canvas.boundingBoxDimensions,
state.canvas.isMaskEnabled,
state.canvas.shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
return;
}
const { maskBlob } = canvasBlobsAndImageData;
if (!maskBlob) {
log.error('Problem getting mask layer blob');
dispatch(
addToast({
title: 'Problem Importing Mask',
description: 'Unable to export mask',
status: 'error',
})
);
return;
}
const { autoAddBoardId } = state.gallery;
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([maskBlob], 'canvasMaskImage.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
toastOptions: { title: 'Mask Sent to ControlNet & Assets' },
},
})
).unwrap();
const { image_name } = imageDTO;
dispatch(
controlNetImageChanged({
controlNetId: action.payload.controlNet.controlNetId,
controlImage: image_name,
})
);
},
});
};

View File

@ -1,4 +1,5 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery'); export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
@ -20,3 +21,11 @@ export const canvasMerged = createAction('canvas/canvasMerged');
export const stagingAreaImageSaved = createAction<{ imageDTO: ImageDTO }>( export const stagingAreaImageSaved = createAction<{ imageDTO: ImageDTO }>(
'canvas/stagingAreaImageSaved' 'canvas/stagingAreaImageSaved'
); );
export const canvasMaskToControlNet = createAction<{
controlNet: ControlNetConfig;
}>('canvas/canvasMaskToControlNet');
export const canvasImageToControlNet = createAction<{
controlNet: ControlNetConfig;
}>('canvas/canvasImageToControlNet');

View File

@ -17,11 +17,13 @@ import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useToggle } from 'react-use'; import { useToggle } from 'react-use';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import ControlNetImagePreview from './ControlNetImagePreview'; import ControlNetImagePreview from './ControlNetImagePreview';
import ControlNetProcessorComponent from './ControlNetProcessorComponent'; import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig'; import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd'; import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode'; import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect'; import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
@ -36,6 +38,8 @@ const ControlNet = (props: ControlNetProps) => {
const { controlNetId } = controlNet; const { controlNetId } = controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const activeTabName = useAppSelector(activeTabNameSelector);
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
({ controlNet }) => { ({ controlNet }) => {
@ -108,6 +112,9 @@ const ControlNet = (props: ControlNetProps) => {
> >
<ParamControlNetModel controlNet={controlNet} /> <ParamControlNetModel controlNet={controlNet} />
</Box> </Box>
{activeTabName === 'unifiedCanvas' && (
<ControlNetCanvasImageImports controlNet={controlNet} />
)}
<IAIIconButton <IAIIconButton
size="sm" size="sm"
tooltip="Duplicate" tooltip="Duplicate"
@ -167,6 +174,7 @@ const ControlNet = (props: ControlNetProps) => {
/> />
)} )}
</Flex> </Flex>
<Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}> <Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}> <Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
<Flex <Flex

View File

@ -10,8 +10,12 @@ import {
TypesafeDroppableData, TypesafeDroppableData,
} from 'features/dnd/types'; } from 'features/dnd/types';
import { memo, useCallback, useMemo, useState } from 'react'; import { memo, useCallback, useMemo, useState } from 'react';
import { FaUndo } from 'react-icons/fa'; import { FaSave, FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import {
useAddImageToBoardMutation,
useChangeImageIsIntermediateMutation,
useGetImageDTOQuery,
} from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types'; import { PostUploadAction } from 'services/api/types';
import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon'; import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon';
import { import {
@ -26,11 +30,13 @@ type Props = {
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
({ controlNet }) => { ({ controlNet, gallery }) => {
const { pendingControlImages } = controlNet; const { pendingControlImages } = controlNet;
const { autoAddBoardId } = gallery;
return { return {
pendingControlImages, pendingControlImages,
autoAddBoardId,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -47,7 +53,7 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector); const { pendingControlImages, autoAddBoardId } = useAppSelector(selector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false); const [isMouseOverImage, setIsMouseOverImage] = useState(false);
@ -59,9 +65,26 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
processedControlImageName ?? skipToken processedControlImageName ?? skipToken
); );
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();
const [addToBoard] = useAddImageToBoardMutation();
const handleResetControlImage = useCallback(() => { const handleResetControlImage = useCallback(() => {
dispatch(controlNetImageChanged({ controlNetId, controlImage: null })); dispatch(controlNetImageChanged({ controlNetId, controlImage: null }));
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
const handleSaveControlImage = useCallback(() => {
if (!processedControlImage) {
return;
}
changeIsIntermediate({
imageDTO: processedControlImage,
is_intermediate: false,
});
addToBoard({ imageDTO: processedControlImage, board_id: autoAddBoardId });
}, [processedControlImage, autoAddBoardId, changeIsIntermediate, addToBoard]);
const handleMouseEnter = useCallback(() => { const handleMouseEnter = useCallback(() => {
setIsMouseOverImage(true); setIsMouseOverImage(true);
}, []); }, []);
@ -122,11 +145,19 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
isDropDisabled={shouldShowProcessedImage || !isEnabled} isDropDisabled={shouldShowProcessedImage || !isEnabled}
postUploadAction={postUploadAction} postUploadAction={postUploadAction}
> >
<>
<IAIDndImageIcon <IAIDndImageIcon
onClick={handleResetControlImage} onClick={handleResetControlImage}
icon={controlImage ? <FaUndo /> : undefined} icon={controlImage ? <FaUndo /> : undefined}
tooltip="Reset Control Image" tooltip="Reset Control Image"
/> />
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={controlImage ? <FaSave size={16} /> : undefined}
tooltip="Save Control Image"
styleOverrides={{ marginTop: 6 }}
/>
</>
</IAIDndImage> </IAIDndImage>
<Box <Box

View File

@ -0,0 +1,54 @@
import { Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import {
canvasImageToControlNet,
canvasMaskToControlNet,
} from 'features/canvas/store/actions';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { FaImage, FaMask } from 'react-icons/fa';
type ControlNetCanvasImageImportsProps = {
controlNet: ControlNetConfig;
};
const ControlNetCanvasImageImports = (
props: ControlNetCanvasImageImportsProps
) => {
const { controlNet } = props;
const dispatch = useAppDispatch();
const handleImportImageFromCanvas = useCallback(() => {
dispatch(canvasImageToControlNet({ controlNet }));
}, [controlNet, dispatch]);
const handleImportMaskFromCanvas = useCallback(() => {
dispatch(canvasMaskToControlNet({ controlNet }));
}, [controlNet, dispatch]);
return (
<Flex
sx={{
gap: 2,
}}
>
<IAIIconButton
size="sm"
icon={<FaImage />}
tooltip="Import Image From Canvas"
aria-label="Import Image From Canvas"
onClick={handleImportImageFromCanvas}
/>
<IAIIconButton
size="sm"
icon={<FaMask />}
tooltip="Import Mask From Canvas"
aria-label="Import Mask From Canvas"
onClick={handleImportMaskFromCanvas}
/>
</Flex>
);
};
export default memo(ControlNetCanvasImageImports);

View File

@ -36,7 +36,7 @@ dependencies = [
"albumentations", "albumentations",
"click", "click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel~=2.0.0", "compel~=2.0.2",
"controlnet-aux>=0.0.6", "controlnet-aux>=0.0.6",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets", "datasets",