Merge branch 'invoke-ai:main' into main

This commit is contained in:
Millun Atluri 2023-09-28 09:41:29 +10:00 committed by GitHub
commit ef0754cdec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 479 additions and 169 deletions

View File

@ -9,6 +9,8 @@ from diffusers.models import UNet2DConditionModel
from PIL import Image from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.model_management.models.base import calc_model_size_by_data
from .attention_processor import AttnProcessor2_0, IPAttnProcessor2_0 from .attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from .resampler import Resampler from .resampler import Resampler
@ -87,6 +89,20 @@ class IPAdapter:
if self._attn_processors is not None: if self._attn_processors is not None:
torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype) torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype)
def calc_size(self):
if self._state_dict is not None:
image_proj_size = sum(
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["image_proj"].values()]
)
ip_adapter_size = sum(
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["ip_adapter"].values()]
)
return image_proj_size + ip_adapter_size
else:
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(
torch.nn.ModuleList(self._attn_processors.values())
)
def _init_image_proj_model(self, state_dict): def _init_image_proj_model(self, state_dict):
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype) return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)

View File

@ -13,6 +13,7 @@ from invokeai.backend.model_management.models.base import (
ModelConfigBase, ModelConfigBase,
ModelType, ModelType,
SubModelType, SubModelType,
calc_model_size_by_fs,
classproperty, classproperty,
) )
@ -30,7 +31,7 @@ class IPAdapterModel(ModelBase):
assert model_type == ModelType.IPAdapter assert model_type == ModelType.IPAdapter
super().__init__(model_path, base_model, model_type) super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path) self.model_size = calc_model_size_by_fs(self.model_path)
@classmethod @classmethod
def detect_format(cls, path: str) -> str: def detect_format(cls, path: str) -> str:
@ -63,10 +64,13 @@ class IPAdapterModel(ModelBase):
if child_type is not None: if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.") raise ValueError("There are no child models in an IP-Adapter model.")
return build_ip_adapter( model = build_ip_adapter(
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
) )
self.model_size = model.calc_size()
return model
@classmethod @classmethod
def convert_if_required( def convert_if_required(
cls, cls,

View File

@ -80,7 +80,7 @@
"lightMode": "Light Mode", "lightMode": "Light Mode",
"linear": "Linear", "linear": "Linear",
"load": "Load", "load": "Load",
"loading": "Loading $t({{noun}})...", "loading": "Loading",
"loadingInvokeAI": "Loading Invoke AI", "loadingInvokeAI": "Loading Invoke AI",
"learnMore": "Learn More", "learnMore": "Learn More",
"modelManager": "Model Manager", "modelManager": "Model Manager",
@ -1444,6 +1444,8 @@
"showCanvasDebugInfo": "Show Additional Canvas Info", "showCanvasDebugInfo": "Show Additional Canvas Info",
"showGrid": "Show Grid", "showGrid": "Show Grid",
"showHide": "Show/Hide", "showHide": "Show/Hide",
"showResultsOn": "Show Results (On)",
"showResultsOff": "Show Results (Off)",
"showIntermediates": "Show Intermediates", "showIntermediates": "Show Intermediates",
"snapToGrid": "Snap to Grid", "snapToGrid": "Snap to Grid",
"undo": "Undo" "undo": "Undo"

View File

@ -1,6 +1,8 @@
import { Flex, Grid } from '@chakra-ui/react'; import { Flex, Grid } from '@chakra-ui/react';
import { useStore } from '@nanostores/react';
import { useLogger } from 'app/logging/useLogger'; import { useLogger } from 'app/logging/useLogger';
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted'; import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { $headerComponent } from 'app/store/nanostores/headerComponent';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { PartialAppConfig } from 'app/types/invokeai'; import { PartialAppConfig } from 'app/types/invokeai';
import ImageUploader from 'common/components/ImageUploader'; import ImageUploader from 'common/components/ImageUploader';
@ -14,12 +16,10 @@ import i18n from 'i18n';
import { size } from 'lodash-es'; import { size } from 'lodash-es';
import { memo, useCallback, useEffect } from 'react'; import { memo, useCallback, useEffect } from 'react';
import { ErrorBoundary } from 'react-error-boundary'; import { ErrorBoundary } from 'react-error-boundary';
import { usePreselectedImage } from '../../features/parameters/hooks/usePreselectedImage';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback'; import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import GlobalHotkeys from './GlobalHotkeys'; import GlobalHotkeys from './GlobalHotkeys';
import PreselectedImage from './PreselectedImage';
import Toaster from './Toaster'; import Toaster from './Toaster';
import { useStore } from '@nanostores/react';
import { $headerComponent } from 'app/store/nanostores/headerComponent';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};
@ -36,8 +36,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
const logger = useLogger('system'); const logger = useLogger('system');
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { handleSendToCanvas, handleSendToImg2Img, handleUseAllMetadata } =
usePreselectedImage(selectedImage?.imageName);
const handleReset = useCallback(() => { const handleReset = useCallback(() => {
localStorage.clear(); localStorage.clear();
location.reload(); location.reload();
@ -59,24 +58,6 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
dispatch(appStarted()); dispatch(appStarted());
}, [dispatch]); }, [dispatch]);
useEffect(() => {
if (selectedImage && selectedImage.action === 'sendToCanvas') {
handleSendToCanvas();
}
}, [selectedImage, handleSendToCanvas]);
useEffect(() => {
if (selectedImage && selectedImage.action === 'sendToImg2Img') {
handleSendToImg2Img();
}
}, [selectedImage, handleSendToImg2Img]);
useEffect(() => {
if (selectedImage && selectedImage.action === 'useAllParameters') {
handleUseAllMetadata();
}
}, [selectedImage, handleUseAllMetadata]);
const headerComponent = useStore($headerComponent); const headerComponent = useStore($headerComponent);
return ( return (
@ -112,6 +93,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
<ChangeBoardModal /> <ChangeBoardModal />
<Toaster /> <Toaster />
<GlobalHotkeys /> <GlobalHotkeys />
<PreselectedImage selectedImage={selectedImage} />
</ErrorBoundary> </ErrorBoundary>
); );
}; };

