feat(ui): split control layers from raster layers for UI and internal state, same rendering as raster layers

This commit is contained in:
psychedelicious 2024-08-15 10:35:07 +10:00
parent e49b72ee4e
commit d5ca99fc3c
59 changed files with 866 additions and 671 deletions

View File

@ -1679,7 +1679,9 @@
"opacity": "Opacity", "opacity": "Opacity",
"regionalGuidance_withCount": "Regional Guidance ({{count}})", "regionalGuidance_withCount": "Regional Guidance ({{count}})",
"controlAdapters_withCount": "Control Adapters ({{count}})", "controlAdapters_withCount": "Control Adapters ({{count}})",
"layers_withCount": "Raster Layers ({{count}})", "controlLayer": "Control Layer",
"controlLayers_withCount": "Control Layers ({{count}})",
"rasterLayers_withCount": "Raster Layers ({{count}})",
"ipAdapters_withCount": "IP Adapters ({{count}})", "ipAdapters_withCount": "IP Adapters ({{count}})",
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)", "globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)", "globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",

View File

@ -2,11 +2,11 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { import {
$lastProgressEvent, $lastProgressEvent,
layerAdded, rasterLayerAdded,
sessionStagingAreaImageAccepted, sessionStagingAreaImageAccepted,
sessionStagingAreaReset, sessionStagingAreaReset,
} from 'features/controlLayers/store/canvasV2Slice'; } from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasLayerState } from 'features/controlLayers/store/types'; import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/types'; import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
@ -62,12 +62,12 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
const { imageDTO, offsetX, offsetY } = stagingAreaImage; const { imageDTO, offsetX, offsetY } = stagingAreaImage;
const imageObject = imageDTOToImageObject(imageDTO); const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasLayerState> = { const overrides: Partial<CanvasRasterLayerState> = {
position: { x: x + offsetX, y: y + offsetY }, position: { x: x + offsetX, y: y + offsetY },
objects: [imageObject], objects: [imageObject],
}; };
api.dispatch(layerAdded({ overrides })); api.dispatch(rasterLayerAdded({ overrides }));
api.dispatch(sessionStagingAreaReset()); api.dispatch(sessionStagingAreaReset());
}, },
}); });

View File

@ -1,5 +1,5 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { ipaAllDeleted, layerAllDeleted } from 'features/controlLayers/store/canvasV2Slice'; import { ipaAllDeleted, rasterLayerAllDeleted } from 'features/controlLayers/store/canvasV2Slice';
import { getImageUsage } from 'features/deleteImageModal/store/selectors'; import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -22,7 +22,7 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
const imageUsage = getImageUsage(nodes.present, canvasV2, image_name); const imageUsage = getImageUsage(nodes.present, canvasV2, image_name);
if (imageUsage.isLayerImage && !wereLayersReset) { if (imageUsage.isLayerImage && !wereLayersReset) {
dispatch(layerAllDeleted()); dispatch(rasterLayerAllDeleted());
wereLayersReset = true; wereLayersReset = true;
} }

View File

@ -55,7 +55,7 @@ const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO
}; };
const deleteLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => { const deleteLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.canvasV2.layers.entities.forEach(({ id, objects }) => { state.canvasV2.rasterLayers.entities.forEach(({ id, objects }) => {
let shouldDelete = false; let shouldDelete = false;
for (const obj of objects) { for (const obj of objects) {
if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) { if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
@ -64,7 +64,7 @@ const deleteLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: Im
} }
} }
if (shouldDelete) { if (shouldDelete) {
dispatch(entityDeleted({ entityIdentifier: { id, type: 'layer' } })); dispatch(entityDeleted({ entityIdentifier: { id, type: 'raster_layer' } }));
} }
}); });
}; };

View File

@ -4,10 +4,10 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { import {
ipaImageChanged, ipaImageChanged,
layerAdded, rasterLayerAdded,
rgIPAdapterImageChanged, rgIPAdapterImageChanged,
} from 'features/controlLayers/store/canvasV2Slice'; } from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasLayerState } from 'features/controlLayers/store/types'; import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/types'; import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types'; import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop'; import { isValidDrop } from 'features/dnd/util/isValidDrop';
@ -108,11 +108,11 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
) { ) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO); const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = getState().canvasV2.bbox.rect; const { x, y } = getState().canvasV2.bbox.rect;
const overrides: Partial<CanvasLayerState> = { const overrides: Partial<CanvasRasterLayerState> = {
objects: [imageObject], objects: [imageObject],
position: { x, y }, position: { x, y },
}; };
dispatch(layerAdded({ overrides, isSelected: true })); dispatch(rasterLayerAdded({ overrides, isSelected: true }));
return; return;
} }

View File

@ -2,7 +2,6 @@ import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice'; import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasEntityState } from 'features/controlLayers/store/types';
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt'; import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
@ -18,14 +17,13 @@ import { forEach, upperFirst } from 'lodash-es';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { getConnectedEdges } from 'reactflow'; import { getConnectedEdges } from 'reactflow';
const LAYER_TYPE_TO_TKEY: Record<CanvasEntityState['type'], string> = { const LAYER_TYPE_TO_TKEY = {
control_adapter: 'controlLayers.globalControlAdapter', ip_adapter: 'controlLayers.ipAdapter',
ip_adapter: 'controlLayers.globalIPAdapter',
regional_guidance: 'controlLayers.regionalGuidance',
layer: 'controlLayers.raster',
inpaint_mask: 'controlLayers.inpaintMask', inpaint_mask: 'controlLayers.inpaintMask',
initial_image: 'controlLayers.initialImage', regional_guidance: 'controlLayers.regionalGuidance',
}; raster_layer: 'controlLayers.raster',
control_layer: 'controlLayers.globalControlAdapter',
} as const;
const createSelector = (templates: Templates) => const createSelector = (templates: Templates) =>
createMemoizedSelector( createMemoizedSelector(
@ -125,41 +123,35 @@ const createSelector = (templates: Templates) =>
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') }); reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
} }
// canvasV2.controlAdapters.entities canvasV2.controlLayers.entities
// .filter((ca) => ca.isEnabled) .filter((controlLayer) => controlLayer.isEnabled)
// .forEach((ca, i) => { .forEach((controlLayer, i) => {
// const layerLiteral = i18n.t('controlLayers.layers_one'); const layerLiteral = i18n.t('controlLayers.layers_one');
// const layerNumber = i + 1; const layerNumber = i + 1;
// const layerType = i18n.t(LAYER_TYPE_TO_TKEY[ca.type]); const layerType = i18n.t(LAYER_TYPE_TO_TKEY['control_layer']);
// const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
// const problems: string[] = []; const problems: string[] = [];
// // Must have model // Must have model
// if (!ca.model) { if (!controlLayer.controlAdapter.model) {
// problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected')); problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
// } }
// // Model base must match // Model base must match
// if (ca.model?.base !== model?.base) { if (controlLayer.controlAdapter.model?.base !== model?.base) {
// problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel')); problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
// } }
// // Must have a control image OR, if it has a processor, it must have a processed image // T2I Adapters require images have dimensions that are multiples of 64 (SD1.5) or 32 (SDXL)
// if (!ca.imageObject) { if (controlLayer.controlAdapter.type === 't2i_adapter') {
// problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected')); const multiple = model?.base === 'sdxl' ? 32 : 64;
// } else if (ca.processorConfig && !ca.processedImageObject) { if (bbox.rect.width % multiple !== 0 || bbox.rect.height % multiple !== 0) {
// problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed')); problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions', { multiple }));
// } }
// // T2I Adapters require images have dimensions that are multiples of 64 (SD1.5) or 32 (SDXL) }
// if (ca.adapterType === 't2i_adapter') {
// const multiple = model?.base === 'sdxl' ? 32 : 64;
// if (bbox.rect.width % multiple !== 0 || bbox.rect.height % multiple !== 0) {
// problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions', { multiple }));
// }
// }
// if (problems.length) { if (problems.length) {
// const content = upperFirst(problems.join(', ')); const content = upperFirst(problems.join(', '));
// reasons.push({ prefix, content }); reasons.push({ prefix, content });
// } }
// }); });
canvasV2.ipAdapters.entities canvasV2.ipAdapters.entities
.filter((ipa) => ipa.isEnabled) .filter((ipa) => ipa.isEnabled)
@ -226,8 +218,9 @@ const createSelector = (templates: Templates) =>
} }
}); });
canvasV2.layers.entities canvasV2.rasterLayers.entities
.filter((l) => l.isEnabled) .filter((l) => l.isEnabled)
.filter((l) => l.type === 'raster_layer')
.forEach((l, i) => { .forEach((l, i) => {
const layerLiteral = i18n.t('controlLayers.layers_one'); const layerLiteral = i18n.t('controlLayers.layers_one');
const layerNumber = i + 1; const layerNumber = i + 1;

View File

@ -1,7 +1,7 @@
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { useAddCALayer, useAddIPALayer } from 'features/controlLayers/hooks/addLayerHooks'; import { useAddCALayer, useAddIPALayer } from 'features/controlLayers/hooks/addLayerHooks';
import { layerAdded, rgAdded } from 'features/controlLayers/store/canvasV2Slice'; import { rasterLayerAdded, rgAdded } from 'features/controlLayers/store/canvasV2Slice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi'; import { PiPlusBold } from 'react-icons/pi';
@ -15,7 +15,7 @@ export const AddLayerButton = memo(() => {
dispatch(rgAdded()); dispatch(rgAdded());
}, [dispatch]); }, [dispatch]);
const addRasterLayer = useCallback(() => { const addRasterLayer = useCallback(() => {
dispatch(layerAdded({ isSelected: true })); dispatch(rasterLayerAdded({ isSelected: true }));
}, [dispatch]); }, [dispatch]);
return ( return (

View File

@ -1,8 +1,9 @@
import { Flex } from '@invoke-ai/ui-library'; import { Flex } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { ControlLayerEntityList } from 'features/controlLayers/components/ControlLayer/ControlLayerEntityList';
import { InpaintMask } from 'features/controlLayers/components/InpaintMask/InpaintMask'; import { InpaintMask } from 'features/controlLayers/components/InpaintMask/InpaintMask';
import { IPAdapterList } from 'features/controlLayers/components/IPAdapter/IPAdapterList'; import { IPAdapterList } from 'features/controlLayers/components/IPAdapter/IPAdapterList';
import { LayerEntityList } from 'features/controlLayers/components/Layer/LayerEntityList'; import { RasterLayerEntityList } from 'features/controlLayers/components/RasterLayer/RasterLayerEntityList';
import { RegionalGuidanceEntityList } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceEntityList'; import { RegionalGuidanceEntityList } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceEntityList';
import { memo } from 'react'; import { memo } from 'react';
@ -13,7 +14,8 @@ export const CanvasEntityList = memo(() => {
<InpaintMask /> <InpaintMask />
<RegionalGuidanceEntityList /> <RegionalGuidanceEntityList />
<IPAdapterList /> <IPAdapterList />
<LayerEntityList /> <ControlLayerEntityList />
<RasterLayerEntityList />
</Flex> </Flex>
</ScrollableContent> </ScrollableContent>
); );

View File

@ -0,0 +1,39 @@
import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
import { CanvasEntityTitle } from 'features/controlLayers/components/common/CanvasEntityTitle';
import { ControlLayerActionsMenu } from 'features/controlLayers/components/ControlLayer/ControlLayerActionsMenu';
import { ControlLayerControlAdapter } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapter';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { memo, useMemo } from 'react';
type Props = {
id: string;
};
export const ControlLayer = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'control_layer' }), [id]);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer>
<CanvasEntityHeader>
<CanvasEntityEnabledToggle />
<CanvasEntityTitle />
<Spacer />
<ControlLayerActionsMenu />
<CanvasEntityDeleteButton />
</CanvasEntityHeader>
<CanvasEntitySettingsWrapper>
<ControlLayerControlAdapter />
</CanvasEntitySettingsWrapper>
</CanvasEntityContainer>
</EntityIdentifierContext.Provider>
);
});
ControlLayer.displayName = 'ControlLayer';

View File

@ -0,0 +1,17 @@
import { Menu, MenuList } from '@invoke-ai/ui-library';
import { CanvasEntityActionMenuItems } from 'features/controlLayers/components/common/CanvasEntityActionMenuItems';
import { CanvasEntityMenuButton } from 'features/controlLayers/components/common/CanvasEntityMenuButton';
import { memo } from 'react';
export const ControlLayerActionsMenu = memo(() => {
return (
<Menu>
<CanvasEntityMenuButton />
<MenuList>
<CanvasEntityActionMenuItems />
</MenuList>
</Menu>
);
});
ControlLayerActionsMenu.displayName = 'ControlLayerActionsMenu';

View File

@ -5,50 +5,48 @@ import { Weight } from 'features/controlLayers/components/common/Weight';
import { ControlAdapterControlModeSelect } from 'features/controlLayers/components/ControlAdapter/ControlAdapterControlModeSelect'; import { ControlAdapterControlModeSelect } from 'features/controlLayers/components/ControlAdapter/ControlAdapterControlModeSelect';
import { ControlAdapterModel } from 'features/controlLayers/components/ControlAdapter/ControlAdapterModel'; import { ControlAdapterModel } from 'features/controlLayers/components/ControlAdapter/ControlAdapterModel';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useControlLayerControlAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
import { import {
layerControlAdapterBeginEndStepPctChanged, controlLayerBeginEndStepPctChanged,
layerControlAdapterControlModeChanged, controlLayerControlModeChanged,
layerControlAdapterModelChanged, controlLayerModelChanged,
layerControlAdapterWeightChanged, controlLayerWeightChanged,
} from 'features/controlLayers/store/canvasV2Slice'; } from 'features/controlLayers/store/canvasV2Slice';
import type { ControlModeV2, ControlNetConfig, T2IAdapterConfig } from 'features/controlLayers/store/types'; import type { ControlModeV2 } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types'; import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
type Props = { export const ControlLayerControlAdapter = memo(() => {
controlAdapter: ControlNetConfig | T2IAdapterConfig;
};
export const LayerControlAdapter = memo(({ controlAdapter }: Props) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { id } = useEntityIdentifierContext(); const entityIdentifier = useEntityIdentifierContext();
const controlAdapter = useControlLayerControlAdapter(entityIdentifier);
const onChangeBeginEndStepPct = useCallback( const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => { (beginEndStepPct: [number, number]) => {
dispatch(layerControlAdapterBeginEndStepPctChanged({ id, beginEndStepPct })); dispatch(controlLayerBeginEndStepPctChanged({ id: entityIdentifier.id, beginEndStepPct }));
}, },
[dispatch, id] [dispatch, entityIdentifier.id]
); );
const onChangeControlMode = useCallback( const onChangeControlMode = useCallback(
(controlMode: ControlModeV2) => { (controlMode: ControlModeV2) => {
dispatch(layerControlAdapterControlModeChanged({ id, controlMode })); dispatch(controlLayerControlModeChanged({ id: entityIdentifier.id, controlMode }));
}, },
[dispatch, id] [dispatch, entityIdentifier.id]
); );
const onChangeWeight = useCallback( const onChangeWeight = useCallback(
(weight: number) => { (weight: number) => {
dispatch(layerControlAdapterWeightChanged({ id, weight })); dispatch(controlLayerWeightChanged({ id: entityIdentifier.id, weight }));
}, },
[dispatch, id] [dispatch, entityIdentifier.id]
); );
const onChangeModel = useCallback( const onChangeModel = useCallback(
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => { (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
dispatch(layerControlAdapterModelChanged({ id, modelConfig })); dispatch(controlLayerModelChanged({ id: entityIdentifier.id, modelConfig }));
}, },
[dispatch, id] [dispatch, entityIdentifier.id]
); );
return ( return (
@ -63,4 +61,4 @@ export const LayerControlAdapter = memo(({ controlAdapter }: Props) => {
); );
}); });
LayerControlAdapter.displayName = 'LayerControlAdapter'; ControlLayerControlAdapter.displayName = 'ControlLayerControlAdapter';

View File

@ -0,0 +1,38 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupTitle } from 'features/controlLayers/components/common/CanvasEntityGroupTitle';
import { ControlLayer } from 'features/controlLayers/components/ControlLayer/ControlLayer';
import { mapId } from 'features/controlLayers/konva/util';
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const selectEntityIds = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => {
return canvasV2.controlLayers.entities.map(mapId).reverse();
});
export const ControlLayerEntityList = memo(() => {
const { t } = useTranslation();
const isSelected = useAppSelector((s) => Boolean(s.canvasV2.selectedEntityIdentifier?.type === 'control_layer'));
const layerIds = useAppSelector(selectEntityIds);
if (layerIds.length === 0) {
return null;
}
if (layerIds.length > 0) {
return (
<>
<CanvasEntityGroupTitle
title={t('controlLayers.controlLayers_withCount', { count: layerIds.length })}
isSelected={isSelected}
/>
{layerIds.map((id) => (
<ControlLayer key={id} id={id} />
))}
</>
);
}
});
ControlLayerEntityList.displayName = 'ControlLayerEntityList';

View File

@ -10,8 +10,10 @@ import {
PopoverContent, PopoverContent,
PopoverTrigger, PopoverTrigger,
} from '@invoke-ai/ui-library'; } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { MaskOpacity } from 'features/controlLayers/components/MaskOpacity'; import { MaskOpacity } from 'features/controlLayers/components/MaskOpacity';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { import {
clipToBboxChanged, clipToBboxChanged,
invertScrollChanged, invertScrollChanged,
@ -25,6 +27,7 @@ import { RiSettings4Fill } from 'react-icons/ri';
const ControlLayersSettingsPopover = () => { const ControlLayersSettingsPopover = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const canvasManager = useStore($canvasManager);
const clipToBbox = useAppSelector((s) => s.canvasV2.settings.clipToBbox); const clipToBbox = useAppSelector((s) => s.canvasV2.settings.clipToBbox);
const invertScroll = useAppSelector((s) => s.canvasV2.tool.invertScroll); const invertScroll = useAppSelector((s) => s.canvasV2.tool.invertScroll);
const onChangeInvertScroll = useCallback( const onChangeInvertScroll = useCallback(
@ -38,6 +41,21 @@ const ControlLayersSettingsPopover = () => {
const invalidateRasterizationCaches = useCallback(() => { const invalidateRasterizationCaches = useCallback(() => {
dispatch(rasterizationCachesInvalidated()); dispatch(rasterizationCachesInvalidated());
}, [dispatch]); }, [dispatch]);
const calculateBboxes = useCallback(() => {
if (!canvasManager) {
return;
}
for (const adapter of canvasManager.rasterLayerAdapters.values()) {
adapter.transformer.requestRectCalculation();
}
for (const adapter of canvasManager.controlLayerAdapters.values()) {
adapter.transformer.requestRectCalculation();
}
for (const adapter of canvasManager.regionalGuidanceAdapters.values()) {
adapter.transformer.requestRectCalculation();
}
canvasManager.inpaintMaskAdapter.transformer.requestRectCalculation();
}, [canvasManager]);
return ( return (
<Popover isLazy> <Popover isLazy>
<PopoverTrigger> <PopoverTrigger>
@ -58,6 +76,9 @@ const ControlLayersSettingsPopover = () => {
<Button onClick={invalidateRasterizationCaches} size="sm"> <Button onClick={invalidateRasterizationCaches} size="sm">
Invalidate Rasterization Caches Invalidate Rasterization Caches
</Button> </Button>
<Button onClick={calculateBboxes} size="sm">
Calculate Bboxes
</Button>
</Flex> </Flex>
</PopoverBody> </PopoverBody>
</PopoverContent> </PopoverContent>

View File

@ -1,5 +1,5 @@
/* eslint-disable i18next/no-literal-string */ /* eslint-disable i18next/no-literal-string */
import { Button, Flex, Switch } from '@invoke-ai/ui-library'; import { Flex, Switch } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { BrushWidth } from 'features/controlLayers/components/BrushWidth'; import { BrushWidth } from 'features/controlLayers/components/BrushWidth';
@ -12,36 +12,15 @@ import { ResetCanvasButton } from 'features/controlLayers/components/ResetCanvas
import { ToolChooser } from 'features/controlLayers/components/ToolChooser'; import { ToolChooser } from 'features/controlLayers/components/ToolChooser';
import { UndoRedoButtonGroup } from 'features/controlLayers/components/UndoRedoButtonGroup'; import { UndoRedoButtonGroup } from 'features/controlLayers/components/UndoRedoButtonGroup';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager'; import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { nanoid } from 'features/controlLayers/konva/util';
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton'; import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
import { ViewerToggleMenu } from 'features/gallery/components/ImageViewer/ViewerToggleMenu'; import { ViewerToggleMenu } from 'features/gallery/components/ImageViewer/ViewerToggleMenu';
import type { ChangeEvent } from 'react'; import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
const filter = () => {
const entity = $canvasManager.get()?.stateApi.getSelectedEntity();
if (!entity || entity.type !== 'layer') {
return;
}
entity.adapter.filter.previewFilter({
type: 'canny_image_processor',
id: nanoid(),
low_threshold: 50,
high_threshold: 50,
});
};
export const ControlLayersToolbar = memo(() => { export const ControlLayersToolbar = memo(() => {
const tool = useAppSelector((s) => s.canvasV2.tool.selected); const tool = useAppSelector((s) => s.canvasV2.tool.selected);
const canvasManager = useStore($canvasManager); const canvasManager = useStore($canvasManager);
const bbox = useCallback(() => {
if (!canvasManager) {
return;
}
for (const l of canvasManager.layers.values()) {
l.transformer.requestRectCalculation();
}
}, [canvasManager]);
const onChangeDebugging = useCallback( const onChangeDebugging = useCallback(
(e: ChangeEvent<HTMLInputElement>) => { (e: ChangeEvent<HTMLInputElement>) => {
if (!canvasManager) { if (!canvasManager) {
@ -61,7 +40,6 @@ export const ControlLayersToolbar = memo(() => {
<Flex gap={2} marginInlineEnd="auto" alignItems="center"> <Flex gap={2} marginInlineEnd="auto" alignItems="center">
<ToggleProgressButton /> <ToggleProgressButton />
<ToolChooser /> <ToolChooser />
<Button onClick={filter}>Filter</Button>
</Flex> </Flex>
</Flex> </Flex>
<Flex flex={1} gap={2} justifyContent="center" alignItems="center"> <Flex flex={1} gap={2} justifyContent="center" alignItems="center">
@ -70,7 +48,6 @@ export const ControlLayersToolbar = memo(() => {
</Flex> </Flex>
<CanvasScale /> <CanvasScale />
<CanvasResetViewButton /> <CanvasResetViewButton />
<Button onClick={bbox}>bbox</Button>
<Switch onChange={onChangeDebugging}>debug</Switch> <Switch onChange={onChangeDebugging}>debug</Switch>
<Flex flex={1} justifyContent="center"> <Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto" alignItems="center"> <Flex gap={2} marginInlineStart="auto" alignItems="center">

View File

@ -13,7 +13,7 @@ export const DeleteAllLayersButton = memo(() => {
s.canvasV2.regions.entities.length + s.canvasV2.regions.entities.length +
// s.canvasV2.controlAdapters.entities.length + // s.canvasV2.controlAdapters.entities.length +
s.canvasV2.ipAdapters.entities.length + s.canvasV2.ipAdapters.entities.length +
s.canvasV2.layers.entities.length s.canvasV2.rasterLayers.entities.length
); );
}); });
const onClick = useCallback(() => { const onClick = useCallback(() => {

View File

@ -18,7 +18,7 @@ export const Filter = memo(() => {
return; return;
} }
const entity = canvasManager.stateApi.getEntity(filteringEntity); const entity = canvasManager.stateApi.getEntity(filteringEntity);
if (!entity || entity.type !== 'layer') { if (!entity || (entity.type !== 'raster_layer' && entity.type !== 'control_layer')) {
return; return;
} }
entity.adapter.filter.previewFilter(); entity.adapter.filter.previewFilter();
@ -33,7 +33,7 @@ export const Filter = memo(() => {
return; return;
} }
const entity = canvasManager.stateApi.getEntity(filteringEntity); const entity = canvasManager.stateApi.getEntity(filteringEntity);
if (!entity || entity.type !== 'layer') { if (!entity || (entity.type !== 'raster_layer' && entity.type !== 'control_layer')) {
return; return;
} }
entity.adapter.filter.applyFilter(); entity.adapter.filter.applyFilter();
@ -48,7 +48,7 @@ export const Filter = memo(() => {
return; return;
} }
const entity = canvasManager.stateApi.getEntity(filteringEntity); const entity = canvasManager.stateApi.getEntity(filteringEntity);
if (!entity || entity.type !== 'layer') { if (!entity || (entity.type !== 'raster_layer' && entity.type !== 'control_layer')) {
return; return;
} }
entity.adapter.filter.cancelFilter(); entity.adapter.filter.cancelFilter();

View File

@ -1,17 +0,0 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useFilter } from 'features/controlLayers/components/Filters/Filter';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
const FilterWrapper = (props: PropsWithChildren) => {
const isPreviewDisabled = useAppSelector((s) => s.canvasV2.selectedEntityIdentifier?.type !== 'layer');
const filter = useFilter();
return (
<Flex flexDir="column" gap={3} w="full" h="full">
{props.children}
</Flex>
);
};
export default memo(FilterWrapper);

View File

@ -1,4 +1,4 @@
import { Spacer, useDisclosure } from '@invoke-ai/ui-library'; import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer'; import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton'; import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle'; import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
@ -15,18 +15,17 @@ type Props = {
export const IPAdapter = memo(({ id }: Props) => { export const IPAdapter = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'ip_adapter' }), [id]); const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'ip_adapter' }), [id]);
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true });
return ( return (
<EntityIdentifierContext.Provider value={entityIdentifier}> <EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer> <CanvasEntityContainer>
<CanvasEntityHeader onDoubleClick={onToggle}> <CanvasEntityHeader>
<CanvasEntityEnabledToggle /> <CanvasEntityEnabledToggle />
<CanvasEntityTitle /> <CanvasEntityTitle />
<Spacer /> <Spacer />
<CanvasEntityDeleteButton /> <CanvasEntityDeleteButton />
</CanvasEntityHeader> </CanvasEntityHeader>
{isOpen && <IPAdapterSettings />} <IPAdapterSettings />
</CanvasEntityContainer> </CanvasEntityContainer>
</EntityIdentifierContext.Provider> </EntityIdentifierContext.Provider>
); );

View File

@ -1,7 +1,7 @@
import { Box, Flex } from '@invoke-ai/ui-library'; import { Box, Flex } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct'; import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { CanvasEntitySettings } from 'features/controlLayers/components/common/CanvasEntitySettings'; import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
import { Weight } from 'features/controlLayers/components/common/Weight'; import { Weight } from 'features/controlLayers/components/common/Weight';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod'; import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@ -73,7 +73,7 @@ export const IPAdapterSettings = memo(() => {
const postUploadAction = useMemo<IPALayerImagePostUploadAction>(() => ({ type: 'SET_IPA_IMAGE', id }), [id]); const postUploadAction = useMemo<IPALayerImagePostUploadAction>(() => ({ type: 'SET_IPA_IMAGE', id }), [id]);
return ( return (
<CanvasEntitySettings> <CanvasEntitySettingsWrapper>
<Flex flexDir="column" gap={4} position="relative" w="full"> <Flex flexDir="column" gap={4} position="relative" w="full">
<Flex gap={3} alignItems="center" w="full"> <Flex gap={3} alignItems="center" w="full">
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s"> <Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
@ -102,7 +102,7 @@ export const IPAdapterSettings = memo(() => {
</Flex> </Flex>
</Flex> </Flex>
</Flex> </Flex>
</CanvasEntitySettings> </CanvasEntitySettingsWrapper>
); );
}); });

View File

@ -1,32 +1,38 @@
import { Spacer, useDisclosure } from '@invoke-ai/ui-library'; import { Spacer } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer'; import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle'; import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
import { CanvasEntityGroupTitle } from 'features/controlLayers/components/common/CanvasEntityGroupTitle';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader'; import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityTitle } from 'features/controlLayers/components/common/CanvasEntityTitle'; import { CanvasEntityTitle } from 'features/controlLayers/components/common/CanvasEntityTitle';
import { InpaintMaskActionsMenu } from 'features/controlLayers/components/InpaintMask/InpaintMaskActionsMenu'; import { InpaintMaskActionsMenu } from 'features/controlLayers/components/InpaintMask/InpaintMaskActionsMenu';
import { InpaintMaskSettings } from 'features/controlLayers/components/InpaintMask/InpaintMaskSettings';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types'; import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { InpaintMaskMaskFillColorPicker } from './InpaintMaskMaskFillColorPicker'; import { InpaintMaskMaskFillColorPicker } from './InpaintMaskMaskFillColorPicker';
export const InpaintMask = memo(() => { export const InpaintMask = memo(() => {
const { t } = useTranslation();
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id: 'inpaint_mask', type: 'inpaint_mask' }), []); const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id: 'inpaint_mask', type: 'inpaint_mask' }), []);
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: false }); const isSelected = useAppSelector((s) => Boolean(s.canvasV2.selectedEntityIdentifier?.type === 'inpaint_mask'));
return ( return (
<>
<CanvasEntityGroupTitle title={t('controlLayers.inpaintMask')} isSelected={isSelected} />
<EntityIdentifierContext.Provider value={entityIdentifier}> <EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer> <CanvasEntityContainer>
<CanvasEntityHeader onDoubleClick={onToggle}> <CanvasEntityHeader>
<CanvasEntityEnabledToggle /> <CanvasEntityEnabledToggle />
<CanvasEntityTitle /> <CanvasEntityTitle />
<Spacer /> <Spacer />
<InpaintMaskMaskFillColorPicker /> <InpaintMaskMaskFillColorPicker />
<InpaintMaskActionsMenu /> <InpaintMaskActionsMenu />
</CanvasEntityHeader> </CanvasEntityHeader>
{isOpen && <InpaintMaskSettings />}
</CanvasEntityContainer> </CanvasEntityContainer>
</EntityIdentifierContext.Provider> </EntityIdentifierContext.Provider>
</>
); );
}); });

View File

@ -1,8 +0,0 @@
import { CanvasEntitySettings } from 'features/controlLayers/components/common/CanvasEntitySettings';
import { memo } from 'react';
export const InpaintMaskSettings = memo(() => {
return <CanvasEntitySettings>PLACEHOLDER</CanvasEntitySettings>;
});
InpaintMaskSettings.displayName = 'InpaintMaskSettings';

View File

@ -1,22 +0,0 @@
import { CanvasEntitySettings } from 'features/controlLayers/components/common/CanvasEntitySettings';
import { LayerControlAdapter } from 'features/controlLayers/components/Layer/LayerControlAdapter';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useLayerControlAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
import { memo } from 'react';
export const LayerSettings = memo(() => {
const entityIdentifier = useEntityIdentifierContext();
const controlAdapter = useLayerControlAdapter(entityIdentifier);
if (!controlAdapter) {
return null;
}
return (
<CanvasEntitySettings>
<LayerControlAdapter controlAdapter={controlAdapter} />
</CanvasEntitySettings>
);
});
LayerSettings.displayName = 'LayerSettings';

View File

@ -4,8 +4,7 @@ import { CanvasEntityDeleteButton } from 'features/controlLayers/components/comm
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle'; import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader'; import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityTitle } from 'features/controlLayers/components/common/CanvasEntityTitle'; import { CanvasEntityTitle } from 'features/controlLayers/components/common/CanvasEntityTitle';
import { LayerActionsMenu } from 'features/controlLayers/components/Layer/LayerActionsMenu'; import { RasterLayerActionsMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerActionsMenu';
import { LayerSettings } from 'features/controlLayers/components/Layer/LayerSettings';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types'; import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
@ -14,8 +13,8 @@ type Props = {
id: string; id: string;
}; };
export const Layer = memo(({ id }: Props) => { export const RasterLayer = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'layer' }), [id]); const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'raster_layer' }), [id]);
return ( return (
<EntityIdentifierContext.Provider value={entityIdentifier}> <EntityIdentifierContext.Provider value={entityIdentifier}>
@ -24,13 +23,12 @@ export const Layer = memo(({ id }: Props) => {
<CanvasEntityEnabledToggle /> <CanvasEntityEnabledToggle />
<CanvasEntityTitle /> <CanvasEntityTitle />
<Spacer /> <Spacer />
<LayerActionsMenu /> <RasterLayerActionsMenu />
<CanvasEntityDeleteButton /> <CanvasEntityDeleteButton />
</CanvasEntityHeader> </CanvasEntityHeader>
<LayerSettings />
</CanvasEntityContainer> </CanvasEntityContainer>
</EntityIdentifierContext.Provider> </EntityIdentifierContext.Provider>
); );
}); });
Layer.displayName = 'Layer'; RasterLayer.displayName = 'RasterLayer';

View File