View File

@ -0,0 +1,16 @@
import { usePreselectedImage } from 'features/parameters/hooks/usePreselectedImage';
import { memo } from 'react';
type Props = {
selectedImage?: {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
};
const PreselectedImage = (props: Props) => {
usePreselectedImage(props.selectedImage);
return null;
};
export default memo(PreselectedImage);

View File

@ -25,7 +25,7 @@ export const addBoardIdSelectedListener = () => {
const state = getState(); const state = getState();
const board_id = boardIdSelected.match(action) const board_id = boardIdSelected.match(action)
? action.payload ? action.payload.boardId
: state.gallery.selectedBoardId; : state.gallery.selectedBoardId;
const galleryView = galleryViewChanged.match(action) const galleryView = galleryViewChanged.match(action)
@ -55,7 +55,12 @@ export const addBoardIdSelectedListener = () => {
if (boardImagesData) { if (boardImagesData) {
const firstImage = imagesSelectors.selectAll(boardImagesData)[0]; const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
dispatch(imageSelected(firstImage ?? null)); const selectedImage = imagesSelectors.selectById(
boardImagesData,
action.payload.selectedImageName
);
dispatch(imageSelected(selectedImage || firstImage || null));
} else { } else {
// board has no images - deselect // board has no images - deselect
dispatch(imageSelected(null)); dispatch(imageSelected(null));

View File

@ -81,9 +81,32 @@ export const addInvocationCompleteEventListener = () => {
// If auto-switch is enabled, select the new image // If auto-switch is enabled, select the new image
if (shouldAutoSwitch) { if (shouldAutoSwitch) {
// if auto-add is enabled, switch the board as the image comes in // if auto-add is enabled, switch the gallery view and board if needed as the image comes in
if (gallery.galleryView !== 'images') {
dispatch(galleryViewChanged('images')); dispatch(galleryViewChanged('images'));
dispatch(boardIdSelected(imageDTO.board_id ?? 'none')); }
if (
imageDTO.board_id &&
imageDTO.board_id !== gallery.selectedBoardId
) {
dispatch(
boardIdSelected({
boardId: imageDTO.board_id,
selectedImageName: imageDTO.image_name,
})
);
}
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
dispatch(
boardIdSelected({
boardId: 'none',
selectedImageName: imageDTO.image_name,
})
);
}
dispatch(imageSelected(imageDTO)); dispatch(imageSelected(imageDTO));
} }
} }

View File

@ -139,6 +139,11 @@ const IAICanvas = () => {
const { handleDragStart, handleDragMove, handleDragEnd } = const { handleDragStart, handleDragMove, handleDragEnd } =
useCanvasDragMove(); useCanvasDragMove();
const handleContextMenu = useCallback(
(e: KonvaEventObject<MouseEvent>) => e.evt.preventDefault(),
[]
);
useEffect(() => { useEffect(() => {
if (!containerRef.current) { if (!containerRef.current) {
return; return;
@ -205,9 +210,7 @@ const IAICanvas = () => {
onDragStart={handleDragStart} onDragStart={handleDragStart}
onDragMove={handleDragMove} onDragMove={handleDragMove}
onDragEnd={handleDragEnd} onDragEnd={handleDragEnd}
onContextMenu={(e: KonvaEventObject<MouseEvent>) => onContextMenu={handleContextMenu}
e.evt.preventDefault()
}
onWheel={handleWheel} onWheel={handleWheel}
draggable={(tool === 'move' || isStaging) && !isModifyingBoundingBox} draggable={(tool === 'move' || isStaging) && !isModifyingBoundingBox}
> >
@ -223,7 +226,11 @@ const IAICanvas = () => {
> >
<IAICanvasObjectRenderer /> <IAICanvasObjectRenderer />
</Layer> </Layer>
<Layer id="mask" visible={isMaskEnabled} listening={false}> <Layer
id="mask"
visible={isMaskEnabled && !isStaging}
listening={false}
>
<IAICanvasMaskLines visible={true} listening={false} /> <IAICanvasMaskLines visible={true} listening={false} />
<IAICanvasMaskCompositer listening={false} /> <IAICanvasMaskCompositer listening={false} />
</Layer> </Layer>

View File

@ -11,7 +11,7 @@ const IAICanvasImageErrorFallback = ({
canvasImage, canvasImage,
}: IAICanvasImageErrorFallbackProps) => { }: IAICanvasImageErrorFallbackProps) => {
const [errorColorLight, errorColorDark, fontColorLight, fontColorDark] = const [errorColorLight, errorColorDark, fontColorLight, fontColorDark] =
useToken('colors', ['gray.400', 'gray.500', 'base.700', 'base.900']); useToken('colors', ['base.400', 'base.500', 'base.700', 'base.900']);
const errorColor = useColorModeValue(errorColorLight, errorColorDark); const errorColor = useColorModeValue(errorColorLight, errorColorDark);
const fontColor = useColorModeValue(fontColorLight, fontColorDark); const fontColor = useColorModeValue(fontColorLight, fontColorDark);
const { t } = useTranslation(); const { t } = useTranslation();

View File

@ -3,10 +3,9 @@ import { useAppSelector } from 'app/store/storeHooks';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { GroupConfig } from 'konva/lib/Group'; import { GroupConfig } from 'konva/lib/Group';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { memo } from 'react';
import { Group, Rect } from 'react-konva'; import { Group, Rect } from 'react-konva';
import IAICanvasImage from './IAICanvasImage'; import IAICanvasImage from './IAICanvasImage';
import { memo } from 'react';
const selector = createSelector( const selector = createSelector(
[canvasSelector], [canvasSelector],
@ -15,11 +14,11 @@ const selector = createSelector(
layerState, layerState,
shouldShowStagingImage, shouldShowStagingImage,
shouldShowStagingOutline, shouldShowStagingOutline,
boundingBoxCoordinates: { x, y }, boundingBoxCoordinates: stageBoundingBoxCoordinates,
boundingBoxDimensions: { width, height }, boundingBoxDimensions: stageBoundingBoxDimensions,
} = canvas; } = canvas;
const { selectedImageIndex, images } = layerState.stagingArea; const { selectedImageIndex, images, boundingBox } = layerState.stagingArea;
return { return {
currentStagingAreaImage: currentStagingAreaImage:
@ -30,10 +29,10 @@ const selector = createSelector(
isOnLastImage: selectedImageIndex === images.length - 1, isOnLastImage: selectedImageIndex === images.length - 1,
shouldShowStagingImage, shouldShowStagingImage,
shouldShowStagingOutline, shouldShowStagingOutline,
x, x: boundingBox?.x ?? stageBoundingBoxCoordinates.x,
y, y: boundingBox?.y ?? stageBoundingBoxCoordinates.y,
width, width: boundingBox?.width ?? stageBoundingBoxDimensions.width,
height, height: boundingBox?.height ?? stageBoundingBoxDimensions.height,
}; };
}, },
{ {

View File

@ -14,6 +14,7 @@ import {
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -23,8 +24,8 @@ import {
FaCheck, FaCheck,
FaEye, FaEye,
FaEyeSlash, FaEyeSlash,
FaPlus,
FaSave, FaSave,
FaTimes,
} from 'react-icons/fa'; } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { stagingAreaImageSaved } from '../store/actions'; import { stagingAreaImageSaved } from '../store/actions';
@ -41,10 +42,10 @@ const selector = createSelector(
} = canvas; } = canvas;
return { return {
currentIndex: selectedImageIndex,
total: images.length,
currentStagingAreaImage: currentStagingAreaImage:
images.length > 0 ? images[selectedImageIndex] : undefined, images.length > 0 ? images[selectedImageIndex] : undefined,
isOnFirstImage: selectedImageIndex === 0,
isOnLastImage: selectedImageIndex === images.length - 1,
shouldShowStagingImage, shouldShowStagingImage,
shouldShowStagingOutline, shouldShowStagingOutline,
}; };
@ -55,10 +56,10 @@ const selector = createSelector(
const IAICanvasStagingAreaToolbar = () => { const IAICanvasStagingAreaToolbar = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { const {
isOnFirstImage,
isOnLastImage,
currentStagingAreaImage, currentStagingAreaImage,
shouldShowStagingImage, shouldShowStagingImage,
currentIndex,
total,
} = useAppSelector(selector); } = useAppSelector(selector);
const { t } = useTranslation(); const { t } = useTranslation();
@ -71,39 +72,6 @@ const IAICanvasStagingAreaToolbar = () => {
dispatch(setShouldShowStagingOutline(false)); dispatch(setShouldShowStagingOutline(false));
}, [dispatch]); }, [dispatch]);
useHotkeys(
['left'],
() => {
handlePrevImage();
},
{
enabled: () => true,
preventDefault: true,
}
);
useHotkeys(
['right'],
() => {
handleNextImage();
},
{
enabled: () => true,
preventDefault: true,
}
);
useHotkeys(
['enter'],
() => {
handleAccept();
},
{
enabled: () => true,
preventDefault: true,
}
);
const handlePrevImage = useCallback( const handlePrevImage = useCallback(
() => dispatch(prevStagingAreaImage()), () => dispatch(prevStagingAreaImage()),
[dispatch] [dispatch]
@ -119,64 +87,30 @@ const IAICanvasStagingAreaToolbar = () => {
[dispatch] [dispatch]
); );
useHotkeys(['left'], handlePrevImage, {
enabled: () => true,
preventDefault: true,
});
useHotkeys(['right'], handleNextImage, {
enabled: () => true,
preventDefault: true,
});
useHotkeys(['enter'], () => handleAccept, {
enabled: () => true,
preventDefault: true,
});
const { data: imageDTO } = useGetImageDTOQuery( const { data: imageDTO } = useGetImageDTOQuery(
currentStagingAreaImage?.imageName ?? skipToken currentStagingAreaImage?.imageName ?? skipToken
); );
if (!currentStagingAreaImage) { const handleToggleShouldShowStagingImage = useCallback(() => {
return null; dispatch(setShouldShowStagingImage(!shouldShowStagingImage));
} }, [dispatch, shouldShowStagingImage]);
return ( const handleSaveToGallery = useCallback(() => {
<Flex
pos="absolute"
bottom={4}
w="100%"
align="center"
justify="center"
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
<ButtonGroup isAttached borderRadius="base" shadow="dark-lg">
<IAIIconButton
tooltip={`${t('unifiedCanvas.previous')} (Left)`}
aria-label={`${t('unifiedCanvas.previous')} (Left)`}
icon={<FaArrowLeft />}
onClick={handlePrevImage}
colorScheme="accent"
isDisabled={isOnFirstImage}
/>
<IAIIconButton
tooltip={`${t('unifiedCanvas.next')} (Right)`}
aria-label={`${t('unifiedCanvas.next')} (Right)`}
icon={<FaArrowRight />}
onClick={handleNextImage}
colorScheme="accent"
isDisabled={isOnLastImage}
/>
<IAIIconButton
tooltip={`${t('unifiedCanvas.accept')} (Enter)`}
aria-label={`${t('unifiedCanvas.accept')} (Enter)`}
icon={<FaCheck />}
onClick={handleAccept}
colorScheme="accent"
/>
<IAIIconButton
tooltip={t('unifiedCanvas.showHide')}
aria-label={t('unifiedCanvas.showHide')}
data-alert={!shouldShowStagingImage}
icon={shouldShowStagingImage ? <FaEye /> : <FaEyeSlash />}
onClick={() =>
dispatch(setShouldShowStagingImage(!shouldShowStagingImage))
}
colorScheme="accent"
/>
<IAIIconButton
tooltip={t('unifiedCanvas.saveToGallery')}
aria-label={t('unifiedCanvas.saveToGallery')}
isDisabled={!imageDTO || !imageDTO.is_intermediate}
icon={<FaSave />}
onClick={() => {
if (!imageDTO) { if (!imageDTO) {
return; return;
} }
@ -186,14 +120,93 @@ const IAICanvasStagingAreaToolbar = () => {
imageDTO, imageDTO,
}) })
); );
}, [dispatch, imageDTO]);
const handleDiscardStagingArea = useCallback(() => {
dispatch(discardStagedImages());
}, [dispatch]);
if (!currentStagingAreaImage) {
return null;
}
return (
<Flex
pos="absolute"
bottom={4}
gap={2}
w="100%"
align="center"
justify="center"
onMouseEnter={handleMouseOver}
onMouseLeave={handleMouseOut}
>
<ButtonGroup isAttached borderRadius="base" shadow="dark-lg">
<IAIIconButton
tooltip={`${t('unifiedCanvas.previous')} (Left)`}
aria-label={`${t('unifiedCanvas.previous')} (Left)`}
icon={<FaArrowLeft />}
onClick={handlePrevImage}
colorScheme="accent"
isDisabled={!shouldShowStagingImage}
/>
<IAIButton
colorScheme="accent"
pointerEvents="none"
isDisabled={!shouldShowStagingImage}
sx={{
background: 'base.600',
_dark: {
background: 'base.800',
},
}} }}
>{`${currentIndex + 1}/${total}`}</IAIButton>
<IAIIconButton
tooltip={`${t('unifiedCanvas.next')} (Right)`}
aria-label={`${t('unifiedCanvas.next')} (Right)`}
icon={<FaArrowRight />}
onClick={handleNextImage}
colorScheme="accent"
isDisabled={!shouldShowStagingImage}
/>
</ButtonGroup>
<ButtonGroup isAttached borderRadius="base" shadow="dark-lg">
<IAIIconButton
tooltip={`${t('unifiedCanvas.accept')} (Enter)`}
aria-label={`${t('unifiedCanvas.accept')} (Enter)`}
icon={<FaCheck />}
onClick={handleAccept}
colorScheme="accent"
/>
<IAIIconButton
tooltip={
shouldShowStagingImage
? t('unifiedCanvas.showResultsOn')
: t('unifiedCanvas.showResultsOff')
}
aria-label={
shouldShowStagingImage
? t('unifiedCanvas.showResultsOn')
: t('unifiedCanvas.showResultsOff')
}
data-alert={!shouldShowStagingImage}
icon={shouldShowStagingImage ? <FaEye /> : <FaEyeSlash />}
onClick={handleToggleShouldShowStagingImage}
colorScheme="accent"
/>
<IAIIconButton
tooltip={t('unifiedCanvas.saveToGallery')}
aria-label={t('unifiedCanvas.saveToGallery')}
isDisabled={!imageDTO || !imageDTO.is_intermediate}
icon={<FaSave />}
onClick={handleSaveToGallery}
colorScheme="accent" colorScheme="accent"
/> />
<IAIIconButton <IAIIconButton
tooltip={t('unifiedCanvas.discardAll')} tooltip={t('unifiedCanvas.discardAll')}
aria-label={t('unifiedCanvas.discardAll')} aria-label={t('unifiedCanvas.discardAll')}
icon={<FaPlus style={{ transform: 'rotate(45deg)' }} />} icon={<FaTimes />}
onClick={() => dispatch(discardStagedImages())} onClick={handleDiscardStagingArea}
colorScheme="error" colorScheme="error"
fontSize={20} fontSize={20}
/> />

View File

@ -213,45 +213,45 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
[scaledStep] [scaledStep]
); );
const handleStartedTransforming = () => { const handleStartedTransforming = useCallback(() => {
dispatch(setIsTransformingBoundingBox(true)); dispatch(setIsTransformingBoundingBox(true));
}; }, [dispatch]);
const handleEndedTransforming = () => { const handleEndedTransforming = useCallback(() => {
dispatch(setIsTransformingBoundingBox(false)); dispatch(setIsTransformingBoundingBox(false));
dispatch(setIsMovingBoundingBox(false)); dispatch(setIsMovingBoundingBox(false));
dispatch(setIsMouseOverBoundingBox(false)); dispatch(setIsMouseOverBoundingBox(false));
setIsMouseOverBoundingBoxOutline(false); setIsMouseOverBoundingBoxOutline(false);
}; }, [dispatch]);
const handleStartedMoving = () => { const handleStartedMoving = useCallback(() => {
dispatch(setIsMovingBoundingBox(true)); dispatch(setIsMovingBoundingBox(true));
}; }, [dispatch]);
const handleEndedModifying = () => { const handleEndedModifying = useCallback(() => {
dispatch(setIsTransformingBoundingBox(false)); dispatch(setIsTransformingBoundingBox(false));
dispatch(setIsMovingBoundingBox(false)); dispatch(setIsMovingBoundingBox(false));
dispatch(setIsMouseOverBoundingBox(false)); dispatch(setIsMouseOverBoundingBox(false));
setIsMouseOverBoundingBoxOutline(false); setIsMouseOverBoundingBoxOutline(false);
}; }, [dispatch]);
const handleMouseOver = () => { const handleMouseOver = useCallback(() => {
setIsMouseOverBoundingBoxOutline(true); setIsMouseOverBoundingBoxOutline(true);
}; }, []);
const handleMouseOut = () => { const handleMouseOut = useCallback(() => {
!isTransformingBoundingBox && !isTransformingBoundingBox &&
!isMovingBoundingBox && !isMovingBoundingBox &&
setIsMouseOverBoundingBoxOutline(false); setIsMouseOverBoundingBoxOutline(false);
}; }, [isMovingBoundingBox, isTransformingBoundingBox]);
const handleMouseEnterBoundingBox = () => { const handleMouseEnterBoundingBox = useCallback(() => {
dispatch(setIsMouseOverBoundingBox(true)); dispatch(setIsMouseOverBoundingBox(true));
}; }, [dispatch]);
const handleMouseLeaveBoundingBox = () => { const handleMouseLeaveBoundingBox = useCallback(() => {
dispatch(setIsMouseOverBoundingBox(false)); dispatch(setIsMouseOverBoundingBox(false));
}; }, [dispatch]);
return ( return (
<Group {...rest}> <Group {...rest}>

View File

@ -6,7 +6,7 @@ export const canvasSelector = (state: RootState): CanvasState => state.canvas;
export const isStagingSelector = createSelector( export const isStagingSelector = createSelector(
[stateSelector], [stateSelector],
({ canvas }) => canvas.layerState.stagingArea.images.length > 0 ({ canvas }) => canvas.batchIds.length > 0
); );
export const initialCanvasImageSelector = ( export const initialCanvasImageSelector = (

View File

@ -186,7 +186,7 @@ export const canvasSlice = createSlice({
state.pastLayerStates.push(cloneDeep(state.layerState)); state.pastLayerStates.push(cloneDeep(state.layerState));
state.layerState = { state.layerState = {
...initialLayerState, ...cloneDeep(initialLayerState),
objects: [ objects: [
{ {
kind: 'image', kind: 'image',
@ -200,6 +200,7 @@ export const canvasSlice = createSlice({
], ],
}; };
state.futureLayerStates = []; state.futureLayerStates = [];
state.batchIds = [];
const newScale = calculateScale( const newScale = calculateScale(
stageDimensions.width, stageDimensions.width,
@ -349,11 +350,14 @@ export const canvasSlice = createSlice({
state.pastLayerStates.shift(); state.pastLayerStates.shift();
} }
state.layerState.stagingArea = { ...initialLayerState.stagingArea }; state.layerState.stagingArea = cloneDeep(
cloneDeep(initialLayerState)
).stagingArea;
state.futureLayerStates = []; state.futureLayerStates = [];
state.shouldShowStagingOutline = true; state.shouldShowStagingOutline = true;
state.shouldShowStagingOutline = true; state.shouldShowStagingImage = true;
state.batchIds = [];
}, },
addFillRect: (state) => { addFillRect: (state) => {
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } = const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } =
@ -490,8 +494,9 @@ export const canvasSlice = createSlice({
resetCanvas: (state) => { resetCanvas: (state) => {
state.pastLayerStates.push(cloneDeep(state.layerState)); state.pastLayerStates.push(cloneDeep(state.layerState));
state.layerState = initialLayerState; state.layerState = cloneDeep(initialLayerState);
state.futureLayerStates = []; state.futureLayerStates = [];
state.batchIds = [];
}, },
canvasResized: ( canvasResized: (
state, state,
@ -616,25 +621,22 @@ export const canvasSlice = createSlice({
return; return;
} }
const currentIndex = state.layerState.stagingArea.selectedImageIndex; const nextIndex = state.layerState.stagingArea.selectedImageIndex + 1;
const length = state.layerState.stagingArea.images.length; const lastIndex = state.layerState.stagingArea.images.length - 1;
state.layerState.stagingArea.selectedImageIndex = Math.min( state.layerState.stagingArea.selectedImageIndex =
currentIndex + 1, nextIndex > lastIndex ? 0 : nextIndex;
length - 1
);
}, },
prevStagingAreaImage: (state) => { prevStagingAreaImage: (state) => {
if (!state.layerState.stagingArea.images.length) { if (!state.layerState.stagingArea.images.length) {
return; return;
} }
const currentIndex = state.layerState.stagingArea.selectedImageIndex; const prevIndex = state.layerState.stagingArea.selectedImageIndex - 1;
const lastIndex = state.layerState.stagingArea.images.length - 1;
state.layerState.stagingArea.selectedImageIndex = Math.max( state.layerState.stagingArea.selectedImageIndex =
currentIndex - 1, prevIndex < 0 ? lastIndex : prevIndex;
0
);
}, },
commitStagingAreaImage: (state) => { commitStagingAreaImage: (state) => {
if (!state.layerState.stagingArea.images.length) { if (!state.layerState.stagingArea.images.length) {
@ -656,13 +658,12 @@ export const canvasSlice = createSlice({
...imageToCommit, ...imageToCommit,
}); });
} }
state.layerState.stagingArea = { state.layerState.stagingArea = cloneDeep(initialLayerState).stagingArea;
...initialLayerState.stagingArea,
};
state.futureLayerStates = []; state.futureLayerStates = [];
state.shouldShowStagingOutline = true; state.shouldShowStagingOutline = true;
state.shouldShowStagingImage = true; state.shouldShowStagingImage = true;
state.batchIds = [];
}, },
fitBoundingBoxToStage: (state) => { fitBoundingBoxToStage: (state) => {
const { const {

View File

@ -98,6 +98,9 @@ export const controlNetSlice = createSlice({
isControlNetEnabledToggled: (state) => { isControlNetEnabledToggled: (state) => {
state.isEnabled = !state.isEnabled; state.isEnabled = !state.isEnabled;
}, },
controlNetEnabled: (state) => {
state.isEnabled = true;
},
controlNetAdded: ( controlNetAdded: (
state, state,
action: PayloadAction<{ action: PayloadAction<{
@ -111,6 +114,12 @@ export const controlNetSlice = createSlice({
controlNetId, controlNetId,
}; };
}, },
controlNetRecalled: (state, action: PayloadAction<ControlNetConfig>) => {
const controlNet = action.payload;
state.controlNets[controlNet.controlNetId] = {
...controlNet,
};
},
controlNetDuplicated: ( controlNetDuplicated: (
state, state,
action: PayloadAction<{ action: PayloadAction<{
@ -439,7 +448,9 @@ export const controlNetSlice = createSlice({
export const { export const {
isControlNetEnabledToggled, isControlNetEnabledToggled,
controlNetEnabled,
controlNetAdded, controlNetAdded,
controlNetRecalled,
controlNetDuplicated, controlNetDuplicated,
controlNetAddedFromImage, controlNetAddedFromImage,
controlNetRemoved, controlNetRemoved,

View File

@ -93,7 +93,7 @@ const GalleryBoard = ({
const [localBoardName, setLocalBoardName] = useState(board_name); const [localBoardName, setLocalBoardName] = useState(board_name);
const handleSelectBoard = useCallback(() => { const handleSelectBoard = useCallback(() => {
dispatch(boardIdSelected(board_id)); dispatch(boardIdSelected({ boardId: board_id }));
if (autoAssignBoardOnClick) { if (autoAssignBoardOnClick) {
dispatch(autoAddBoardIdChanged(board_id)); dispatch(autoAddBoardIdChanged(board_id));
} }

View File

@ -34,7 +34,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
const { autoAddBoardId, autoAssignBoardOnClick } = useAppSelector(selector); const { autoAddBoardId, autoAssignBoardOnClick } = useAppSelector(selector);
const boardName = useBoardName('none'); const boardName = useBoardName('none');
const handleSelectBoard = useCallback(() => { const handleSelectBoard = useCallback(() => {
dispatch(boardIdSelected('none')); dispatch(boardIdSelected({ boardId: 'none' }));
if (autoAssignBoardOnClick) { if (autoAssignBoardOnClick) {
dispatch(autoAddBoardIdChanged('none')); dispatch(autoAddBoardIdChanged('none'));
} }

View File

@ -32,7 +32,7 @@ const SystemBoardButton = ({ board_id }: Props) => {
const boardName = useBoardName(board_id); const boardName = useBoardName(board_id);
const handleClick = useCallback(() => { const handleClick = useCallback(() => {
dispatch(boardIdSelected(board_id)); dispatch(boardIdSelected({ boardId: board_id }));
}, [board_id, dispatch]); }, [board_id, dispatch]);
return ( return (

View File

@ -1,8 +1,15 @@
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types'; import {
ControlNetMetadataItem,
CoreMetadata,
LoRAMetadataItem,
} from 'features/nodes/types/types';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useCallback } from 'react'; import { memo, useMemo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas'; import {
isValidControlNetModel,
isValidLoRAModel,
} from '../../../parameters/types/parameterSchemas';
import ImageMetadataItem from './ImageMetadataItem'; import ImageMetadataItem from './ImageMetadataItem';
type Props = { type Props = {
@ -26,6 +33,7 @@ const ImageMetadataActions = (props: Props) => {
recallHeight, recallHeight,
recallStrength, recallStrength,
recallLoRA, recallLoRA,
recallControlNet,
} = useRecallParameters(); } = useRecallParameters();
const handleRecallPositivePrompt = useCallback(() => { const handleRecallPositivePrompt = useCallback(() => {
@ -75,6 +83,21 @@ const ImageMetadataActions = (props: Props) => {
[recallLoRA] [recallLoRA]
); );
const handleRecallControlNet = useCallback(
(controlnet: ControlNetMetadataItem) => {
recallControlNet(controlnet);
},
[recallControlNet]
);
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
return metadata?.controlnets
? metadata.controlnets.filter((controlnet) =>
isValidControlNetModel(controlnet.control_model)
)
: [];
}, [metadata?.controlnets]);
if (!metadata || Object.keys(metadata).length === 0) { if (!metadata || Object.keys(metadata).length === 0) {
return null; return null;
} }
@ -180,6 +203,14 @@ const ImageMetadataActions = (props: Props) => {
); );
} }
})} })}
{validControlNets.map((controlnet, index) => (
<ImageMetadataItem
key={index}
label="ControlNet"
value={`${controlnet.control_model?.model_name} - ${controlnet.control_weight}`}
onClick={() => handleRecallControlNet(controlnet)}
/>
))}
</> </>
); );
}; };

View File

@ -35,8 +35,11 @@ export const gallerySlice = createSlice({
autoAssignBoardOnClickChanged: (state, action: PayloadAction<boolean>) => { autoAssignBoardOnClickChanged: (state, action: PayloadAction<boolean>) => {
state.autoAssignBoardOnClick = action.payload; state.autoAssignBoardOnClick = action.payload;
}, },
boardIdSelected: (state, action: PayloadAction<BoardId>) => { boardIdSelected: (
state.selectedBoardId = action.payload; state,
action: PayloadAction<{ boardId: BoardId; selectedImageName?: string }>
) => {
state.selectedBoardId = action.payload.boardId;
state.galleryView = 'images'; state.galleryView = 'images';
}, },
autoAddBoardIdChanged: (state, action: PayloadAction<BoardId>) => { autoAddBoardIdChanged: (state, action: PayloadAction<BoardId>) => {

View File

@ -1141,6 +1141,10 @@ const zLoRAMetadataItem = z.object({
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>; export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
const zControlNetMetadataItem = zControlField.deepPartial();
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
export const zCoreMetadata = z export const zCoreMetadata = z
.object({ .object({
app_version: z.string().nullish().catch(null), app_version: z.string().nullish().catch(null),

View File

@ -32,7 +32,8 @@ export const addSDXLRefinerToGraph = (
graph: NonNullableGraph, graph: NonNullableGraph,
baseNodeId: string, baseNodeId: string,
modelLoaderNodeId?: string, modelLoaderNodeId?: string,
canvasInitImage?: ImageDTO canvasInitImage?: ImageDTO,
canvasMaskImage?: ImageDTO
): void => { ): void => {
const { const {
refinerModel, refinerModel,
@ -257,8 +258,30 @@ export const addSDXLRefinerToGraph = (
}; };
} }
graph.edges.push( if (graph.id === SDXL_CANVAS_INPAINT_GRAPH) {
{ if (isUsingScaledDimensions) {
graph.edges.push({
source: {
node_id: MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'mask',
},
});
} else {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
...(graph.nodes[
SDXL_REFINER_INPAINT_CREATE_MASK
] as CreateDenoiseMaskInvocation),
mask: canvasMaskImage,
};
}
}
if (graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
graph.edges.push({
source: { source: {
node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE, node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE,
field: 'image', field: 'image',
@ -267,8 +290,10 @@ export const addSDXLRefinerToGraph = (
node_id: SDXL_REFINER_INPAINT_CREATE_MASK, node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'mask', field: 'mask',
}, },
}, });
{ }
graph.edges.push({
source: { source: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK, node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'denoise_mask', field: 'denoise_mask',
@ -277,8 +302,7 @@ export const addSDXLRefinerToGraph = (
node_id: SDXL_REFINER_DENOISE_LATENTS, node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'denoise_mask', field: 'denoise_mask',
}, },
} });
);
} }
if ( if (

View File

@ -663,7 +663,8 @@ export const buildCanvasSDXLInpaintGraph = (
graph, graph,
CANVAS_COHERENCE_DENOISE_LATENTS, CANVAS_COHERENCE_DENOISE_LATENTS,
modelLoaderNodeId, modelLoaderNodeId,
canvasInitImage canvasInitImage,
canvasMaskImage
); );
if (seamlessXAxis || seamlessYAxis) { if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS; modelLoaderNodeId = SDXL_REFINER_SEAMLESS;

View File

@ -1,7 +1,7 @@
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { CoreMetadata } from 'features/nodes/types/types'; import { CoreMetadata } from 'features/nodes/types/types';
import { t } from 'i18next'; import { t } from 'i18next';
import { useCallback } from 'react'; import { useCallback, useEffect } from 'react';
import { useAppToaster } from '../../../app/components/Toaster'; import { useAppToaster } from '../../../app/components/Toaster';
import { useAppDispatch } from '../../../app/store/storeHooks'; import { useAppDispatch } from '../../../app/store/storeHooks';
import { import {
@ -13,18 +13,21 @@ import { setActiveTab } from '../../ui/store/uiSlice';
import { initialImageSelected } from '../store/actions'; import { initialImageSelected } from '../store/actions';
import { useRecallParameters } from './useRecallParameters'; import { useRecallParameters } from './useRecallParameters';
export const usePreselectedImage = (imageName?: string) => { export const usePreselectedImage = (selectedImage?: {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
}) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { recallAllParameters } = useRecallParameters(); const { recallAllParameters } = useRecallParameters();
const toaster = useAppToaster(); const toaster = useAppToaster();
const { currentData: selectedImageDto } = useGetImageDTOQuery( const { currentData: selectedImageDto } = useGetImageDTOQuery(
imageName ?? skipToken selectedImage?.imageName ?? skipToken
); );
const { currentData: selectedImageMetadata } = useGetImageMetadataQuery( const { currentData: selectedImageMetadata } = useGetImageMetadataQuery(
imageName ?? skipToken selectedImage?.imageName ?? skipToken
); );
const handleSendToCanvas = useCallback(() => { const handleSendToCanvas = useCallback(() => {
@ -54,5 +57,23 @@ export const usePreselectedImage = (imageName?: string) => {
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, [selectedImageMetadata]); }, [selectedImageMetadata]);
useEffect(() => {
if (selectedImage && selectedImage.action === 'sendToCanvas') {
handleSendToCanvas();
}
}, [selectedImage, handleSendToCanvas]);
useEffect(() => {
if (selectedImage && selectedImage.action === 'sendToImg2Img') {
handleSendToImg2Img();
}
}, [selectedImage, handleSendToImg2Img]);
useEffect(() => {
if (selectedImage && selectedImage.action === 'useAllParameters') {
handleUseAllMetadata();
}
}, [selectedImage, handleUseAllMetadata]);
return { handleSendToCanvas, handleSendToImg2Img, handleUseAllMetadata }; return { handleSendToCanvas, handleSendToImg2Img, handleUseAllMetadata };
}; };

View File

@ -2,7 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types'; import {
CoreMetadata,
LoRAMetadataItem,
ControlNetMetadataItem,
} from 'features/nodes/types/types';
import { import {
refinerModelChanged, refinerModelChanged,
setNegativeStylePromptSDXL, setNegativeStylePromptSDXL,
@ -18,9 +22,18 @@ import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { import {
controlNetModelsAdapter,
loraModelsAdapter, loraModelsAdapter,
useGetControlNetModelsQuery,
useGetLoRAModelsQuery, useGetLoRAModelsQuery,
} from '../../../services/api/endpoints/models'; } from '../../../services/api/endpoints/models';
import {
ControlNetConfig,
controlNetEnabled,
controlNetRecalled,
controlNetReset,
initialControlNet,
} from '../../controlNet/store/controlNetSlice';
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice'; import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
import { initialImageSelected, modelSelected } from '../store/actions'; import { initialImageSelected, modelSelected } from '../store/actions';
import { import {
@ -38,6 +51,7 @@ import {
isValidCfgScale, isValidCfgScale,
isValidHeight, isValidHeight,
isValidLoRAModel, isValidLoRAModel,
isValidControlNetModel,
isValidMainModel, isValidMainModel,
isValidNegativePrompt, isValidNegativePrompt,
isValidPositivePrompt, isValidPositivePrompt,
@ -53,6 +67,11 @@ import {
isValidStrength, isValidStrength,
isValidWidth, isValidWidth,
} from '../types/parameterSchemas'; } from '../types/parameterSchemas';
import { v4 as uuidv4 } from 'uuid';
import {
CONTROLNET_PROCESSORS,
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
} from 'features/controlNet/store/constants';
const selector = createSelector(stateSelector, ({ generation }) => { const selector = createSelector(stateSelector, ({ generation }) => {
const { model } = generation; const { model } = generation;
@ -390,6 +409,121 @@ export const useRecallParameters = () => {
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] [prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall ControlNet with toast
*/
const { controlnets } = useGetControlNetModelsQuery(undefined, {
selectFromResult: (result) => ({
controlnets: result.data
? controlNetModelsAdapter.getSelectors().selectAll(result.data)
: [],
}),
});
const prepareControlNetMetadataItem = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem) => {
if (!isValidControlNetModel(controlnetMetadataItem.control_model)) {
return { controlnet: null, error: 'Invalid ControlNet model' };
}
const {
image,
control_model,
control_weight,
begin_step_percent,
end_step_percent,
control_mode,
resize_mode,
} = controlnetMetadataItem;
const matchingControlNetModel = controlnets.find(
(c) =>
c.base_model === control_model.base_model &&
c.model_name === control_model.model_name
);
if (!matchingControlNetModel) {
return { controlnet: null, error: 'ControlNet model is not installed' };
}
const isCompatibleBaseModel =
matchingControlNetModel?.base_model === model?.base_model;
if (!isCompatibleBaseModel) {
return {
controlnet: null,
error: 'ControlNet incompatible with currently-selected model',
};
}
const controlNetId = uuidv4();
let processorType = initialControlNet.processorType;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (matchingControlNetModel.model_name.includes(modelSubstring)) {
processorType =
CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring] ||
initialControlNet.processorType;
break;
}
}
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
const controlnet: ControlNetConfig = {
isEnabled: true,
model: matchingControlNetModel,
weight:
typeof control_weight === 'number'
? control_weight
: initialControlNet.weight,
beginStepPct: begin_step_percent || initialControlNet.beginStepPct,
endStepPct: end_step_percent || initialControlNet.endStepPct,
controlMode: control_mode || initialControlNet.controlMode,
resizeMode: resize_mode || initialControlNet.resizeMode,
controlImage: image?.image_name || null,
processedControlImage: image?.image_name || null,
processorType,
processorNode:
processorNode.type !== 'none'
? processorNode
: initialControlNet.processorNode,
shouldAutoConfig: true,
controlNetId,
};
return { controlnet, error: null };
},
[controlnets, model?.base_model]
);
const recallControlNet = useCallback(
(controlnetMetadataItem: ControlNetMetadataItem) => {
const result = prepareControlNetMetadataItem(controlnetMetadataItem);
if (!result.controlnet) {
parameterNotSetToast(result.error);
return;
}
dispatch(
controlNetRecalled({
...result.controlnet,
})
);
dispatch(controlNetEnabled());
parameterSetToast();
},
[
prepareControlNetMetadataItem,
dispatch,
parameterSetToast,
parameterNotSetToast,
]
);
/* /*
* Sets image as initial image with toast * Sets image as initial image with toast
*/ */
@ -428,6 +562,7 @@ export const useRecallParameters = () => {
refiner_negative_aesthetic_score, refiner_negative_aesthetic_score,
refiner_start, refiner_start,
loras, loras,
controlnets,
} = metadata; } = metadata;
if (isValidCfgScale(cfg_scale)) { if (isValidCfgScale(cfg_scale)) {
@ -517,6 +652,15 @@ export const useRecallParameters = () => {
} }
}); });
dispatch(controlNetReset());
dispatch(controlNetEnabled());
controlnets?.forEach((controlnet) => {
const result = prepareControlNetMetadataItem(controlnet);
if (result.controlnet) {
dispatch(controlNetRecalled(result.controlnet));
}
});
allParameterSetToast(); allParameterSetToast();
}, },
[ [
@ -524,6 +668,7 @@ export const useRecallParameters = () => {
allParameterSetToast, allParameterSetToast,
dispatch, dispatch,
prepareLoRAMetadataItem, prepareLoRAMetadataItem,
prepareControlNetMetadataItem,
] ]
); );
@ -542,6 +687,7 @@ export const useRecallParameters = () => {
recallHeight, recallHeight,
recallStrength, recallStrength,
recallLoRA, recallLoRA,
recallControlNet,
recallAllParameters, recallAllParameters,
sendToImageToImage, sendToImageToImage,
}; };

View File

@ -1,5 +1,6 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { isNil } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { import {
@ -40,7 +41,7 @@ export const useCancelCurrentQueueItem = () => {
}, [currentQueueItemId, dispatch, t, trigger]); }, [currentQueueItemId, dispatch, t, trigger]);
const isDisabled = useMemo( const isDisabled = useMemo(
() => !isConnected || !currentQueueItemId, () => !isConnected || isNil(currentQueueItemId),
[isConnected, currentQueueItemId] [isConnected, currentQueueItemId]
); );