@ -3,7 +3,7 @@ import { CanvasEntityActionMenuItems } from 'features/controlLayers/components/c
import { CanvasEntityMenuButton } from 'features/controlLayers/components/common/CanvasEntityMenuButton'; import { CanvasEntityMenuButton } from 'features/controlLayers/components/common/CanvasEntityMenuButton';
import { memo } from 'react'; import { memo } from 'react';
export const LayerActionsMenu = memo(() => { export const RasterLayerActionsMenu = memo(() => {
return ( return (
<Menu> <Menu>
<CanvasEntityMenuButton /> <CanvasEntityMenuButton />
@ -14,4 +14,4 @@ export const LayerActionsMenu = memo(() => {
); );
}); });
LayerActionsMenu.displayName = 'LayerActionsMenu'; RasterLayerActionsMenu.displayName = 'RasterLayerActionsMenu';

View File

@ -1,19 +1,19 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupTitle } from 'features/controlLayers/components/common/CanvasEntityGroupTitle'; import { CanvasEntityGroupTitle } from 'features/controlLayers/components/common/CanvasEntityGroupTitle';
import { Layer } from 'features/controlLayers/components/Layer/Layer'; import { RasterLayer } from 'features/controlLayers/components/RasterLayer/RasterLayer';
import { mapId } from 'features/controlLayers/konva/util'; import { mapId } from 'features/controlLayers/konva/util';
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice'; import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { memo } from 'react'; import { memo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const selectEntityIds = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => { const selectEntityIds = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => {
return canvasV2.layers.entities.map(mapId).reverse(); return canvasV2.rasterLayers.entities.map(mapId).reverse();
}); });
export const LayerEntityList = memo(() => { export const RasterLayerEntityList = memo(() => {
const { t } = useTranslation(); const { t } = useTranslation();
const isSelected = useAppSelector((s) => Boolean(s.canvasV2.selectedEntityIdentifier?.type === 'layer')); const isSelected = useAppSelector((s) => Boolean(s.canvasV2.selectedEntityIdentifier?.type === 'raster_layer'));
const layerIds = useAppSelector(selectEntityIds); const layerIds = useAppSelector(selectEntityIds);
if (layerIds.length === 0) { if (layerIds.length === 0) {
@ -24,15 +24,15 @@ export const LayerEntityList = memo(() => {
return ( return (
<> <>
<CanvasEntityGroupTitle <CanvasEntityGroupTitle
title={t('controlLayers.layers_withCount', { count: layerIds.length })} title={t('controlLayers.rasterLayers_withCount', { count: layerIds.length })}
isSelected={isSelected} isSelected={isSelected}
/> />
{layerIds.map((id) => ( {layerIds.map((id) => (
<Layer key={id} id={id} /> <RasterLayer key={id} id={id} />
))} ))}
</> </>
); );
} }
}); });
LayerEntityList.displayName = 'LayerEntityList'; RasterLayerEntityList.displayName = 'RasterLayerEntityList';

View File

@ -1,4 +1,4 @@
import { Spacer, useDisclosure } from '@invoke-ai/ui-library'; import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer'; import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton'; import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle'; import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
@ -20,11 +20,10 @@ type Props = {
export const RegionalGuidance = memo(({ id }: Props) => { export const RegionalGuidance = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'regional_guidance' }), [id]); const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'regional_guidance' }), [id]);
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true });
return ( return (
<EntityIdentifierContext.Provider value={entityIdentifier}> <EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer> <CanvasEntityContainer>
<CanvasEntityHeader onDoubleClick={onToggle}> <CanvasEntityHeader>
<CanvasEntityEnabledToggle /> <CanvasEntityEnabledToggle />
<CanvasEntityTitle /> <CanvasEntityTitle />
<Spacer /> <Spacer />
@ -34,7 +33,7 @@ export const RegionalGuidance = memo(({ id }: Props) => {
<RegionalGuidanceActionsMenu /> <RegionalGuidanceActionsMenu />
<CanvasEntityDeleteButton /> <CanvasEntityDeleteButton />
</CanvasEntityHeader> </CanvasEntityHeader>
{isOpen && <RegionalGuidanceSettings />} <RegionalGuidanceSettings />
</CanvasEntityContainer> </CanvasEntityContainer>
</EntityIdentifierContext.Provider> </EntityIdentifierContext.Provider>
); );

View File

@ -1,6 +1,6 @@
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { AddPromptButtons } from 'features/controlLayers/components/AddPromptButtons'; import { AddPromptButtons } from 'features/controlLayers/components/AddPromptButtons';
import { CanvasEntitySettings } from 'features/controlLayers/components/common/CanvasEntitySettings'; import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { selectRGOrThrow } from 'features/controlLayers/store/regionsReducers'; import { selectRGOrThrow } from 'features/controlLayers/store/regionsReducers';
import { memo } from 'react'; import { memo } from 'react';
@ -16,12 +16,12 @@ export const RegionalGuidanceSettings = memo(() => {
const hasIPAdapters = useAppSelector((s) => selectRGOrThrow(s.canvasV2, id).ipAdapters.length > 0); const hasIPAdapters = useAppSelector((s) => selectRGOrThrow(s.canvasV2, id).ipAdapters.length > 0);
return ( return (
<CanvasEntitySettings> <CanvasEntitySettingsWrapper>
{!hasPositivePrompt && !hasNegativePrompt && !hasIPAdapters && <AddPromptButtons id={id} />} {!hasPositivePrompt && !hasNegativePrompt && !hasIPAdapters && <AddPromptButtons id={id} />}
{hasPositivePrompt && <RegionalGuidancePositivePrompt id={id} />} {hasPositivePrompt && <RegionalGuidancePositivePrompt id={id} />}
{hasNegativePrompt && <RegionalGuidanceNegativePrompt id={id} />} {hasNegativePrompt && <RegionalGuidanceNegativePrompt id={id} />}
{hasIPAdapters && <RegionalGuidanceIPAdapters id={id} />} {hasIPAdapters && <RegionalGuidanceIPAdapters id={id} />}
</CanvasEntitySettings> </CanvasEntitySettingsWrapper>
); );
}); });

View File

@ -3,16 +3,18 @@ import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useLayerUseAsControl } from 'features/controlLayers/hooks/useLayerControlAdapter'; import { useDefaultControlAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager'; import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { import {
$filteringEntity, $filteringEntity,
controlLayerConvertedToRasterLayer,
entityArrangedBackwardOne, entityArrangedBackwardOne,
entityArrangedForwardOne, entityArrangedForwardOne,
entityArrangedToBack, entityArrangedToBack,
entityArrangedToFront, entityArrangedToFront,
entityDeleted, entityDeleted,
entityReset, entityReset,
rasterLayerConvertedToControlLayer,
selectCanvasV2Slice, selectCanvasV2Slice,
} from 'features/controlLayers/store/canvasV2Slice'; } from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasEntityIdentifier, CanvasV2State } from 'features/controlLayers/store/types'; import type { CanvasEntityIdentifier, CanvasV2State } from 'features/controlLayers/store/types';
@ -28,17 +30,21 @@ import {
PiQuestionMarkBold, PiQuestionMarkBold,
PiStarHalfBold, PiStarHalfBold,
PiTrashSimpleBold, PiTrashSimpleBold,
PiXBold,
} from 'react-icons/pi'; } from 'react-icons/pi';
const getIndexAndCount = ( const getIndexAndCount = (
canvasV2: CanvasV2State, canvasV2: CanvasV2State,
{ id, type }: CanvasEntityIdentifier { id, type }: CanvasEntityIdentifier
): { index: number; count: number } => { ): { index: number; count: number } => {
if (type === 'layer') { if (type === 'raster_layer') {
return { return {
index: canvasV2.layers.entities.findIndex((entity) => entity.id === id), index: canvasV2.rasterLayers.entities.findIndex((entity) => entity.id === id),
count: canvasV2.layers.entities.length, count: canvasV2.rasterLayers.entities.length,
};
} else if (type === 'control_layer') {
return {
index: canvasV2.controlLayers.entities.findIndex((entity) => entity.id === id),
count: canvasV2.controlLayers.entities.length,
}; };
} else if (type === 'regional_guidance') { } else if (type === 'regional_guidance') {
return { return {
@ -58,7 +64,6 @@ export const CanvasEntityActionMenuItems = memo(() => {
const canvasManager = useStore($canvasManager); const canvasManager = useStore($canvasManager);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext(); const entityIdentifier = useEntityIdentifierContext();
const useAsControl = useLayerUseAsControl(entityIdentifier);
const selectValidActions = useMemo( const selectValidActions = useMemo(
() => () =>
createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => { createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => {
@ -76,16 +81,39 @@ export const CanvasEntityActionMenuItems = memo(() => {
const validActions = useAppSelector(selectValidActions); const validActions = useAppSelector(selectValidActions);
const isArrangeable = useMemo( const isArrangeable = useMemo(
() => entityIdentifier.type === 'layer' || entityIdentifier.type === 'regional_guidance', () =>
entityIdentifier.type === 'raster_layer' ||
entityIdentifier.type === 'control_layer' ||
entityIdentifier.type === 'regional_guidance',
[entityIdentifier.type] [entityIdentifier.type]
); );
const isDeleteable = useMemo( const isDeleteable = useMemo(
() => entityIdentifier.type === 'layer' || entityIdentifier.type === 'regional_guidance', () =>
entityIdentifier.type === 'raster_layer' ||
entityIdentifier.type === 'control_layer' ||
entityIdentifier.type === 'regional_guidance',
[entityIdentifier.type] [entityIdentifier.type]
); );
const isFilterable = useMemo(() => entityIdentifier.type === 'layer', [entityIdentifier.type]);
const isUseAsControlable = useMemo(() => entityIdentifier.type === 'layer', [entityIdentifier.type]); const isFilterable = useMemo(
() => entityIdentifier.type === 'raster_layer' || entityIdentifier.type === 'control_layer',
[entityIdentifier.type]
);
const isRasterLayer = useMemo(() => entityIdentifier.type === 'raster_layer', [entityIdentifier.type]);
const isControlLayer = useMemo(() => entityIdentifier.type === 'control_layer', [entityIdentifier.type]);
const defaultControlAdapter = useDefaultControlAdapter();
const convertRasterLayerToControlLayer = useCallback(() => {
dispatch(rasterLayerConvertedToControlLayer({ id: entityIdentifier.id, controlAdapter: defaultControlAdapter }));
}, [dispatch, defaultControlAdapter, entityIdentifier.id]);
const convertControlLayerToRasterLayer = useCallback(() => {
dispatch(controlLayerConvertedToRasterLayer({ id: entityIdentifier.id }));
}, [dispatch, entityIdentifier.id]);
const deleteEntity = useCallback(() => { const deleteEntity = useCallback(() => {
dispatch(entityDeleted({ entityIdentifier })); dispatch(entityDeleted({ entityIdentifier }));
@ -142,9 +170,14 @@ export const CanvasEntityActionMenuItems = memo(() => {
{t('common.filter')} {t('common.filter')}
</MenuItem> </MenuItem>
)} )}
{isUseAsControlable && ( {isRasterLayer && (
<MenuItem onClick={useAsControl.toggle} icon={useAsControl.hasControlAdapter ? <PiXBold /> : <PiCheckBold />}> <MenuItem onClick={convertRasterLayerToControlLayer} icon={<PiCheckBold />}>
{useAsControl.hasControlAdapter ? t('common.removeControl') : t('common.useAsControl')} {t('common.convertToControlLayer')}
</MenuItem>
)}
{isControlLayer && (
<MenuItem onClick={convertControlLayerToRasterLayer} icon={<PiCheckBold />}>
{t('common.convertToRasterLayer')}
</MenuItem> </MenuItem>
)} )}
<MenuDivider /> <MenuDivider />

View File

@ -8,7 +8,7 @@ type Props = {
export const CanvasEntityGroupTitle = memo(({ title, isSelected }: Props) => { export const CanvasEntityGroupTitle = memo(({ title, isSelected }: Props) => {
return ( return (
<Text color={isSelected ? 'base.100' : 'base.300'} userSelect="none"> <Text color={isSelected ? 'base.200' : 'base.500'} fontWeight="semibold" userSelect="none">
{title} {title}
</Text> </Text>
); );

View File

@ -2,7 +2,7 @@ import { Flex } from '@invoke-ai/ui-library';
import type { PropsWithChildren } from 'react'; import type { PropsWithChildren } from 'react';
import { memo } from 'react'; import { memo } from 'react';
export const CanvasEntitySettings = memo(({ children }: PropsWithChildren) => { export const CanvasEntitySettingsWrapper = memo(({ children }: PropsWithChildren) => {
return ( return (
<Flex flexDir="column" gap={3} px={3} pb={3}> <Flex flexDir="column" gap={3} px={3} pb={3}>
{children} {children}
@ -10,4 +10,4 @@ export const CanvasEntitySettings = memo(({ children }: PropsWithChildren) => {
); );
}); });
CanvasEntitySettings.displayName = 'CanvasEntitySettings'; CanvasEntitySettingsWrapper.displayName = 'CanvasEntitySettingsWrapper';

View File

@ -1,10 +1,7 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton'; import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { import { entityReset, selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
entityReset,
selectCanvasV2Slice,
} from 'features/controlLayers/store/canvasV2Slice';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
@ -17,7 +14,6 @@ export function useCanvasResetLayerHotkey() {
useAssertSingleton(useCanvasResetLayerHotkey.name); useAssertSingleton(useCanvasResetLayerHotkey.name);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier); const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const isStaging = useAppSelector((s) => s.canvasV2.session.isStaging);
const resetSelectedLayer = useCallback(() => { const resetSelectedLayer = useCallback(() => {
if (selectedEntityIdentifier === null) { if (selectedEntityIdentifier === null) {
@ -27,16 +23,9 @@ export function useCanvasResetLayerHotkey() {
}, [dispatch, selectedEntityIdentifier]); }, [dispatch, selectedEntityIdentifier]);
const isResetEnabled = useMemo( const isResetEnabled = useMemo(
() => () => selectedEntityIdentifier?.type === 'inpaint_mask',
(!isStaging && selectedEntityIdentifier?.type === 'layer') || [selectedEntityIdentifier?.type]
selectedEntityIdentifier?.type === 'regional_guidance' ||
selectedEntityIdentifier?.type === 'inpaint_mask',
[isStaging, selectedEntityIdentifier?.type]
); );
useHotkeys('shift+c', resetSelectedLayer, { enabled: isResetEnabled }, [ useHotkeys('shift+c', resetSelectedLayer, { enabled: isResetEnabled }, [isResetEnabled, resetSelectedLayer]);
isResetEnabled,
isStaging,
resetSelectedLayer,
]);
} }

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { selectCanvasV2Slice, selectEntity } from 'features/controlLayers/store/canvasV2Slice'; import { selectCanvasV2Slice, selectEntity } from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types'; import { type CanvasEntityIdentifier,isDrawableEntity } from 'features/controlLayers/store/types';
import { useMemo } from 'react'; import { useMemo } from 'react';
export const useEntityObjectCount = (entityIdentifier: CanvasEntityIdentifier) => { export const useEntityObjectCount = (entityIdentifier: CanvasEntityIdentifier) => {
@ -11,11 +11,7 @@ export const useEntityObjectCount = (entityIdentifier: CanvasEntityIdentifier) =
const entity = selectEntity(canvasV2, entityIdentifier); const entity = selectEntity(canvasV2, entityIdentifier);
if (!entity) { if (!entity) {
return 0; return 0;
} else if (entity.type === 'layer') { } else if (isDrawableEntity(entity)) {
return entity.objects.length;
} else if (entity.type === 'inpaint_mask') {
return entity.objects.length;
} else if (entity.type === 'regional_guidance') {
return entity.objects.length; return entity.objects.length;
} else { } else {
return 0; return 0;

View File

@ -13,10 +13,10 @@ export const useEntityTitle = (entityIdentifier: CanvasEntityIdentifier) => {
const parts: string[] = []; const parts: string[] = [];
if (entityIdentifier.type === 'inpaint_mask') { if (entityIdentifier.type === 'inpaint_mask') {
parts.push(t('controlLayers.inpaintMask')); parts.push(t('controlLayers.inpaintMask'));
} else if (entityIdentifier.type === 'control_adapter') { } else if (entityIdentifier.type === 'control_layer') {
parts.push(t('controlLayers.globalControlAdapter')); parts.push(t('controlLayers.controlLayer'));
} else if (entityIdentifier.type === 'layer') { } else if (entityIdentifier.type === 'raster_layer') {
parts.push(t('controlLayers.layer')); parts.push(t('controlLayers.rasterLayer'));
} else if (entityIdentifier.type === 'ip_adapter') { } else if (entityIdentifier.type === 'ip_adapter') {
parts.push(t('controlLayers.ipAdapter')); parts.push(t('controlLayers.ipAdapter'));
} else if (entityIdentifier.type === 'regional_guidance') { } else if (entityIdentifier.type === 'regional_guidance') {

View File

@ -1,23 +1,19 @@
import { createMemoizedAppSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedAppSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { layerUsedAsControlChanged, selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice'; import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { selectLayer } from 'features/controlLayers/store/layersReducers'; import { selectControlLayerOrThrow } from 'features/controlLayers/store/controlLayersReducers';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types'; import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { initialControlNetV2, initialT2IAdapterV2 } from 'features/controlLayers/store/types'; import { initialControlNetV2, initialT2IAdapterV2 } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common';
import { useCallback, useMemo } from 'react'; import { useMemo } from 'react';
import { useControlNetAndT2IAdapterModels } from 'services/api/hooks/modelsByType'; import { useControlNetAndT2IAdapterModels } from 'services/api/hooks/modelsByType';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
export const useLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier) => { export const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier) => {
const selectControlAdapter = useMemo( const selectControlAdapter = useMemo(
() => () =>
createMemoizedAppSelector(selectCanvasV2Slice, (canvasV2) => { createMemoizedAppSelector(selectCanvasV2Slice, (canvasV2) => {
const layer = selectLayer(canvasV2, entityIdentifier.id); const layer = selectControlLayerOrThrow(canvasV2, entityIdentifier.id);
if (!layer) {
return null;
}
return layer.controlAdapter; return layer.controlAdapter;
}), }),
[entityIdentifier] [entityIdentifier]
@ -26,32 +22,23 @@ export const useLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier)
return controlAdapter; return controlAdapter;
}; };
export const useLayerUseAsControl = (entityIdentifier: CanvasEntityIdentifier) => { export const useDefaultControlAdapter = () => {
const dispatch = useAppDispatch();
const [modelConfigs] = useControlNetAndT2IAdapterModels(); const [modelConfigs] = useControlNetAndT2IAdapterModels();
const baseModel = useAppSelector((s) => s.canvasV2.params.model?.base); const baseModel = useAppSelector((s) => s.canvasV2.params.model?.base);
const controlAdapter = useLayerControlAdapter(entityIdentifier);
const model: ControlNetModelConfig | T2IAdapterModelConfig | null = useMemo(() => { const defaultControlAdapter = useMemo(() => {
// prefer to use a model that matches the base model
const compatibleModels = modelConfigs.filter((m) => (baseModel ? m.base === baseModel : true)); const compatibleModels = modelConfigs.filter((m) => (baseModel ? m.base === baseModel : true));
return compatibleModels[0] ?? modelConfigs[0] ?? null; const model = compatibleModels[0] ?? modelConfigs[0] ?? null;
}, [baseModel, modelConfigs]); const controlAdapter =
model?.type === 't2i_adapter' ? deepClone(initialT2IAdapterV2) : deepClone(initialControlNetV2);
const toggle = useCallback(() => {
if (controlAdapter) {
dispatch(layerUsedAsControlChanged({ id: entityIdentifier.id, controlAdapter: null }));
return;
}
const newControlAdapter = deepClone(model?.type === 't2i_adapter' ? initialT2IAdapterV2 : initialControlNetV2);
if (model) { if (model) {
newControlAdapter.model = zModelIdentifierField.parse(model); controlAdapter.model = zModelIdentifierField.parse(model);
} }
dispatch(layerUsedAsControlChanged({ id: entityIdentifier.id, controlAdapter: newControlAdapter })); return controlAdapter;
}, [controlAdapter, dispatch, entityIdentifier.id, model]); }, [baseModel, modelConfigs]);
return { hasControlAdapter: Boolean(controlAdapter), toggle }; return defaultControlAdapter;
}; };

View File

@ -4,7 +4,12 @@ import { CanvasFilter } from 'features/controlLayers/konva/CanvasFilter';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasObjectRenderer } from 'features/controlLayers/konva/CanvasObjectRenderer'; import { CanvasObjectRenderer } from 'features/controlLayers/konva/CanvasObjectRenderer';
import { CanvasTransformer } from 'features/controlLayers/konva/CanvasTransformer'; import { CanvasTransformer } from 'features/controlLayers/konva/CanvasTransformer';
import type { CanvasEntityIdentifier, CanvasLayerState, CanvasV2State } from 'features/controlLayers/store/types'; import type {
CanvasControlLayerState,
CanvasEntityIdentifier,
CanvasRasterLayerState,
CanvasV2State,
} from 'features/controlLayers/store/types';
import Konva from 'konva'; import Konva from 'konva';
import { get } from 'lodash-es'; import { get } from 'lodash-es';
import type { Logger } from 'roarr'; import type { Logger } from 'roarr';
@ -17,7 +22,7 @@ export class CanvasLayerAdapter {
manager: CanvasManager; manager: CanvasManager;
log: Logger; log: Logger;
state: CanvasLayerState; state: CanvasRasterLayerState | CanvasControlLayerState;
konva: { konva: {
layer: Konva.Layer; layer: Konva.Layer;
@ -110,7 +115,7 @@ export class CanvasLayerAdapter {
this.konva.layer.visible(isEnabled); this.konva.layer.visible(isEnabled);
}; };
updateObjects = async (arg?: { objects: CanvasLayerState['objects'] }) => { updateObjects = async (arg?: { objects: CanvasRasterLayerState['objects'] }) => {
this.log.trace('Updating objects'); this.log.trace('Updating objects');
const objects = get(arg, 'objects', this.state.objects); const objects = get(arg, 'objects', this.state.objects);

View File

@ -29,7 +29,6 @@ import { getImageDTO, uploadImage } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types'; import type { ImageDTO } from 'services/api/types';
import { CanvasBackground } from './CanvasBackground'; import { CanvasBackground } from './CanvasBackground';
import type { CanvasControlAdapter } from './CanvasControlAdapter';
import { CanvasLayerAdapter } from './CanvasLayerAdapter'; import { CanvasLayerAdapter } from './CanvasLayerAdapter';
import { CanvasMaskAdapter } from './CanvasMaskAdapter'; import { CanvasMaskAdapter } from './CanvasMaskAdapter';
import { CanvasPreview } from './CanvasPreview'; import { CanvasPreview } from './CanvasPreview';
@ -46,10 +45,10 @@ export class CanvasManager {
path: string[]; path: string[];
stage: Konva.Stage; stage: Konva.Stage;
container: HTMLDivElement; container: HTMLDivElement;
controlAdapters: Map<string, CanvasControlAdapter>; rasterLayerAdapters: Map<string, CanvasLayerAdapter> = new Map();
layers: Map<string, CanvasLayerAdapter>; controlLayerAdapters: Map<string, CanvasLayerAdapter> = new Map();
regions: Map<string, CanvasMaskAdapter>; regionalGuidanceAdapters: Map<string, CanvasMaskAdapter> = new Map();
inpaintMask: CanvasMaskAdapter; inpaintMaskAdapter: CanvasMaskAdapter;
stateApi: CanvasStateApi; stateApi: CanvasStateApi;
preview: CanvasPreview; preview: CanvasPreview;
background: CanvasBackground; background: CanvasBackground;
@ -94,10 +93,6 @@ export class CanvasManager {
this.background = new CanvasBackground(this); this.background = new CanvasBackground(this);
this.stage.add(this.background.konva.layer); this.stage.add(this.background.konva.layer);
this.layers = new Map();
this.regions = new Map();
this.controlAdapters = new Map();
this._worker.onmessage = (event: MessageEvent<ExtentsResult | WorkerLogMessage>) => { this._worker.onmessage = (event: MessageEvent<ExtentsResult | WorkerLogMessage>) => {
const { type, data } = event.data; const { type, data } = event.data;
if (type === 'log') { if (type === 'log') {
@ -128,8 +123,8 @@ export class CanvasManager {
this.stateApi.$currentFill.set(this.stateApi.getCurrentFill()); this.stateApi.$currentFill.set(this.stateApi.getCurrentFill());
this.stateApi.$selectedEntity.set(this.stateApi.getSelectedEntity()); this.stateApi.$selectedEntity.set(this.stateApi.getSelectedEntity());
this.inpaintMask = new CanvasMaskAdapter(this.stateApi.getInpaintMaskState(), this); this.inpaintMaskAdapter = new CanvasMaskAdapter(this.stateApi.getInpaintMaskState(), this);
this.stage.add(this.inpaintMask.konva.layer); this.stage.add(this.inpaintMaskAdapter.konva.layer);
} }
enableDebugging() { enableDebugging() {
@ -152,18 +147,24 @@ export class CanvasManager {
} }
arrangeEntities() { arrangeEntities() {
const { getLayersState, getRegionsState } = this.stateApi;
const layers = getLayersState().entities;
const regions = getRegionsState().entities;
let zIndex = 0; let zIndex = 0;
this.background.konva.layer.zIndex(++zIndex); this.background.konva.layer.zIndex(++zIndex);
for (const layer of layers) {
this.layers.get(layer.id)?.konva.layer.zIndex(++zIndex); for (const layer of this.stateApi.getRasterLayersState().entities) {
this.rasterLayerAdapters.get(layer.id)?.konva.layer.zIndex(++zIndex);
} }
for (const rg of regions) {
this.regions.get(rg.id)?.konva.layer.zIndex(++zIndex); for (const layer of this.stateApi.getControlLayersState().entities) {
this.controlLayerAdapters.get(layer.id)?.konva.layer.zIndex(++zIndex);
} }
this.inpaintMask.konva.layer.zIndex(++zIndex);
for (const rg of this.stateApi.getRegionsState().entities) {
this.regionalGuidanceAdapters.get(rg.id)?.konva.layer.zIndex(++zIndex);
}
this.inpaintMaskAdapter.konva.layer.zIndex(++zIndex);
this.preview.getLayer().zIndex(++zIndex); this.preview.getLayer().zIndex(++zIndex);
} }
@ -215,12 +216,14 @@ export class CanvasManager {
const { id, type } = transformingEntity; const { id, type } = transformingEntity;
if (type === 'layer') { if (type === 'raster_layer') {
return this.layers.get(id) ?? null; return this.rasterLayerAdapters.get(id) ?? null;
} else if (type === 'control_layer') {
return this.controlLayerAdapters.get(id) ?? null;
} else if (type === 'inpaint_mask') { } else if (type === 'inpaint_mask') {
return this.inpaintMask; return this.inpaintMaskAdapter;
} else if (type === 'regional_guidance') { } else if (type === 'regional_guidance') {
return this.regions.get(id) ?? null; return this.regionalGuidanceAdapters.get(id) ?? null;
} }
return null; return null;
@ -268,21 +271,46 @@ export class CanvasManager {
return; return;
} }
if (this._isFirstRender || state.layers.entities !== this._prevState.layers.entities) { if (this._isFirstRender || state.rasterLayers.entities !== this._prevState.rasterLayers.entities) {
this.log.debug('Rendering layers'); this.log.debug('Rendering raster layers');
for (const canvasLayer of this.layers.values()) { for (const canvasLayer of this.rasterLayerAdapters.values()) {
if (!state.layers.entities.find((l) => l.id === canvasLayer.id)) { if (!state.rasterLayers.entities.find((l) => l.id === canvasLayer.id)) {
await canvasLayer.destroy(); await canvasLayer.destroy();
this.layers.delete(canvasLayer.id); this.rasterLayerAdapters.delete(canvasLayer.id);
} }
} }
for (const entityState of state.layers.entities) { for (const entityState of state.rasterLayers.entities) {
let adapter = this.layers.get(entityState.id); let adapter = this.rasterLayerAdapters.get(entityState.id);
if (!adapter) { if (!adapter) {
adapter = new CanvasLayerAdapter(entityState, this); adapter = new CanvasLayerAdapter(entityState, this);
this.layers.set(adapter.id, adapter); this.rasterLayerAdapters.set(adapter.id, adapter);
this.stage.add(adapter.konva.layer);
}
await adapter.update({
state: entityState,
toolState: state.tool,
isSelected: state.selectedEntityIdentifier?.id === entityState.id,
});
}
}
if (this._isFirstRender || state.controlLayers.entities !== this._prevState.controlLayers.entities) {
this.log.debug('Rendering control layers');
for (const canvasLayer of this.controlLayerAdapters.values()) {
if (!state.controlLayers.entities.find((l) => l.id === canvasLayer.id)) {
await canvasLayer.destroy();
this.controlLayerAdapters.delete(canvasLayer.id);
}
}
for (const entityState of state.controlLayers.entities) {
let adapter = this.controlLayerAdapters.get(entityState.id);
if (!adapter) {
adapter = new CanvasLayerAdapter(entityState, this);
this.controlLayerAdapters.set(adapter.id, adapter);
this.stage.add(adapter.konva.layer); this.stage.add(adapter.konva.layer);
} }
await adapter.update({ await adapter.update({
@ -303,18 +331,18 @@ export class CanvasManager {
this.log.debug('Rendering regions'); this.log.debug('Rendering regions');
// Destroy the konva nodes for nonexistent entities // Destroy the konva nodes for nonexistent entities
for (const canvasRegion of this.regions.values()) { for (const canvasRegion of this.regionalGuidanceAdapters.values()) {
if (!state.regions.entities.find((rg) => rg.id === canvasRegion.id)) { if (!state.regions.entities.find((rg) => rg.id === canvasRegion.id)) {
canvasRegion.destroy(); canvasRegion.destroy();
this.regions.delete(canvasRegion.id); this.regionalGuidanceAdapters.delete(canvasRegion.id);
} }
} }
for (const entityState of state.regions.entities) { for (const entityState of state.regions.entities) {
let adapter = this.regions.get(entityState.id); let adapter = this.regionalGuidanceAdapters.get(entityState.id);
if (!adapter) { if (!adapter) {
adapter = new CanvasMaskAdapter(entityState, this); adapter = new CanvasMaskAdapter(entityState, this);
this.regions.set(adapter.id, adapter); this.regionalGuidanceAdapters.set(adapter.id, adapter);
this.stage.add(adapter.konva.layer); this.stage.add(adapter.konva.layer);
} }
await adapter.update({ await adapter.update({
@ -333,7 +361,7 @@ export class CanvasManager {
state.selectedEntityIdentifier?.id !== this._prevState.selectedEntityIdentifier?.id state.selectedEntityIdentifier?.id !== this._prevState.selectedEntityIdentifier?.id
) { ) {
this.log.debug('Rendering inpaint mask'); this.log.debug('Rendering inpaint mask');
await this.inpaintMask.update({ await this.inpaintMaskAdapter.update({
state: state.inpaintMask, state: state.inpaintMask,
toolState: state.tool, toolState: state.tool,
isSelected: state.selectedEntityIdentifier?.id === state.inpaintMask.id, isSelected: state.selectedEntityIdentifier?.id === state.inpaintMask.id,
@ -354,11 +382,6 @@ export class CanvasManager {
await this.preview.bbox.render(); await this.preview.bbox.render();
} }
if (this._isFirstRender || state.layers !== this._prevState.layers || state.regions !== this._prevState.regions) {
// this.log.debug('Updating entity bboxes');
// debouncedUpdateBboxes(stage, canvasV2.layers, canvasV2.controlAdapters, canvasV2.regions, onBboxChanged);
}
if (this._isFirstRender || state.session !== this._prevState.session) { if (this._isFirstRender || state.session !== this._prevState.session) {
this.log.debug('Rendering staging area'); this.log.debug('Rendering staging area');
await this.preview.stagingArea.render(); await this.preview.stagingArea.render();
@ -366,7 +389,7 @@ export class CanvasManager {
if ( if (
this._isFirstRender || this._isFirstRender ||
state.layers.entities !== this._prevState.layers.entities || state.rasterLayers.entities !== this._prevState.rasterLayers.entities ||
state.regions.entities !== this._prevState.regions.entities || state.regions.entities !== this._prevState.regions.entities ||
state.inpaintMask !== this._prevState.inpaintMask || state.inpaintMask !== this._prevState.inpaintMask ||
state.selectedEntityIdentifier?.id !== this._prevState.selectedEntityIdentifier?.id state.selectedEntityIdentifier?.id !== this._prevState.selectedEntityIdentifier?.id
@ -402,15 +425,15 @@ export class CanvasManager {
return () => { return () => {
this.log.debug('Cleaning up konva renderer'); this.log.debug('Cleaning up konva renderer');
this.inpaintMask.destroy(); this.inpaintMaskAdapter.destroy();
for (const region of this.regions.values()) { for (const adapter of this.regionalGuidanceAdapters.values()) {
region.destroy(); adapter.destroy();
} }
for (const layer of this.layers.values()) { for (const adapter of this.rasterLayerAdapters.values()) {
layer.destroy(); adapter.destroy();
} }
for (const controlAdapter of this.controlAdapters.values()) { for (const adapter of this.controlLayerAdapters.values()) {
controlAdapter.destroy(); adapter.destroy();
} }
this.background.destroy(); this.background.destroy();
this.preview.destroy(); this.preview.destroy();
@ -507,7 +530,7 @@ export class CanvasManager {
} }
getCompositeLayerStageClone = (): Konva.Stage => { getCompositeLayerStageClone = (): Konva.Stage => {
const layersState = this.stateApi.getLayersState(); const layersState = this.stateApi.getRasterLayersState();
const stageClone = this.stage.clone(); const stageClone = this.stage.clone();
stageClone.scaleX(1); stageClone.scaleX(1);
@ -536,7 +559,7 @@ export class CanvasManager {
}; };
getCompositeRasterizedImageCache = (rect: Rect): ImageCache | null => { getCompositeRasterizedImageCache = (rect: Rect): ImageCache | null => {
const layerState = this.stateApi.getLayersState(); const layerState = this.stateApi.getRasterLayersState();
const imageCache = layerState.compositeRasterizationCache.find((cache) => isEqual(cache.rect, rect)); const imageCache = layerState.compositeRasterizationCache.find((cache) => isEqual(cache.rect, rect));
return imageCache ?? null; return imageCache ?? null;
}; };
@ -567,11 +590,11 @@ export class CanvasManager {
}; };
getInpaintMaskBlob = (rect?: Rect): Promise<Blob> => { getInpaintMaskBlob = (rect?: Rect): Promise<Blob> => {
return this.inpaintMask.renderer.getBlob(rect); return this.inpaintMaskAdapter.renderer.getBlob(rect);
}; };
getInpaintMaskImageData = (rect?: Rect): ImageData => { getInpaintMaskImageData = (rect?: Rect): ImageData => {
return this.inpaintMask.renderer.getImageData(rect); return this.inpaintMaskAdapter.renderer.getImageData(rect);
}; };
getGenerationMode(): GenerationMode { getGenerationMode(): GenerationMode {
@ -617,7 +640,7 @@ export class CanvasManager {
logDebugInfo() { logDebugInfo() {
// eslint-disable-next-line no-console // eslint-disable-next-line no-console
console.log(this); console.log(this);
for (const layer of this.layers.values()) { for (const layer of this.rasterLayerAdapters.values()) {
// eslint-disable-next-line no-console // eslint-disable-next-line no-console
console.log(layer); console.log(layer);
} }

View File

@ -26,14 +26,15 @@ import {
entityReset, entityReset,
entitySelected, entitySelected,
eraserWidthChanged, eraserWidthChanged,
layerCompositeRasterized, rasterLayerCompositeRasterized,
toolBufferChanged, toolBufferChanged,
toolChanged, toolChanged,
} from 'features/controlLayers/store/canvasV2Slice'; } from 'features/controlLayers/store/canvasV2Slice';
import type { import type {
CanvasControlLayerState,
CanvasEntityIdentifier, CanvasEntityIdentifier,
CanvasInpaintMaskState, CanvasInpaintMaskState,
CanvasLayerState, CanvasRasterLayerState,
CanvasRegionalGuidanceState, CanvasRegionalGuidanceState,
CanvasV2State, CanvasV2State,
EntityBrushLineAddedPayload, EntityBrushLineAddedPayload,
@ -53,8 +54,14 @@ import { atom } from 'nanostores';
type EntityStateAndAdapter = type EntityStateAndAdapter =
| { | {
id: string; id: string;
type: CanvasLayerState['type']; type: CanvasRasterLayerState['type'];
state: CanvasLayerState; state: CanvasRasterLayerState;
adapter: CanvasLayerAdapter;
}
| {
id: string;
type: CanvasControlLayerState['type'];
state: CanvasControlLayerState;
adapter: CanvasLayerAdapter; adapter: CanvasLayerAdapter;
} }
| { | {
@ -63,12 +70,6 @@ type EntityStateAndAdapter =
state: CanvasInpaintMaskState; state: CanvasInpaintMaskState;
adapter: CanvasMaskAdapter; adapter: CanvasMaskAdapter;
} }
// | {
// id: string;
// type: CanvasControlAdapterState['type'];
// state: CanvasControlAdapterState;
// adapter: CanvasControlAdapter;
// }
| { | {
id: string; id: string;
type: CanvasRegionalGuidanceState['type']; type: CanvasRegionalGuidanceState['type'];
@ -117,7 +118,7 @@ export class CanvasStateApi {
}; };
compositeLayerRasterized = (arg: { imageName: string; rect: Rect }) => { compositeLayerRasterized = (arg: { imageName: string; rect: Rect }) => {
log.trace(arg, 'Composite layer rasterized'); log.trace(arg, 'Composite layer rasterized');
this._store.dispatch(layerCompositeRasterized(arg)); this._store.dispatch(rasterLayerCompositeRasterized(arg));
}; };
setSelectedEntity = (arg: EntityIdentifierPayload) => { setSelectedEntity = (arg: EntityIdentifierPayload) => {
log.trace({ arg }, 'Setting selected entity'); log.trace({ arg }, 'Setting selected entity');
@ -157,8 +158,11 @@ export class CanvasStateApi {
getRegionsState = () => { getRegionsState = () => {
return this.getState().regions; return this.getState().regions;
}; };
getLayersState = () => { getRasterLayersState = () => {
return this.getState().layers; return this.getState().rasterLayers;
};
getControlLayersState = () => {
return this.getState().controlLayers;
}; };
getInpaintMaskState = () => { getInpaintMaskState = () => {
return this.getState().inpaintMask; return this.getState().inpaintMask;
@ -185,15 +189,18 @@ export class CanvasStateApi {
let entityState: EntityStateAndAdapter['state'] | null = null; let entityState: EntityStateAndAdapter['state'] | null = null;
let entityAdapter: EntityStateAndAdapter['adapter'] | null = null; let entityAdapter: EntityStateAndAdapter['adapter'] | null = null;
if (identifier.type === 'layer') { if (identifier.type === 'raster_layer') {
entityState = state.layers.entities.find((i) => i.id === identifier.id) ?? null; entityState = state.rasterLayers.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.manager.layers.get(identifier.id) ?? null; entityAdapter = this.manager.rasterLayerAdapters.get(identifier.id) ?? null;
} else if (identifier.type === 'control_layer') {
entityState = state.controlLayers.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.manager.controlLayerAdapters.get(identifier.id) ?? null;
} else if (identifier.type === 'regional_guidance') { } else if (identifier.type === 'regional_guidance') {
entityState = state.regions.entities.find((i) => i.id === identifier.id) ?? null; entityState = state.regions.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.manager.regions.get(identifier.id) ?? null; entityAdapter = this.manager.regionalGuidanceAdapters.get(identifier.id) ?? null;
} else if (identifier.type === 'inpaint_mask') { } else if (identifier.type === 'inpaint_mask') {
entityState = state.inpaintMask; entityState = state.inpaintMask;
entityAdapter = this.manager.inpaintMask; entityAdapter = this.manager.inpaintMaskAdapter;
} }
if (entityState && entityAdapter) { if (entityState && entityAdapter) {

View File

@ -8,6 +8,7 @@ import {
BRUSH_ERASER_BORDER_WIDTH, BRUSH_ERASER_BORDER_WIDTH,
} from 'features/controlLayers/konva/constants'; } from 'features/controlLayers/konva/constants';
import { alignCoordForTool, getPrefixedId } from 'features/controlLayers/konva/util'; import { alignCoordForTool, getPrefixedId } from 'features/controlLayers/konva/util';
import { isDrawableEntity } from 'features/controlLayers/store/types';
import Konva from 'konva'; import Konva from 'konva';
import type { Logger } from 'roarr'; import type { Logger } from 'roarr';
@ -156,10 +157,7 @@ export class CanvasTool {
const tool = toolState.selected; const tool = toolState.selected;
const isDrawableEntity = const isDrawable = selectedEntity && isDrawableEntity(selectedEntity.state);
selectedEntity?.state.type === 'regional_guidance' ||
selectedEntity?.state.type === 'layer' ||
selectedEntity?.state.type === 'inpaint_mask';
// Update the stage's pointer style // Update the stage's pointer style
if (tool === 'view') { if (tool === 'view') {
@ -168,7 +166,7 @@ export class CanvasTool {
} else if (renderedEntityCount === 0) { } else if (renderedEntityCount === 0) {
// We have no layers, so we should not render any tool // We have no layers, so we should not render any tool
stage.container().style.cursor = 'default'; stage.container().style.cursor = 'default';
} else if (!isDrawableEntity) { } else if (!isDrawable) {
// Non-drawable layers don't have tools // Non-drawable layers don't have tools
stage.container().style.cursor = 'not-allowed'; stage.container().style.cursor = 'not-allowed';
} else if (tool === 'move' || Boolean(this.manager.stateApi.$transformingEntity.get())) { } else if (tool === 'move' || Boolean(this.manager.stateApi.$transformingEntity.get())) {
@ -186,7 +184,7 @@ export class CanvasTool {
stage.draggable(tool === 'view'); stage.draggable(tool === 'view');
if (!cursorPos || renderedEntityCount === 0 || !isDrawableEntity) { if (!cursorPos || renderedEntityCount === 0 || !isDrawable) {
// We can bail early if the mouse isn't over the stage or there are no layers // We can bail early if the mouse isn't over the stage or there are no layers
this.konva.group.visible(false); this.konva.group.visible(false);
} else { } else {

View File

@ -6,8 +6,9 @@ import {
offsetCoord, offsetCoord,
} from 'features/controlLayers/konva/util'; } from 'features/controlLayers/konva/util';
import type { import type {
CanvasControlLayerState,
CanvasInpaintMaskState, CanvasInpaintMaskState,
CanvasLayerState, CanvasRasterLayerState,
CanvasRegionalGuidanceState, CanvasRegionalGuidanceState,
CanvasV2State, CanvasV2State,
Coordinate, Coordinate,
@ -84,7 +85,7 @@ const getLastPointOfLine = (points: number[]): Coordinate | null => {
}; };
const getLastPointOfLastLineOfEntity = ( const getLastPointOfLastLineOfEntity = (
entity: CanvasLayerState | CanvasRegionalGuidanceState | CanvasInpaintMaskState, entity: CanvasRasterLayerState | CanvasControlLayerState | CanvasRegionalGuidanceState | CanvasInpaintMaskState,
tool: Tool tool: Tool
): Coordinate | null => { ): Coordinate | null => {
const lastObject = entity.objects[entity.objects.length - 1]; const lastObject = entity.objects[entity.objects.length - 1];
@ -138,7 +139,9 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => {
return e.evt.buttons === 1; return e.evt.buttons === 1;
} }
function getClip(entity: CanvasRegionalGuidanceState | CanvasLayerState | CanvasInpaintMaskState) { function getClip(
entity: CanvasRegionalGuidanceState | CanvasControlLayerState | CanvasRasterLayerState | CanvasInpaintMaskState
) {
const settings = getSettings(); const settings = getSettings();
const bboxRect = getBbox().rect; const bboxRect = getBbox().rect;

View File

@ -12,12 +12,12 @@ import { pick } from 'lodash-es';
export const bboxReducers = { export const bboxReducers = {
bboxScaledSizeChanged: (state, action: PayloadAction<Partial<Dimensions>>) => { bboxScaledSizeChanged: (state, action: PayloadAction<Partial<Dimensions>>) => {
state.layers.imageCache = null; state.rasterLayers.imageCache = null;
state.bbox.scaledSize = { ...state.bbox.scaledSize, ...action.payload }; state.bbox.scaledSize = { ...state.bbox.scaledSize, ...action.payload };
}, },
bboxScaleMethodChanged: (state, action: PayloadAction<BoundingBoxScaleMethod>) => { bboxScaleMethodChanged: (state, action: PayloadAction<BoundingBoxScaleMethod>) => {
state.bbox.scaleMethod = action.payload; state.bbox.scaleMethod = action.payload;
state.layers.imageCache = null; state.rasterLayers.imageCache = null;
if (action.payload === 'auto') { if (action.payload === 'auto') {
const optimalDimension = getOptimalDimension(state.params.model); const optimalDimension = getOptimalDimension(state.params.model);
@ -27,7 +27,7 @@ export const bboxReducers = {
}, },
bboxChanged: (state, action: PayloadAction<IRect>) => { bboxChanged: (state, action: PayloadAction<IRect>) => {
state.bbox.rect = action.payload; state.bbox.rect = action.payload;
state.layers.imageCache = null; state.rasterLayers.imageCache = null;
if (state.bbox.scaleMethod === 'auto') { if (state.bbox.scaleMethod === 'auto') {
const optimalDimension = getOptimalDimension(state.params.model); const optimalDimension = getOptimalDimension(state.params.model);

View File

@ -5,11 +5,12 @@ import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/uti
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { bboxReducers } from 'features/controlLayers/store/bboxReducers'; import { bboxReducers } from 'features/controlLayers/store/bboxReducers';
import { compositingReducers } from 'features/controlLayers/store/compositingReducers'; import { compositingReducers } from 'features/controlLayers/store/compositingReducers';
import { controlLayersReducers } from 'features/controlLayers/store/controlLayersReducers';
import { inpaintMaskReducers } from 'features/controlLayers/store/inpaintMaskReducers'; import { inpaintMaskReducers } from 'features/controlLayers/store/inpaintMaskReducers';
import { ipAdaptersReducers } from 'features/controlLayers/store/ipAdaptersReducers'; import { ipAdaptersReducers } from 'features/controlLayers/store/ipAdaptersReducers';
import { layersReducers } from 'features/controlLayers/store/layersReducers';
import { lorasReducers } from 'features/controlLayers/store/lorasReducers'; import { lorasReducers } from 'features/controlLayers/store/lorasReducers';
import { paramsReducers } from 'features/controlLayers/store/paramsReducers'; import { paramsReducers } from 'features/controlLayers/store/paramsReducers';
import { rasterLayersReducers } from 'features/controlLayers/store/rasterLayersReducers';
import { regionsReducers } from 'features/controlLayers/store/regionsReducers'; import { regionsReducers } from 'features/controlLayers/store/regionsReducers';
import { sessionReducers } from 'features/controlLayers/store/sessionReducers'; import { sessionReducers } from 'features/controlLayers/store/sessionReducers';
import { settingsReducers } from 'features/controlLayers/store/settingsReducers'; import { settingsReducers } from 'features/controlLayers/store/settingsReducers';
@ -23,9 +24,10 @@ import type { InvocationDenoiseProgressEvent } from 'services/events/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import type { import type {
CanvasControlLayerState,
CanvasEntityIdentifier, CanvasEntityIdentifier,
CanvasInpaintMaskState, CanvasInpaintMaskState,
CanvasLayerState, CanvasRasterLayerState,
CanvasRegionalGuidanceState, CanvasRegionalGuidanceState,
CanvasV2State, CanvasV2State,
Coordinate, Coordinate,
@ -38,12 +40,13 @@ import type {
FilterConfig, FilterConfig,
StageAttrs, StageAttrs,
} from './types'; } from './types';
import { IMAGE_FILTERS, RGBA_RED } from './types'; import { IMAGE_FILTERS, isDrawableEntity, RGBA_RED } from './types';
const initialState: CanvasV2State = { const initialState: CanvasV2State = {
_version: 3, _version: 3,
selectedEntityIdentifier: null, selectedEntityIdentifier: null,
layers: { entities: [], compositeRasterizationCache: [] }, rasterLayers: { entities: [], compositeRasterizationCache: [] },
controlLayers: { entities: [] },
ipAdapters: { entities: [] }, ipAdapters: { entities: [] },
regions: { entities: [] }, regions: { entities: [] },
loras: [], loras: [],
@ -143,27 +146,21 @@ const initialState: CanvasV2State = {
export function selectEntity(state: CanvasV2State, { id, type }: CanvasEntityIdentifier) { export function selectEntity(state: CanvasV2State, { id, type }: CanvasEntityIdentifier) {
switch (type) { switch (type) {
case 'layer': case 'raster_layer':
return state.layers.entities.find((layer) => layer.id === id); return state.rasterLayers.entities.find((layer) => layer.id === id);
case 'control_layer':
return state.controlLayers.entities.find((layer) => layer.id === id);
case 'inpaint_mask': case 'inpaint_mask':
return state.inpaintMask; return state.inpaintMask;
case 'regional_guidance': case 'regional_guidance':
return state.regions.entities.find((rg) => rg.id === id); return state.regions.entities.find((rg) => rg.id === id);
case 'ip_adapter':
return state.ipAdapters.entities.find((ip) => ip.id === id);
default: default:
return; return;
} }
} }
const invalidateCompositeRasterizationCache = (entity: CanvasLayerState, state: CanvasV2State) => {
if (entity.controlAdapter === null) {
state.layers.compositeRasterizationCache = [];
}
};
const invalidateRasterizationCaches = ( const invalidateRasterizationCaches = (
entity: CanvasLayerState | CanvasInpaintMaskState | CanvasRegionalGuidanceState, entity: CanvasRasterLayerState | CanvasControlLayerState | CanvasInpaintMaskState | CanvasRegionalGuidanceState,
state: CanvasV2State state: CanvasV2State
) => { ) => {
// TODO(psyche): We can be more efficient and only invalidate caches when the entity's changes intersect with the // TODO(psyche): We can be more efficient and only invalidate caches when the entity's changes intersect with the
@ -176,8 +173,8 @@ const invalidateRasterizationCaches = (
// layer's image data will contribute to the composite layer's image data. // layer's image data will contribute to the composite layer's image data.
// If the layer is used as a control layer, it will not contribute to the composite layer, so we do not need to reset // If the layer is used as a control layer, it will not contribute to the composite layer, so we do not need to reset
// its cache. // its cache.
if (entity.type === 'layer') { if (entity.type === 'raster_layer') {
invalidateCompositeRasterizationCache(entity, state); state.rasterLayers.compositeRasterizationCache = [];
} }
}; };
@ -185,7 +182,8 @@ export const canvasV2Slice = createSlice({
name: 'canvasV2', name: 'canvasV2',
initialState, initialState,
reducers: { reducers: {
...layersReducers, ...rasterLayersReducers,
...controlLayersReducers,
...ipAdaptersReducers, ...ipAdaptersReducers,
...regionsReducers, ...regionsReducers,
...lorasReducers, ...lorasReducers,
@ -205,7 +203,7 @@ export const canvasV2Slice = createSlice({
const entity = selectEntity(state, entityIdentifier); const entity = selectEntity(state, entityIdentifier);
if (!entity) { if (!entity) {
return; return;
} else if (entity.type === 'layer' || entity.type === 'inpaint_mask' || entity.type === 'regional_guidance') { } else if (isDrawableEntity(entity)) {
entity.isEnabled = true; entity.isEnabled = true;
entity.objects = []; entity.objects = [];
entity.position = { x: 0, y: 0 }; entity.position = { x: 0, y: 0 };
@ -229,7 +227,7 @@ export const canvasV2Slice = createSlice({
return; return;
} }
if (entity.type === 'layer' || entity.type === 'inpaint_mask' || entity.type === 'regional_guidance') { if (isDrawableEntity(entity)) {
entity.position = position; entity.position = position;
// When an entity is moved, we need to invalidate the rasterization caches. // When an entity is moved, we need to invalidate the rasterization caches.
invalidateRasterizationCaches(entity, state); invalidateRasterizationCaches(entity, state);
@ -242,7 +240,7 @@ export const canvasV2Slice = createSlice({
return; return;
} }
if (entity.type === 'layer' || entity.type === 'inpaint_mask' || entity.type === 'regional_guidance') { if (isDrawableEntity(entity)) {
entity.objects = [imageObject]; entity.objects = [imageObject];
entity.position = { x: rect.x, y: rect.y }; entity.position = { x: rect.x, y: rect.y };
// Remove the cache for the given rect. This should never happen, because we should never rasterize the same // Remove the cache for the given rect. This should never happen, because we should never rasterize the same
@ -258,7 +256,7 @@ export const canvasV2Slice = createSlice({
return; return;
} }
if (entity.type === 'layer' || entity.type === 'inpaint_mask' || entity.type === 'regional_guidance') { if (isDrawableEntity(entity)) {
entity.objects.push(brushLine); entity.objects.push(brushLine);
// When adding a brush line, we need to invalidate the rasterization caches. // When adding a brush line, we need to invalidate the rasterization caches.
invalidateRasterizationCaches(entity, state); invalidateRasterizationCaches(entity, state);
@ -269,7 +267,7 @@ export const canvasV2Slice = createSlice({
const entity = selectEntity(state, entityIdentifier); const entity = selectEntity(state, entityIdentifier);
if (!entity) { if (!entity) {
return; return;
} else if (entity.type === 'layer' || entity.type === 'inpaint_mask' || entity.type === 'regional_guidance') { } else if (isDrawableEntity(entity)) {
entity.objects.push(eraserLine); entity.objects.push(eraserLine);
// When adding an eraser line, we need to invalidate the rasterization caches. // When adding an eraser line, we need to invalidate the rasterization caches.
invalidateRasterizationCaches(entity, state); invalidateRasterizationCaches(entity, state);
@ -282,7 +280,7 @@ export const canvasV2Slice = createSlice({
const entity = selectEntity(state, entityIdentifier); const entity = selectEntity(state, entityIdentifier);
if (!entity) { if (!entity) {
return; return;
} else if (entity.type === 'layer') { } else if (isDrawableEntity(entity)) {
entity.objects.push(rect); entity.objects.push(rect);
// When adding an eraser line, we need to invalidate the rasterization caches. // When adding an eraser line, we need to invalidate the rasterization caches.
invalidateRasterizationCaches(entity, state); invalidateRasterizationCaches(entity, state);
@ -292,18 +290,37 @@ export const canvasV2Slice = createSlice({
}, },
entityDeleted: (state, action: PayloadAction<EntityIdentifierPayload>) => { entityDeleted: (state, action: PayloadAction<EntityIdentifierPayload>) => {
const { entityIdentifier } = action.payload; const { entityIdentifier } = action.payload;
const entity = selectEntity(state, entityIdentifier);
if (entity?.type === 'layer') { let selectedEntityIdentifier: CanvasEntityIdentifier = { type: state.inpaintMask.type, id: state.inpaintMask.id };
// When a layer is deleted, we may need to invalidate the composite rasterization cache.
invalidateCompositeRasterizationCache(entity, state); if (entityIdentifier.type === 'raster_layer') {
// When deleting a raster layer, we need to invalidate the composite rasterization cache.
const index = state.rasterLayers.entities.findIndex((layer) => layer.id === entityIdentifier.id);
state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id);
state.rasterLayers.compositeRasterizationCache = [];
const nextRasterLayer = state.rasterLayers.entities[index];
if (nextRasterLayer) {
selectedEntityIdentifier = { type: nextRasterLayer.type, id: nextRasterLayer.id };
}
} else if (entityIdentifier.type === 'control_layer') {
const index = state.controlLayers.entities.findIndex((layer) => layer.id === entityIdentifier.id);
state.controlLayers.entities = state.controlLayers.entities.filter((rg) => rg.id !== entityIdentifier.id);
const nextControlLayer = state.controlLayers.entities[index];
if (nextControlLayer) {
selectedEntityIdentifier = { type: nextControlLayer.type, id: nextControlLayer.id };
} }
if (entityIdentifier.type === 'layer') {
state.layers.entities = state.layers.entities.filter((layer) => layer.id !== entityIdentifier.id);
} else if (entityIdentifier.type === 'regional_guidance') { } else if (entityIdentifier.type === 'regional_guidance') {
const index = state.regions.entities.findIndex((layer) => layer.id === entityIdentifier.id);
state.regions.entities = state.regions.entities.filter((rg) => rg.id !== entityIdentifier.id); state.regions.entities = state.regions.entities.filter((rg) => rg.id !== entityIdentifier.id);
const region = state.regions.entities[index];
if (region) {
selectedEntityIdentifier = { type: region.type, id: region.id };
}
} else { } else {
assert(false, 'Not implemented'); assert(false, 'Not implemented');
} }
state.selectedEntityIdentifier = selectedEntityIdentifier;
}, },
entityArrangedForwardOne: (state, action: PayloadAction<EntityIdentifierPayload>) => { entityArrangedForwardOne: (state, action: PayloadAction<EntityIdentifierPayload>) => {
const { entityIdentifier } = action.payload; const { entityIdentifier } = action.payload;
@ -311,10 +328,12 @@ export const canvasV2Slice = createSlice({
if (!entity) { if (!entity) {
return; return;
} }
if (entity.type === 'layer') { if (entity.type === 'raster_layer') {
moveOneToEnd(state.layers.entities, entity); moveOneToEnd(state.rasterLayers.entities, entity);
// When arranging an entity, we may need to invalidate the composite rasterization cache. // When arranging a raster layer, we need to invalidate the composite rasterization cache.
invalidateCompositeRasterizationCache(entity, state); state.rasterLayers.compositeRasterizationCache = [];
} else if (entity.type === 'control_layer') {
moveOneToEnd(state.controlLayers.entities, entity);
} else if (entity.type === 'regional_guidance') { } else if (entity.type === 'regional_guidance') {
moveOneToEnd(state.regions.entities, entity); moveOneToEnd(state.regions.entities, entity);
} }
@ -325,10 +344,12 @@ export const canvasV2Slice = createSlice({
if (!entity) { if (!entity) {
return; return;
} }
if (entity.type === 'layer') { if (entity.type === 'raster_layer') {
moveToEnd(state.layers.entities, entity); moveToEnd(state.rasterLayers.entities, entity);
// When arranging an entity, we may need to invalidate the composite rasterization cache. // When arranging a raster layer, we need to invalidate the composite rasterization cache.
invalidateCompositeRasterizationCache(entity, state); state.rasterLayers.compositeRasterizationCache = [];
} else if (entity.type === 'control_layer') {
moveToEnd(state.controlLayers.entities, entity);
} else if (entity.type === 'regional_guidance') { } else if (entity.type === 'regional_guidance') {
moveToEnd(state.regions.entities, entity); moveToEnd(state.regions.entities, entity);
} }
@ -339,10 +360,11 @@ export const canvasV2Slice = createSlice({
if (!entity) { if (!entity) {
return; return;
} }
if (entity.type === 'layer') { if (entity.type === 'raster_layer') {
moveOneToStart(state.layers.entities, entity); moveOneToStart(state.rasterLayers.entities, entity);
// When arranging an entity, we may need to invalidate the composite rasterization cache. // When arranging a raster layer, we need to invalidate the composite rasterization cache.
invalidateCompositeRasterizationCache(entity, state); } else if (entity.type === 'control_layer') {
moveOneToStart(state.controlLayers.entities, entity);
} else if (entity.type === 'regional_guidance') { } else if (entity.type === 'regional_guidance') {
moveOneToStart(state.regions.entities, entity); moveOneToStart(state.regions.entities, entity);
} }
@ -353,18 +375,19 @@ export const canvasV2Slice = createSlice({
if (!entity) { if (!entity) {
return; return;
} }
if (entity.type === 'layer') { if (entity.type === 'raster_layer') {
moveToStart(state.layers.entities, entity); moveToStart(state.rasterLayers.entities, entity);
// When arranging an entity, we may need to invalidate the composite rasterization cache. state.rasterLayers.compositeRasterizationCache = [];
invalidateCompositeRasterizationCache(entity, state); } else if (entity.type === 'control_layer') {
moveToStart(state.controlLayers.entities, entity);
} else if (entity.type === 'regional_guidance') { } else if (entity.type === 'regional_guidance') {
moveToStart(state.regions.entities, entity); moveToStart(state.regions.entities, entity);
} }
}, },
allEntitiesDeleted: (state) => { allEntitiesDeleted: (state) => {
state.regions.entities = []; state.regions.entities = [];
state.layers.entities = []; state.rasterLayers.entities = [];
state.layers.compositeRasterizationCache = []; state.rasterLayers.compositeRasterizationCache = [];
state.ipAdapters.entities = []; state.ipAdapters.entities = [];
}, },
filterSelected: (state, action: PayloadAction<{ type: FilterConfig['type'] }>) => { filterSelected: (state, action: PayloadAction<{ type: FilterConfig['type'] }>) => {
@ -377,8 +400,8 @@ export const canvasV2Slice = createSlice({
// Invalidate the rasterization caches for all entities. // Invalidate the rasterization caches for all entities.
// Layers & composite layer // Layers & composite layer
state.layers.compositeRasterizationCache = []; state.rasterLayers.compositeRasterizationCache = [];
for (const layer of state.layers.entities) { for (const layer of state.rasterLayers.entities) {
layer.rasterizationCache = []; layer.rasterizationCache = [];
} }
@ -399,7 +422,8 @@ export const canvasV2Slice = createSlice({
state.bbox.scaledSize = getScaledBoundingBoxDimensions(size, optimalDimension); state.bbox.scaledSize = getScaledBoundingBoxDimensions(size, optimalDimension);
state.ipAdapters = deepClone(initialState.ipAdapters); state.ipAdapters = deepClone(initialState.ipAdapters);
state.layers = deepClone(initialState.layers); state.rasterLayers = deepClone(initialState.rasterLayers);
state.controlLayers = deepClone(initialState.controlLayers);
state.regions = deepClone(initialState.regions); state.regions = deepClone(initialState.regions);
state.selectedEntityIdentifier = deepClone(initialState.selectedEntityIdentifier); state.selectedEntityIdentifier = deepClone(initialState.selectedEntityIdentifier);
state.session = deepClone(initialState.session); state.session = deepClone(initialState.session);
@ -445,16 +469,21 @@ export const {
bboxAspectRatioIdChanged, bboxAspectRatioIdChanged,
bboxDimensionsSwapped, bboxDimensionsSwapped,
bboxSizeOptimized, bboxSizeOptimized,
// layers // Raster layers
layerAdded, rasterLayerAdded,
layerRecalled, rasterLayerRecalled,
layerAllDeleted, rasterLayerAllDeleted,
layerUsedAsControlChanged, rasterLayerConvertedToControlLayer,
layerControlAdapterModelChanged, rasterLayerCompositeRasterized,
layerControlAdapterControlModeChanged, // Control layers
layerControlAdapterWeightChanged, controlLayerAdded,
layerControlAdapterBeginEndStepPctChanged, controlLayerRecalled,
layerCompositeRasterized, controlLayerAllDeleted,
controlLayerConvertedToRasterLayer,
controlLayerModelChanged,
controlLayerControlModeChanged,
controlLayerWeightChanged,
controlLayerBeginEndStepPctChanged,
// IP Adapters // IP Adapters
ipaAdded, ipaAdded,
ipaRecalled, ipaRecalled,

View File

@ -0,0 +1,153 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { merge, omit } from 'lodash-es';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import type {
CanvasControlLayerState,
CanvasRasterLayerState,
CanvasV2State,
ControlModeV2,
ControlNetConfig,
T2IAdapterConfig,
} from './types';
import { initialControlNetV2 } from './types';
export const selectControlLayer = (state: CanvasV2State, id: string) =>
state.controlLayers.entities.find((layer) => layer.id === id);
export const selectControlLayerOrThrow = (state: CanvasV2State, id: string) => {
const layer = selectControlLayer(state, id);
assert(layer, `Layer with id ${id} not found`);
return layer;
};
export const controlLayersReducers = {
controlLayerAdded: {
reducer: (
state,
action: PayloadAction<{ id: string; overrides?: Partial<CanvasControlLayerState>; isSelected?: boolean }>
) => {
const { id, overrides, isSelected } = action.payload;
const layer: CanvasControlLayerState = {
id,
type: 'control_layer',
isEnabled: true,
objects: [],
opacity: 1,
position: { x: 0, y: 0 },
rasterizationCache: [],
controlAdapter: deepClone(initialControlNetV2),
};
merge(layer, overrides);
state.controlLayers.entities.push(layer);
if (isSelected) {
state.selectedEntityIdentifier = { type: 'control_layer', id };
}
},
prepare: (payload: { overrides?: Partial<CanvasControlLayerState>; isSelected?: boolean }) => ({
payload: { ...payload, id: getPrefixedId('control_layer') },
}),
},
controlLayerRecalled: (state, action: PayloadAction<{ data: CanvasControlLayerState }>) => {
const { data } = action.payload;
state.controlLayers.entities.push(data);
state.selectedEntityIdentifier = { type: 'control_layer', id: data.id };
},
controlLayerAllDeleted: (state) => {
state.controlLayers.entities = [];
},
controlLayerConvertedToRasterLayer: {
reducer: (state, action: PayloadAction<{ id: string; newId: string }>) => {
const { id, newId } = action.payload;
const layer = selectControlLayer(state, id);
if (!layer) {
return;
}
// Convert the raster layer to control layer
const rasterLayerState: CanvasRasterLayerState = {
...omit(deepClone(layer), ['type', 'controlAdapter']),
id: newId,
type: 'raster_layer',
};
// Remove the control layer
state.controlLayers.entities = state.controlLayers.entities.filter((layer) => layer.id !== id);
// Add the new raster layer
state.rasterLayers.entities.push(rasterLayerState);
// The composite layer's image data will change when the control layer is converted to raster layer.
state.rasterLayers.compositeRasterizationCache = [];
state.selectedEntityIdentifier = { type: rasterLayerState.type, id: rasterLayerState.id };
},
prepare: (payload: { id: string }) => ({
payload: { ...payload, newId: getPrefixedId('raster_layer') },
}),
},
controlLayerModelChanged: (
state,
action: PayloadAction<{
id: string;
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null;
}>
) => {
const { id, modelConfig } = action.payload;
const layer = selectControlLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
if (!modelConfig) {
layer.controlAdapter.model = null;
return;
}
layer.controlAdapter.model = zModelIdentifierField.parse(modelConfig);
// We may need to convert the CA to match the model
if (layer.controlAdapter.type === 't2i_adapter' && layer.controlAdapter.model.type === 'controlnet') {
// Converting from T2I Adapter to ControlNet - add `controlMode`
const controlNetConfig: ControlNetConfig = {
...layer.controlAdapter,
type: 'controlnet',
controlMode: 'balanced',
};
layer.controlAdapter = controlNetConfig;
} else if (layer.controlAdapter.type === 'controlnet' && layer.controlAdapter.model.type === 't2i_adapter') {
// Converting from ControlNet to T2I Adapter - remove `controlMode`
const { controlMode: _, ...rest } = layer.controlAdapter;
const t2iAdapterConfig: T2IAdapterConfig = { ...rest, type: 't2i_adapter' };
layer.controlAdapter = t2iAdapterConfig;
}
},
controlLayerControlModeChanged: (state, action: PayloadAction<{ id: string; controlMode: ControlModeV2 }>) => {
const { id, controlMode } = action.payload;
const layer = selectControlLayer(state, id);
if (!layer || !layer.controlAdapter || layer.controlAdapter.type !== 'controlnet') {
return;
}
layer.controlAdapter.controlMode = controlMode;
},
controlLayerWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
const { id, weight } = action.payload;
const layer = selectControlLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
layer.controlAdapter.weight = weight;
},
controlLayerBeginEndStepPctChanged: (
state,
action: PayloadAction<{ id: string; beginEndStepPct: [number, number] }>
) => {
const { id, beginEndStepPct } = action.payload;
const layer = selectControlLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
layer.controlAdapter.beginEndStepPct = beginEndStepPct;
},
} satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -1,142 +0,0 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { isEqual, merge } from 'lodash-es';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import type { CanvasLayerState, CanvasV2State, ControlModeV2, ControlNetConfig, Rect, T2IAdapterConfig } from './types';
export const selectLayer = (state: CanvasV2State, id: string) => state.layers.entities.find((layer) => layer.id === id);
export const selectLayerOrThrow = (state: CanvasV2State, id: string) => {
const layer = selectLayer(state, id);
assert(layer, `Layer with id ${id} not found`);
return layer;
};
export const layersReducers = {
layerAdded: {
reducer: (
state,
action: PayloadAction<{ id: string; overrides?: Partial<CanvasLayerState>; isSelected?: boolean }>
) => {
const { id, overrides, isSelected } = action.payload;
const layer: CanvasLayerState = {
id,
type: 'layer',
isEnabled: true,
objects: [],
opacity: 1,
position: { x: 0, y: 0 },
rasterizationCache: [],
controlAdapter: null,
};
merge(layer, overrides);
state.layers.entities.push(layer);
if (isSelected) {
state.selectedEntityIdentifier = { type: 'layer', id };
}
if (layer.objects.length > 0) {
// This new layer will change the composite layer's image data. Invalidate the cache.
state.layers.compositeRasterizationCache = [];
}
},
prepare: (payload: { overrides?: Partial<CanvasLayerState>; isSelected?: boolean }) => ({
payload: { ...payload, id: getPrefixedId('layer') },
}),
},
layerRecalled: (state, action: PayloadAction<{ data: CanvasLayerState }>) => {
const { data } = action.payload;
state.layers.entities.push(data);
state.selectedEntityIdentifier = { type: 'layer', id: data.id };
if (data.objects.length > 0) {
// This new layer will change the composite layer's image data. Invalidate the cache.
state.layers.compositeRasterizationCache = [];
}
},
layerAllDeleted: (state) => {
state.layers.entities = [];
state.layers.compositeRasterizationCache = [];
},
layerCompositeRasterized: (state, action: PayloadAction<{ imageName: string; rect: Rect }>) => {
state.layers.compositeRasterizationCache = state.layers.compositeRasterizationCache.filter(
(cache) => !isEqual(cache.rect, action.payload.rect)
);
state.layers.compositeRasterizationCache.push(action.payload);
},
layerUsedAsControlChanged: (
state,
action: PayloadAction<{ id: string; controlAdapter: ControlNetConfig | T2IAdapterConfig | null }>
) => {
const { id, controlAdapter } = action.payload;
const layer = selectLayer(state, id);
if (!layer) {
return;
}
layer.controlAdapter = controlAdapter;
// The composite layer's image data will change when the layer is used as control (or not). Invalidate the cache.
state.layers.compositeRasterizationCache = [];
},
layerControlAdapterModelChanged: (
state,
action: PayloadAction<{
id: string;
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null;
}>
) => {
const { id, modelConfig } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
if (!modelConfig) {
layer.controlAdapter.model = null;
return;
}
layer.controlAdapter.model = zModelIdentifierField.parse(modelConfig);
// We may need to convert the CA to match the model
if (layer.controlAdapter.type === 't2i_adapter' && layer.controlAdapter.model.type === 'controlnet') {
// Converting from T2I Adapter to ControlNet - add `controlMode`
const controlNetConfig: ControlNetConfig = {
...layer.controlAdapter,
type: 'controlnet',
controlMode: 'balanced',
};
layer.controlAdapter = controlNetConfig;
} else if (layer.controlAdapter.type === 'controlnet' && layer.controlAdapter.model.type === 't2i_adapter') {
// Converting from ControlNet to T2I Adapter - remove `controlMode`
const { controlMode: _, ...rest } = layer.controlAdapter;
const t2iAdapterConfig: T2IAdapterConfig = { ...rest, type: 't2i_adapter' };
layer.controlAdapter = t2iAdapterConfig;
}
},
layerControlAdapterControlModeChanged: (state, action: PayloadAction<{ id: string; controlMode: ControlModeV2 }>) => {
const { id, controlMode } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter || layer.controlAdapter.type !== 'controlnet') {
return;
}
layer.controlAdapter.controlMode = controlMode;
},
layerControlAdapterWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
const { id, weight } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
layer.controlAdapter.weight = weight;
},
layerControlAdapterBeginEndStepPctChanged: (
state,
action: PayloadAction<{ id: string; beginEndStepPct: [number, number] }>
) => {
const { id, beginEndStepPct } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
layer.controlAdapter.beginEndStepPct = beginEndStepPct;
},
} satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -0,0 +1,108 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { isEqual, merge } from 'lodash-es';
import { assert } from 'tsafe';
import type {
CanvasControlLayerState,
CanvasRasterLayerState,
CanvasV2State,
ControlNetConfig,
Rect,
T2IAdapterConfig,
} from './types';
export const selectRasterLayer = (state: CanvasV2State, id: string) =>
state.rasterLayers.entities.find((layer) => layer.id === id);
export const selectLayerOrThrow = (state: CanvasV2State, id: string) => {
const layer = selectRasterLayer(state, id);
assert(layer, `Layer with id ${id} not found`);
return layer;
};
export const rasterLayersReducers = {
rasterLayerAdded: {
reducer: (
state,
action: PayloadAction<{ id: string; overrides?: Partial<CanvasRasterLayerState>; isSelected?: boolean }>
) => {
const { id, overrides, isSelected } = action.payload;
const layer: CanvasRasterLayerState = {
id,
type: 'raster_layer',
isEnabled: true,
objects: [],
opacity: 1,
position: { x: 0, y: 0 },
rasterizationCache: [],
};
merge(layer, overrides);
state.rasterLayers.entities.push(layer);
if (isSelected) {
state.selectedEntityIdentifier = { type: 'raster_layer', id };
}
if (layer.objects.length > 0) {
// This new layer will change the composite layer's image data. Invalidate the cache.
state.rasterLayers.compositeRasterizationCache = [];
}
},
prepare: (payload: { overrides?: Partial<CanvasRasterLayerState>; isSelected?: boolean }) => ({
payload: { ...payload, id: getPrefixedId('raster_layer') },
}),
},
rasterLayerRecalled: (state, action: PayloadAction<{ data: CanvasRasterLayerState }>) => {
const { data } = action.payload;
state.rasterLayers.entities.push(data);
state.selectedEntityIdentifier = { type: 'raster_layer', id: data.id };
if (data.objects.length > 0) {
// This new layer will change the composite layer's image data. Invalidate the cache.
state.rasterLayers.compositeRasterizationCache = [];
}
},
rasterLayerAllDeleted: (state) => {
state.rasterLayers.entities = [];
state.rasterLayers.compositeRasterizationCache = [];
},
rasterLayerCompositeRasterized: (state, action: PayloadAction<{ imageName: string; rect: Rect }>) => {
state.rasterLayers.compositeRasterizationCache = state.rasterLayers.compositeRasterizationCache.filter(
(cache) => !isEqual(cache.rect, action.payload.rect)
);
state.rasterLayers.compositeRasterizationCache.push(action.payload);
},
rasterLayerConvertedToControlLayer: {
reducer: (
state,
action: PayloadAction<{ id: string; newId: string; controlAdapter: ControlNetConfig | T2IAdapterConfig }>
) => {
const { id, newId, controlAdapter } = action.payload;
const layer = selectRasterLayer(state, id);
if (!layer) {
return;
}
// Convert the raster layer to control layer
const controlLayerState: CanvasControlLayerState = {
...deepClone(layer),
id: newId,
type: 'control_layer',
controlAdapter,
};
// Remove the raster layer
state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== id);
// Add the converted control layer
state.controlLayers.entities.push(controlLayerState);
// The composite layer's image data will change when the raster layer is converted to control layer.
state.rasterLayers.compositeRasterizationCache = [];
state.selectedEntityIdentifier = { type: controlLayerState.type, id: controlLayerState.id };
},
prepare: (payload: { id: string; controlAdapter: ControlNetConfig | T2IAdapterConfig }) => ({
payload: { ...payload, newId: getPrefixedId('control_layer') },
}),
},
} satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -7,7 +7,7 @@ export const selectEntityCount = createSelector(selectCanvasV2Slice, (canvasV2)
canvasV2.regions.entities.length + canvasV2.regions.entities.length +
// canvasV2.controlAdapters.entities.length + // canvasV2.controlAdapters.entities.length +
canvasV2.ipAdapters.entities.length + canvasV2.ipAdapters.entities.length +
canvasV2.layers.entities.length canvasV2.rasterLayers.entities.length
); );
}); });

View File

@ -728,23 +728,22 @@ const zT2IAdapterConfig = z.object({
}); });
export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>; export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
export const zCanvasLayerState = z.object({ export const zCanvasRasterLayerState = z.object({
id: zId, id: zId,
type: z.literal('layer'), type: z.literal('raster_layer'),
isEnabled: z.boolean(), isEnabled: z.boolean(),
position: zCoordinate, position: zCoordinate,
opacity: zOpacity, opacity: zOpacity,
objects: z.array(zCanvasObjectState), objects: z.array(zCanvasObjectState),
rasterizationCache: z.array(zImageCache), rasterizationCache: z.array(zImageCache),
controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig]).nullable(),
}); });
export type CanvasLayerState = z.infer<typeof zCanvasLayerState>; export type CanvasRasterLayerState = z.infer<typeof zCanvasRasterLayerState>;
export type CanvasLayerStateWithValidControlNet = Omit<CanvasLayerState, 'controlAdapter'> & {
controlAdapter: Omit<ControlNetConfig, 'model'> & { model: ControlNetModelConfig }; export const zCanvasControlLayerState = zCanvasRasterLayerState.extend({
}; type: z.literal('control_layer'),
export type CanvasLayerStateWithValidT2IAdapter = Omit<CanvasLayerState, 'controlAdapter'> & { controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig]),
controlAdapter: Omit<T2IAdapterConfig, 'model'> & { model: T2IAdapterModelConfig }; });
}; export type CanvasControlLayerState = z.infer<typeof zCanvasControlLayerState>;
export const initialControlNetV2: ControlNetConfig = { export const initialControlNetV2: ControlNetConfig = {
type: 'controlnet', type: 'controlnet',
@ -808,8 +807,8 @@ export const isBoundingBoxScaleMethod = (v: unknown): v is BoundingBoxScaleMetho
zBoundingBoxScaleMethod.safeParse(v).success; zBoundingBoxScaleMethod.safeParse(v).success;
export type CanvasEntityState = export type CanvasEntityState =
| CanvasLayerState | CanvasRasterLayerState
| CanvasControlAdapterState | CanvasControlLayerState
| CanvasRegionalGuidanceState | CanvasRegionalGuidanceState
| CanvasInpaintMaskState | CanvasInpaintMaskState
| CanvasIPAdapterState; | CanvasIPAdapterState;
@ -832,7 +831,8 @@ export type CanvasV2State = {
_version: 3; _version: 3;
selectedEntityIdentifier: CanvasEntityIdentifier | null; selectedEntityIdentifier: CanvasEntityIdentifier | null;
inpaintMask: CanvasInpaintMaskState; inpaintMask: CanvasInpaintMaskState;
layers: { entities: CanvasLayerState[]; compositeRasterizationCache: ImageCache[] }; rasterLayers: { entities: CanvasRasterLayerState[]; compositeRasterizationCache: ImageCache[] };
controlLayers: { entities: CanvasControlLayerState[] };
ipAdapters: { entities: CanvasIPAdapterState[] }; ipAdapters: { entities: CanvasIPAdapterState[] };
regions: { entities: CanvasRegionalGuidanceState[] }; regions: { entities: CanvasRegionalGuidanceState[] };
loras: LoRA[]; loras: LoRA[];
@ -962,10 +962,19 @@ export type RemoveIndexString<T> = {
export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint'; export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint';
export function isDrawableEntityType(entityType: CanvasEntityState['type']) {
return (
entityType === 'raster_layer' ||
entityType === 'control_layer' ||
entityType === 'regional_guidance' ||
entityType === 'inpaint_mask'
);
}
export function isDrawableEntity( export function isDrawableEntity(
entity: CanvasEntityState entity: CanvasEntityState
): entity is CanvasLayerState | CanvasRegionalGuidanceState | CanvasInpaintMaskState { ): entity is CanvasRasterLayerState | CanvasControlLayerState | CanvasRegionalGuidanceState | CanvasInpaintMaskState {
return entity.type === 'layer' || entity.type === 'regional_guidance' || entity.type === 'inpaint_mask'; return isDrawableEntityType(entity.type);
} }
export function isDrawableEntityAdapter( export function isDrawableEntityAdapter(
@ -973,9 +982,3 @@ export function isDrawableEntityAdapter(
): adapter is CanvasLayerAdapter | CanvasMaskAdapter { ): adapter is CanvasLayerAdapter | CanvasMaskAdapter {
return adapter instanceof CanvasLayerAdapter || adapter instanceof CanvasMaskAdapter; return adapter instanceof CanvasLayerAdapter || adapter instanceof CanvasMaskAdapter;
} }
export function isDrawableEntityType(
entityType: CanvasEntityState['type']
): entityType is 'layer' | 'regional_guidance' | 'inpaint_mask' {
return entityType === 'layer' || entityType === 'regional_guidance' || entityType === 'inpaint_mask';
}

View File

@ -11,7 +11,7 @@ import { some } from 'lodash-es';
import type { ImageUsage } from './types'; import type { ImageUsage } from './types';
export const getImageUsage = (nodes: NodesState, canvasV2: CanvasV2State, image_name: string) => { export const getImageUsage = (nodes: NodesState, canvasV2: CanvasV2State, image_name: string) => {
const isLayerImage = canvasV2.layers.entities.some((layer) => const isLayerImage = canvasV2.rasterLayers.entities.some((layer) =>
layer.objects.some((obj) => obj.type === 'image' && obj.image.image_name === image_name) layer.objects.some((obj) => obj.type === 'image' && obj.image.image_name === image_name)
); );

View File

@ -1,4 +1,4 @@
import type { CanvasLayerState } from 'features/controlLayers/store/types'; import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types'; import type { MetadataHandlers } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers'; import { handlers } from 'features/metadata/util/handlers';
@ -9,7 +9,7 @@ type Props = {
}; };
export const MetadataLayers = ({ metadata }: Props) => { export const MetadataLayers = ({ metadata }: Props) => {
const [layers, setLayers] = useState<CanvasLayerState[]>([]); const [layers, setLayers] = useState<CanvasRasterLayerState[]>([]);
useEffect(() => { useEffect(() => {
const parse = async () => { const parse = async () => {
@ -40,8 +40,8 @@ const MetadataViewLayer = ({
handlers, handlers,
}: { }: {
label: string; label: string;
layer: CanvasLayerState; layer: CanvasRasterLayerState;
handlers: MetadataHandlers<CanvasLayerState[], CanvasLayerState>; handlers: MetadataHandlers<CanvasRasterLayerState[], CanvasRasterLayerState>;
}) => { }) => {
const onRecall = useCallback(() => { const onRecall = useCallback(() => {
if (!handlers.recallItem) { if (!handlers.recallItem) {

View File

@ -2,7 +2,7 @@ import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { objectKeys } from 'common/util/objectKeys'; import { objectKeys } from 'common/util/objectKeys';
import { shouldConcatPromptsChanged } from 'features/controlLayers/store/canvasV2Slice'; import { shouldConcatPromptsChanged } from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasLayerState, LoRA } from 'features/controlLayers/store/types'; import type { CanvasRasterLayerState, LoRA } from 'features/controlLayers/store/types';
import type { import type {
AnyControlAdapterConfigMetadata, AnyControlAdapterConfigMetadata,
BuildMetadataHandlers, BuildMetadataHandlers,
@ -48,7 +48,7 @@ const renderControlAdapterValue: MetadataRenderValueFunc<AnyControlAdapterConfig
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`; return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
} }
}; };
const renderLayerValue: MetadataRenderValueFunc<CanvasLayerState> = async (layer) => { const renderLayerValue: MetadataRenderValueFunc<CanvasRasterLayerState> = async (layer) => {
if (layer.type === 'initial_image_layer') { if (layer.type === 'initial_image_layer') {
let rendered = t('controlLayers.globalInitialImageLayer'); let rendered = t('controlLayers.globalInitialImageLayer');
if (layer.image) { if (layer.image) {
@ -88,7 +88,7 @@ const renderLayerValue: MetadataRenderValueFunc<CanvasLayerState> = async (layer
} }
assert(false, 'Unknown layer type'); assert(false, 'Unknown layer type');
}; };
const renderLayersValue: MetadataRenderValueFunc<CanvasLayerState[]> = async (layers) => { const renderLayersValue: MetadataRenderValueFunc<CanvasRasterLayerState[]> = async (layers) => {
return `${layers.length} ${t('controlLayers.layers', { count: layers.length })}`; return `${layers.length} ${t('controlLayers.layers', { count: layers.length })}`;
}; };

View File

@ -1,6 +1,6 @@
import { getCAId, getImageObjectId, getIPAId, getLayerId } from 'features/controlLayers/konva/naming'; import { getCAId, getImageObjectId, getIPAId, getLayerId } from 'features/controlLayers/konva/naming';
import { defaultLoRAConfig } from 'features/controlLayers/store/lorasReducers'; import { defaultLoRAConfig } from 'features/controlLayers/store/lorasReducers';
import type { CanvasControlAdapterState, CanvasIPAdapterState, CanvasLayerState, LoRA } from 'features/controlLayers/store/types'; import type { CanvasControlAdapterState, CanvasIPAdapterState, CanvasRasterLayerState, LoRA } from 'features/controlLayers/store/types';
import { import {
IMAGE_FILTERS, IMAGE_FILTERS,
imageDTOToImageWithDims, imageDTOToImageWithDims,
@ -8,7 +8,7 @@ import {
initialIPAdapterV2, initialIPAdapterV2,
initialT2IAdapterV2, initialT2IAdapterV2,
isFilterType, isFilterType,
zCanvasLayerState, zCanvasRasterLayerState,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import type { import type {
ControlNetConfigMetadata, ControlNetConfigMetadata,
@ -424,22 +424,22 @@ const parseAllIPAdapters: MetadataParseFunc<IPAdapterConfigMetadata[]> = async (
}; };
//#region Control Layers //#region Control Layers
const parseLayer: MetadataParseFunc<CanvasLayerState> = async (metadataItem) => zCanvasLayerState.parseAsync(metadataItem); const parseLayer: MetadataParseFunc<CanvasRasterLayerState> = async (metadataItem) => zCanvasRasterLayerState.parseAsync(metadataItem);
const parseLayers: MetadataParseFunc<CanvasLayerState[]> = async (metadata) => { const parseLayers: MetadataParseFunc<CanvasRasterLayerState[]> = async (metadata) => {
// We need to support recalling pre-Control Layers metadata into Control Layers. A separate set of parsers handles // We need to support recalling pre-Control Layers metadata into Control Layers. A separate set of parsers handles
// taking pre-CL metadata and parsing it into layers. It doesn't always map 1-to-1, so this is best-effort. For // taking pre-CL metadata and parsing it into layers. It doesn't always map 1-to-1, so this is best-effort. For
// example, CL Control Adapters don't support resize mode, so we simply omit that property. // example, CL Control Adapters don't support resize mode, so we simply omit that property.
try { try {
const layers: CanvasLayerState[] = []; const layers: CanvasRasterLayerState[] = [];
try { try {
const control_layers = await getProperty(metadata, 'control_layers'); const control_layers = await getProperty(metadata, 'control_layers');
const controlLayersRaw = await getProperty(control_layers, 'layers', isArray); const controlLayersRaw = await getProperty(control_layers, 'layers', isArray);
const controlLayersParseResults = await Promise.allSettled(controlLayersRaw.map(parseLayer)); const controlLayersParseResults = await Promise.allSettled(controlLayersRaw.map(parseLayer));
const controlLayers = controlLayersParseResults const controlLayers = controlLayersParseResults
.filter((result): result is PromiseFulfilledResult<CanvasLayerState> => result.status === 'fulfilled') .filter((result): result is PromiseFulfilledResult<CanvasRasterLayerState> => result.status === 'fulfilled')
.map((result) => result.value); .map((result) => result.value);
layers.push(...controlLayers); layers.push(...controlLayers);
} catch { } catch {
@ -498,16 +498,16 @@ const parseLayers: MetadataParseFunc<CanvasLayerState[]> = async (metadata) => {
} }
}; };
const parseInitialImageToInitialImageLayer: MetadataParseFunc<CanvasLayerState> = async (metadata) => { const parseInitialImageToInitialImageLayer: MetadataParseFunc<CanvasRasterLayerState> = async (metadata) => {
// TODO(psyche): recall denoise strength // TODO(psyche): recall denoise strength
// const denoisingStrength = await getProperty(metadata, 'strength', isParameterStrength); // const denoisingStrength = await getProperty(metadata, 'strength', isParameterStrength);
const imageName = await getProperty(metadata, 'init_image', isString); const imageName = await getProperty(metadata, 'init_image', isString);
const imageDTO = await getImageDTO(imageName); const imageDTO = await getImageDTO(imageName);
assert(imageDTO, 'ImageDTO is null'); assert(imageDTO, 'ImageDTO is null');
const id = getLayerId(uuidv4()); const id = getLayerId(uuidv4());
const layer: CanvasLayerState = { const layer: CanvasRasterLayerState = {
id, id,
type: 'layer', type: 'raster_layer',
bbox: null, bbox: null,
bboxNeedsUpdate: true, bboxNeedsUpdate: true,
x: 0, x: 0,

View File

@ -15,8 +15,8 @@ import {
bboxWidthChanged, bboxWidthChanged,
// caRecalled, // caRecalled,
ipaRecalled, ipaRecalled,
layerAllDeleted, rasterLayerAllDeleted,
layerRecalled, rasterLayerRecalled,
loraAllDeleted, loraAllDeleted,
loraRecalled, loraRecalled,
negativePrompt2Changed, negativePrompt2Changed,
@ -42,7 +42,7 @@ import {
import type { import type {
CanvasControlAdapterState, CanvasControlAdapterState,
CanvasIPAdapterState, CanvasIPAdapterState,
CanvasLayerState, CanvasRasterLayerState,
CanvasRegionalGuidanceState, CanvasRegionalGuidanceState,
LoRA, LoRA,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
@ -328,7 +328,7 @@ const recallRG: MetadataRecallFunc<CanvasRegionalGuidanceState> = async (rg) =>
}; };
//#region Control Layers //#region Control Layers
const recallLayer: MetadataRecallFunc<CanvasLayerState> = async (layer) => { const recallLayer: MetadataRecallFunc<CanvasRasterLayerState> = async (layer) => {
const { dispatch } = getStore(); const { dispatch } = getStore();
const clone = deepClone(layer); const clone = deepClone(layer);
const invalidObjects: string[] = []; const invalidObjects: string[] = [];
@ -355,13 +355,13 @@ const recallLayer: MetadataRecallFunc<CanvasLayerState> = async (layer) => {
} }
} }
clone.id = getRGId(uuidv4()); clone.id = getRGId(uuidv4());
dispatch(layerRecalled({ data: clone })); dispatch(rasterLayerRecalled({ data: clone }));
return; return;
}; };
const recallLayers: MetadataRecallFunc<CanvasLayerState[]> = (layers) => { const recallLayers: MetadataRecallFunc<CanvasRasterLayerState[]> = (layers) => {
const { dispatch } = getStore(); const { dispatch } = getStore();
dispatch(layerAllDeleted()); dispatch(rasterLayerAllDeleted());
for (const l of layers) { for (const l of layers) {
recallLayer(l); recallLayer(l);
} }

View File

@ -1,5 +1,5 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import type { CanvasLayerState, LoRA } from 'features/controlLayers/store/types'; import type { CanvasRasterLayerState, LoRA } from 'features/controlLayers/store/types';
import type { import type {
ControlNetConfigMetadata, ControlNetConfigMetadata,
IPAdapterConfigMetadata, IPAdapterConfigMetadata,
@ -109,7 +109,7 @@ const validateIPAdapters: MetadataValidateFunc<IPAdapterConfigMetadata[]> = (ipA
return new Promise((resolve) => resolve(validatedIPAdapters)); return new Promise((resolve) => resolve(validatedIPAdapters));
}; };
const validateLayer: MetadataValidateFunc<CanvasLayerState> = async (layer) => { const validateLayer: MetadataValidateFunc<CanvasRasterLayerState> = async (layer) => {
if (layer.type === 'control_adapter_layer') { if (layer.type === 'control_adapter_layer') {
const model = layer.controlAdapter.model; const model = layer.controlAdapter.model;
assert(model, 'Control Adapter layer missing model'); assert(model, 'Control Adapter layer missing model');
@ -131,8 +131,8 @@ const validateLayer: MetadataValidateFunc<CanvasLayerState> = async (layer) => {
return layer; return layer;
}; };
const validateLayers: MetadataValidateFunc<CanvasLayerState[]> = async (layers) => { const validateLayers: MetadataValidateFunc<CanvasRasterLayerState[]> = async (layers) => {
const validatedLayers: CanvasLayerState[] = []; const validatedLayers: CanvasRasterLayerState[] = [];
for (const l of layers) { for (const l of layers) {
try { try {
const validated = await validateLayer(l); const validated = await validateLayer(l);

View File

@ -1,15 +1,10 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { import type {
CanvasLayerState, CanvasControlLayerState,
CanvasLayerStateWithValidControlNet,
CanvasLayerStateWithValidT2IAdapter,
ControlNetConfig, ControlNetConfig,
FilterConfig,
ImageWithDims,
Rect, Rect,
T2IAdapterConfig, T2IAdapterConfig,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import type { ImageField } from 'features/nodes/types/common';
import { CONTROL_NET_COLLECT, T2I_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants'; import { CONTROL_NET_COLLECT, T2I_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types'; import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
@ -17,18 +12,18 @@ import { assert } from 'tsafe';
export const addControlAdapters = async ( export const addControlAdapters = async (
manager: CanvasManager, manager: CanvasManager,
layers: CanvasLayerState[], layers: CanvasControlLayerState[],
g: Graph, g: Graph,
bbox: Rect, bbox: Rect,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'>,
base: BaseModelType base: BaseModelType
): Promise<(CanvasLayerStateWithValidControlNet | CanvasLayerStateWithValidT2IAdapter)[]> => { ): Promise<CanvasControlLayerState[]> => {
const layersWithValidControlAdapters = layers const validControlLayers = layers
.filter((layer) => layer.isEnabled) .filter((layer) => layer.isEnabled)
.filter((layer) => doesLayerHaveValidControlAdapter(layer, base)); .filter((layer) => isValidControlAdapter(layer.controlAdapter, base));
for (const layer of layersWithValidControlAdapters) { for (const layer of validControlLayers) {
const adapter = manager.layers.get(layer.id); const adapter = manager.controlLayerAdapters.get(layer.id);
assert(adapter, 'Adapter not found'); assert(adapter, 'Adapter not found');
const imageDTO = await adapter.renderer.rasterize(bbox); const imageDTO = await adapter.renderer.rasterize(bbox);
if (layer.controlAdapter.type === 'controlnet') { if (layer.controlAdapter.type === 'controlnet') {
@ -37,7 +32,7 @@ export const addControlAdapters = async (
await addT2IAdapterToGraph(g, layer, imageDTO, denoise); await addT2IAdapterToGraph(g, layer, imageDTO, denoise);
} }
} }
return layersWithValidControlAdapters; return validControlLayers;
}; };
const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => { const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
@ -59,12 +54,14 @@ const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
const addControlNetToGraph = ( const addControlNetToGraph = (
g: Graph, g: Graph,
layer: CanvasLayerStateWithValidControlNet, layer: CanvasControlLayerState,
imageDTO: ImageDTO, imageDTO: ImageDTO,
denoise: Invocation<'denoise_latents'> denoise: Invocation<'denoise_latents'>
) => { ) => {
const { id, controlAdapter } = layer; const { id, controlAdapter } = layer;
assert(controlAdapter.type === 'controlnet');
const { beginEndStepPct, model, weight, controlMode } = controlAdapter; const { beginEndStepPct, model, weight, controlMode } = controlAdapter;
assert(model !== null);
const { image_name } = imageDTO; const { image_name } = imageDTO;
const controlNetCollect = addControlNetCollectorSafe(g, denoise); const controlNetCollect = addControlNetCollectorSafe(g, denoise);
@ -103,12 +100,14 @@ const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
const addT2IAdapterToGraph = ( const addT2IAdapterToGraph = (
g: Graph, g: Graph,
layer: CanvasLayerStateWithValidT2IAdapter, layer: CanvasControlLayerState,
imageDTO: ImageDTO, imageDTO: ImageDTO,
denoise: Invocation<'denoise_latents'> denoise: Invocation<'denoise_latents'>
) => { ) => {
const { id, controlAdapter } = layer; const { id, controlAdapter } = layer;
assert(controlAdapter.type === 't2i_adapter');
const { beginEndStepPct, model, weight } = controlAdapter; const { beginEndStepPct, model, weight } = controlAdapter;
assert(model !== null);
const { image_name } = imageDTO; const { image_name } = imageDTO;
const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise); const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise);
@ -127,25 +126,6 @@ const addT2IAdapterToGraph = (
g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item'); g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item');
}; };
const buildControlImage = (
image: ImageWithDims | null,
processedImage: ImageWithDims | null,
processorConfig: FilterConfig | null
): ImageField => {
if (processedImage && processorConfig) {
// We've processed the image in the app - use it for the control image.
return {
image_name: processedImage.image_name,
};
} else if (image) {
// No processor selected, and we have an image - the user provided a processed image, use it for the control image.
return {
image_name: image.image_name,
};
}
assert(false, 'Attempted to add unprocessed control image');
};
const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => { const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => {
// Must be have a model // Must be have a model
const hasModel = Boolean(controlAdapter.model); const hasModel = Boolean(controlAdapter.model);
@ -153,22 +133,3 @@ const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConf
const modelMatchesBase = controlAdapter.model?.base === base; const modelMatchesBase = controlAdapter.model?.base === base;
return hasModel && modelMatchesBase; return hasModel && modelMatchesBase;
}; };
const doesLayerHaveValidControlAdapter = (
layer: CanvasLayerState,
base: BaseModelType
): layer is CanvasLayerStateWithValidControlNet | CanvasLayerStateWithValidT2IAdapter => {
if (!layer.controlAdapter) {
// Must have a control adapter
return false;
}
if (!layer.controlAdapter.model) {
// Control adapter must have a model selected
return false;
}
if (layer.controlAdapter.model.base !== base) {
// Selected model must match current base model
return false;
}
return true;
};

View File

@ -22,7 +22,7 @@ export const addInpaint = async (
denoise.denoising_start = denoising_start; denoise.denoising_start = denoising_start;
const initialImage = await manager.getCompositeLayerImageDTO(bbox.rect); const initialImage = await manager.getCompositeLayerImageDTO(bbox.rect);
const maskImage = await manager.inpaintMask.renderer.rasterize(bbox.rect); const maskImage = await manager.inpaintMaskAdapter.renderer.rasterize(bbox.rect);
if (!isEqual(scaledSize, originalSize)) { if (!isEqual(scaledSize, originalSize)) {
// Scale before processing requires some resizing // Scale before processing requires some resizing

View File

@ -1,6 +1,6 @@
import type { CanvasLayerState } from 'features/controlLayers/store/types'; import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
export const isValidLayerWithoutControlAdapter = (layer: CanvasLayerState) => { export const isValidLayerWithoutControlAdapter = (layer: CanvasRasterLayerState) => {
return ( return (
layer.isEnabled && layer.isEnabled &&
// Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers // Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers

View File

@ -23,7 +23,7 @@ export const addOutpaint = async (
denoise.denoising_start = denoising_start; denoise.denoising_start = denoising_start;
const initialImage = await manager.getCompositeLayerImageDTO(bbox.rect); const initialImage = await manager.getCompositeLayerImageDTO(bbox.rect);
const maskImage = await manager.inpaintMask.renderer.rasterize(bbox.rect); const maskImage = await manager.inpaintMaskAdapter.renderer.rasterize(bbox.rect);
const infill = getInfill(g, compositing); const infill = getInfill(g, compositing);
if (!isEqual(scaledSize, originalSize)) { if (!isEqual(scaledSize, originalSize)) {

View File

@ -43,7 +43,7 @@ export const addRegions = async (
const validRegions = regions.filter((rg) => isValidRegion(rg, base)); const validRegions = regions.filter((rg) => isValidRegion(rg, base));
for (const region of validRegions) { for (const region of validRegions) {
const adapter = manager.regions.get(region.id); const adapter = manager.regionalGuidanceAdapters.get(region.id);
assert(adapter, 'Adapter not found'); assert(adapter, 'Adapter not found');
const imageDTO = await adapter.renderer.rasterize(bbox); const imageDTO = await adapter.renderer.rasterize(bbox);

View File

@ -215,7 +215,7 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P
const _addedCAs = await addControlAdapters( const _addedCAs = await addControlAdapters(
manager, manager,
state.canvasV2.layers.entities, state.canvasV2.rasterLayers.entities,
g, g,
state.canvasV2.bbox.rect, state.canvasV2.bbox.rect,
denoise, denoise,

View File

@ -219,7 +219,7 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager):
const _addedCAs = await addControlAdapters( const _addedCAs = await addControlAdapters(
manager, manager,
state.canvasV2.layers.entities, state.canvasV2.rasterLayers.entities,
g, g,
state.canvasV2.bbox.rect, state.canvasV2.bbox.rect,
denoise, denoise,