feat(ui): split control layers from raster layers for UI and internal state, same rendering as raster layers

This commit is contained in:
psychedelicious 2024-08-15 10:35:07 +10:00
parent 96abf687f6
commit 1435557d1d
59 changed files with 866 additions and 671 deletions

View File

@ -1674,7 +1674,9 @@
"opacity": "Opacity",
"regionalGuidance_withCount": "Regional Guidance ({{count}})",
"controlAdapters_withCount": "Control Adapters ({{count}})",
"layers_withCount": "Raster Layers ({{count}})",
"controlLayer": "Control Layer",
"controlLayers_withCount": "Control Layers ({{count}})",
"rasterLayers_withCount": "Raster Layers ({{count}})",
"ipAdapters_withCount": "IP Adapters ({{count}})",
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",

View File

@ -2,11 +2,11 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import {
$lastProgressEvent,
layerAdded,
rasterLayerAdded,
sessionStagingAreaImageAccepted,
sessionStagingAreaReset,
} from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasLayerState } from 'features/controlLayers/store/types';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
@ -62,12 +62,12 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
const { imageDTO, offsetX, offsetY } = stagingAreaImage;
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasLayerState> = {
const overrides: Partial<CanvasRasterLayerState> = {
position: { x: x + offsetX, y: y + offsetY },
objects: [imageObject],
};
api.dispatch(layerAdded({ overrides }));
api.dispatch(rasterLayerAdded({ overrides }));
api.dispatch(sessionStagingAreaReset());
},
});

View File

@ -1,5 +1,5 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { ipaAllDeleted, layerAllDeleted } from 'features/controlLayers/store/canvasV2Slice';
import { ipaAllDeleted, rasterLayerAllDeleted } from 'features/controlLayers/store/canvasV2Slice';
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { imagesApi } from 'services/api/endpoints/images';
@ -22,7 +22,7 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
const imageUsage = getImageUsage(nodes.present, canvasV2, image_name);
if (imageUsage.isLayerImage && !wereLayersReset) {
dispatch(layerAllDeleted());
dispatch(rasterLayerAllDeleted());
wereLayersReset = true;
}

View File

@ -55,7 +55,7 @@ const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO
};
const deleteLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.canvasV2.layers.entities.forEach(({ id, objects }) => {
state.canvasV2.rasterLayers.entities.forEach(({ id, objects }) => {
let shouldDelete = false;
for (const obj of objects) {
if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
@ -64,7 +64,7 @@ const deleteLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: Im
}
}
if (shouldDelete) {
dispatch(entityDeleted({ entityIdentifier: { id, type: 'layer' } }));
dispatch(entityDeleted({ entityIdentifier: { id, type: 'raster_layer' } }));
}
});
};

View File

@ -4,10 +4,10 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { parseify } from 'common/util/serialize';
import {
ipaImageChanged,
layerAdded,
rasterLayerAdded,
rgIPAdapterImageChanged,
} from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasLayerState } from 'features/controlLayers/store/types';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
@ -108,11 +108,11 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
) {
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = getState().canvasV2.bbox.rect;
const overrides: Partial<CanvasLayerState> = {
const overrides: Partial<CanvasRasterLayerState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(layerAdded({ overrides, isSelected: true }));
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
return;
}

View File

@ -2,7 +2,6 @@ import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasEntityState } from 'features/controlLayers/store/types';
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
@ -18,14 +17,13 @@ import { forEach, upperFirst } from 'lodash-es';
import { useMemo } from 'react';
import { getConnectedEdges } from 'reactflow';
const LAYER_TYPE_TO_TKEY: Record<CanvasEntityState['type'], string> = {
control_adapter: 'controlLayers.globalControlAdapter',
ip_adapter: 'controlLayers.globalIPAdapter',
regional_guidance: 'controlLayers.regionalGuidance',
layer: 'controlLayers.raster',
const LAYER_TYPE_TO_TKEY = {
ip_adapter: 'controlLayers.ipAdapter',
inpaint_mask: 'controlLayers.inpaintMask',
initial_image: 'controlLayers.initialImage',
};
regional_guidance: 'controlLayers.regionalGuidance',
raster_layer: 'controlLayers.raster',
control_layer: 'controlLayers.globalControlAdapter',
} as const;
const createSelector = (templates: Templates) =>
createMemoizedSelector(
@ -125,41 +123,35 @@ const createSelector = (templates: Templates) =>
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
// canvasV2.controlAdapters.entities
// .filter((ca) => ca.isEnabled)
// .forEach((ca, i) => {
// const layerLiteral = i18n.t('controlLayers.layers_one');
// const layerNumber = i + 1;
// const layerType = i18n.t(LAYER_TYPE_TO_TKEY[ca.type]);
// const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
// const problems: string[] = [];
// // Must have model
// if (!ca.model) {
// problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
// }
// // Model base must match
// if (ca.model?.base !== model?.base) {
// problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
// }
// // Must have a control image OR, if it has a processor, it must have a processed image
// if (!ca.imageObject) {
// problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected'));
// } else if (ca.processorConfig && !ca.processedImageObject) {
// problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed'));
// }
// // T2I Adapters require images have dimensions that are multiples of 64 (SD1.5) or 32 (SDXL)
// if (ca.adapterType === 't2i_adapter') {
// const multiple = model?.base === 'sdxl' ? 32 : 64;
// if (bbox.rect.width % multiple !== 0 || bbox.rect.height % multiple !== 0) {
// problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions', { multiple }));
// }
// }
canvasV2.controlLayers.entities
.filter((controlLayer) => controlLayer.isEnabled)
.forEach((controlLayer, i) => {
const layerLiteral = i18n.t('controlLayers.layers_one');
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY['control_layer']);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];
// Must have model
if (!controlLayer.controlAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
}
// Model base must match
if (controlLayer.controlAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
}
// T2I Adapters require images have dimensions that are multiples of 64 (SD1.5) or 32 (SDXL)
if (controlLayer.controlAdapter.type === 't2i_adapter') {
const multiple = model?.base === 'sdxl' ? 32 : 64;
if (bbox.rect.width % multiple !== 0 || bbox.rect.height % multiple !== 0) {
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions', { multiple }));
}
}
// if (problems.length) {
// const content = upperFirst(problems.join(', '));
// reasons.push({ prefix, content });
// }
// });
if (problems.length) {
const content = upperFirst(problems.join(', '));
reasons.push({ prefix, content });
}
});
canvasV2.ipAdapters.entities
.filter((ipa) => ipa.isEnabled)
@ -226,8 +218,9 @@ const createSelector = (templates: Templates) =>
}
});
canvasV2.layers.entities
canvasV2.rasterLayers.entities
.filter((l) => l.isEnabled)
.filter((l) => l.type === 'raster_layer')
.forEach((l, i) => {
const layerLiteral = i18n.t('controlLayers.layers_one');
const layerNumber = i + 1;

View File

@ -1,7 +1,7 @@
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useAddCALayer, useAddIPALayer } from 'features/controlLayers/hooks/addLayerHooks';
import { layerAdded, rgAdded } from 'features/controlLayers/store/canvasV2Slice';
import { rasterLayerAdded, rgAdded } from 'features/controlLayers/store/canvasV2Slice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
@ -15,7 +15,7 @@ export const AddLayerButton = memo(() => {
dispatch(rgAdded());
}, [dispatch]);
const addRasterLayer = useCallback(() => {
dispatch(layerAdded({ isSelected: true }));
dispatch(rasterLayerAdded({ isSelected: true }));
}, [dispatch]);
return (

View File

@ -1,8 +1,9 @@
import { Flex } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { ControlLayerEntityList } from 'features/controlLayers/components/ControlLayer/ControlLayerEntityList';
import { InpaintMask } from 'features/controlLayers/components/InpaintMask/InpaintMask';
import { IPAdapterList } from 'features/controlLayers/components/IPAdapter/IPAdapterList';
import { LayerEntityList } from 'features/controlLayers/components/Layer/LayerEntityList';
import { RasterLayerEntityList } from 'features/controlLayers/components/RasterLayer/RasterLayerEntityList';
import { RegionalGuidanceEntityList } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceEntityList';
import { memo } from 'react';
@ -13,7 +14,8 @@ export const CanvasEntityList = memo(() => {
<InpaintMask />
<RegionalGuidanceEntityList />
<IPAdapterList />
<LayerEntityList />
<ControlLayerEntityList />
<RasterLayerEntityList />
</Flex>
</ScrollableContent>
);

View File

@ -0,0 +1,39 @@
import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
import { CanvasEntityTitle } from 'features/controlLayers/components/common/CanvasEntityTitle';
import { ControlLayerActionsMenu } from 'features/controlLayers/components/ControlLayer/ControlLayerActionsMenu';
import { ControlLayerControlAdapter } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapter';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { memo, useMemo } from 'react';
type Props = {
id: string;
};
export const ControlLayer = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'control_layer' }), [id]);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer>
<CanvasEntityHeader>
<CanvasEntityEnabledToggle />
<CanvasEntityTitle />
<Spacer />
<ControlLayerActionsMenu />
<CanvasEntityDeleteButton />
</CanvasEntityHeader>
<CanvasEntitySettingsWrapper>
<ControlLayerControlAdapter />
</CanvasEntitySettingsWrapper>
</CanvasEntityContainer>
</EntityIdentifierContext.Provider>
);
});
ControlLayer.displayName = 'ControlLayer';

View File

@ -0,0 +1,17 @@
import { Menu, MenuList } from '@invoke-ai/ui-library';
import { CanvasEntityActionMenuItems } from 'features/controlLayers/components/common/CanvasEntityActionMenuItems';
import { CanvasEntityMenuButton } from 'features/controlLayers/components/common/CanvasEntityMenuButton';
import { memo } from 'react';
export const ControlLayerActionsMenu = memo(() => {
return (
<Menu>
<CanvasEntityMenuButton />
<MenuList>
<CanvasEntityActionMenuItems />
</MenuList>
</Menu>
);
});
ControlLayerActionsMenu.displayName = 'ControlLayerActionsMenu';

View File

@ -5,50 +5,48 @@ import { Weight } from 'features/controlLayers/components/common/Weight';
import { ControlAdapterControlModeSelect } from 'features/controlLayers/components/ControlAdapter/ControlAdapterControlModeSelect';
import { ControlAdapterModel } from 'features/controlLayers/components/ControlAdapter/ControlAdapterModel';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useControlLayerControlAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
import {
layerControlAdapterBeginEndStepPctChanged,
layerControlAdapterControlModeChanged,
layerControlAdapterModelChanged,
layerControlAdapterWeightChanged,
controlLayerBeginEndStepPctChanged,
controlLayerControlModeChanged,
controlLayerModelChanged,
controlLayerWeightChanged,
} from 'features/controlLayers/store/canvasV2Slice';
import type { ControlModeV2, ControlNetConfig, T2IAdapterConfig } from 'features/controlLayers/store/types';
import type { ControlModeV2 } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
type Props = {
controlAdapter: ControlNetConfig | T2IAdapterConfig;
};
export const LayerControlAdapter = memo(({ controlAdapter }: Props) => {
export const ControlLayerControlAdapter = memo(() => {
const dispatch = useAppDispatch();
const { id } = useEntityIdentifierContext();
const entityIdentifier = useEntityIdentifierContext();
const controlAdapter = useControlLayerControlAdapter(entityIdentifier);
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
dispatch(layerControlAdapterBeginEndStepPctChanged({ id, beginEndStepPct }));
dispatch(controlLayerBeginEndStepPctChanged({ id: entityIdentifier.id, beginEndStepPct }));
},
[dispatch, id]
[dispatch, entityIdentifier.id]
);
const onChangeControlMode = useCallback(
(controlMode: ControlModeV2) => {
dispatch(layerControlAdapterControlModeChanged({ id, controlMode }));
dispatch(controlLayerControlModeChanged({ id: entityIdentifier.id, controlMode }));
},
[dispatch, id]
[dispatch, entityIdentifier.id]
);
const onChangeWeight = useCallback(
(weight: number) => {
dispatch(layerControlAdapterWeightChanged({ id, weight }));
dispatch(controlLayerWeightChanged({ id: entityIdentifier.id, weight }));
},
[dispatch, id]
[dispatch, entityIdentifier.id]
);
const onChangeModel = useCallback(
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
dispatch(layerControlAdapterModelChanged({ id, modelConfig }));
dispatch(controlLayerModelChanged({ id: entityIdentifier.id, modelConfig }));
},
[dispatch, id]
[dispatch, entityIdentifier.id]
);
return (
@ -63,4 +61,4 @@ export const LayerControlAdapter = memo(({ controlAdapter }: Props) => {
);
});
LayerControlAdapter.displayName = 'LayerControlAdapter';
ControlLayerControlAdapter.displayName = 'ControlLayerControlAdapter';

View File

@ -0,0 +1,38 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupTitle } from 'features/controlLayers/components/common/CanvasEntityGroupTitle';
import { ControlLayer } from 'features/controlLayers/components/ControlLayer/ControlLayer';
import { mapId } from 'features/controlLayers/konva/util';
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const selectEntityIds = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => {
return canvasV2.controlLayers.entities.map(mapId).reverse();
});
export const ControlLayerEntityList = memo(() => {
const { t } = useTranslation();
const isSelected = useAppSelector((s) => Boolean(s.canvasV2.selectedEntityIdentifier?.type === 'control_layer'));
const layerIds = useAppSelector(selectEntityIds);
if (layerIds.length === 0) {
return null;
}
if (layerIds.length > 0) {
return (
<>
<CanvasEntityGroupTitle
title={t('controlLayers.controlLayers_withCount', { count: layerIds.length })}
isSelected={isSelected}
/>
{layerIds.map((id) => (
<ControlLayer key={id} id={id} />
))}
</>
);
}
});
ControlLayerEntityList.displayName = 'ControlLayerEntityList';

View File

@ -10,8 +10,10 @@ import {
PopoverContent,
PopoverTrigger,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { MaskOpacity } from 'features/controlLayers/components/MaskOpacity';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import {
clipToBboxChanged,
invertScrollChanged,
@ -25,6 +27,7 @@ import { RiSettings4Fill } from 'react-icons/ri';
const ControlLayersSettingsPopover = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const canvasManager = useStore($canvasManager);
const clipToBbox = useAppSelector((s) => s.canvasV2.settings.clipToBbox);
const invertScroll = useAppSelector((s) => s.canvasV2.tool.invertScroll);
const onChangeInvertScroll = useCallback(
@ -38,6 +41,21 @@ const ControlLayersSettingsPopover = () => {
const invalidateRasterizationCaches = useCallback(() => {
dispatch(rasterizationCachesInvalidated());
}, [dispatch]);
const calculateBboxes = useCallback(() => {
if (!canvasManager) {
return;
}
for (const adapter of canvasManager.rasterLayerAdapters.values()) {
adapter.transformer.requestRectCalculation();
}
for (const adapter of canvasManager.controlLayerAdapters.values()) {
adapter.transformer.requestRectCalculation();
}
for (const adapter of canvasManager.regionalGuidanceAdapters.values()) {
adapter.transformer.requestRectCalculation();
}
canvasManager.inpaintMaskAdapter.transformer.requestRectCalculation();
}, [canvasManager]);
return (
<Popover isLazy>
<PopoverTrigger>
@ -58,6 +76,9 @@ const ControlLayersSettingsPopover = () => {
<Button onClick={invalidateRasterizationCaches} size="sm">
Invalidate Rasterization Caches
</Button>
<Button onClick={calculateBboxes} size="sm">
Calculate Bboxes
</Button>
</Flex>
</PopoverBody>
</PopoverContent>

View File

@ -1,5 +1,5 @@
/* eslint-disable i18next/no-literal-string */
import { Button, Flex, Switch } from '@invoke-ai/ui-library';
import { Flex, Switch } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { BrushWidth } from 'features/controlLayers/components/BrushWidth';
@ -12,36 +12,15 @@ import { ResetCanvasButton } from 'features/controlLayers/components/ResetCanvas
import { ToolChooser } from 'features/controlLayers/components/ToolChooser';
import { UndoRedoButtonGroup } from 'features/controlLayers/components/UndoRedoButtonGroup';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { nanoid } from 'features/controlLayers/konva/util';
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
import { ViewerToggleMenu } from 'features/gallery/components/ImageViewer/ViewerToggleMenu';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
const filter = () => {
const entity = $canvasManager.get()?.stateApi.getSelectedEntity();
if (!entity || entity.type !== 'layer') {
return;
}
entity.adapter.filter.previewFilter({
type: 'canny_image_processor',
id: nanoid(),
low_threshold: 50,
high_threshold: 50,
});
};
export const ControlLayersToolbar = memo(() => {
const tool = useAppSelector((s) => s.canvasV2.tool.selected);
const canvasManager = useStore($canvasManager);
const bbox = useCallback(() => {
if (!canvasManager) {
return;
}
for (const l of canvasManager.layers.values()) {
l.transformer.requestRectCalculation();
}
}, [canvasManager]);
const onChangeDebugging = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
if (!canvasManager) {
@ -61,7 +40,6 @@ export const ControlLayersToolbar = memo(() => {
<Flex gap={2} marginInlineEnd="auto" alignItems="center">
<ToggleProgressButton />
<ToolChooser />
<Button onClick={filter}>Filter</Button>
</Flex>
</Flex>
<Flex flex={1} gap={2} justifyContent="center" alignItems="center">
@ -70,7 +48,6 @@ export const ControlLayersToolbar = memo(() => {
</Flex>
<CanvasScale />
<CanvasResetViewButton />
<Button onClick={bbox}>bbox</Button>
<Switch onChange={onChangeDebugging}>debug</Switch>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto" alignItems="center">

View File

@ -13,7 +13,7 @@ export const DeleteAllLayersButton = memo(() => {
s.canvasV2.regions.entities.length +
// s.canvasV2.controlAdapters.entities.length +
s.canvasV2.ipAdapters.entities.length +
s.canvasV2.layers.entities.length
s.canvasV2.rasterLayers.entities.length
);
});
const onClick = useCallback(() => {

View File

@ -18,7 +18,7 @@ export const Filter = memo(() => {
return;
}
const entity = canvasManager.stateApi.getEntity(filteringEntity);
if (!entity || entity.type !== 'layer') {
if (!entity || (entity.type !== 'raster_layer' && entity.type !== 'control_layer')) {
return;
}
entity.adapter.filter.previewFilter();
@ -33,7 +33,7 @@ export const Filter = memo(() => {
return;
}
const entity = canvasManager.stateApi.getEntity(filteringEntity);
if (!entity || entity.type !== 'layer') {
if (!entity || (entity.type !== 'raster_layer' && entity.type !== 'control_layer')) {
return;
}
entity.adapter.filter.applyFilter();
@ -48,7 +48,7 @@ export const Filter = memo(() => {
return;
}
const entity = canvasManager.stateApi.getEntity(filteringEntity);
if (!entity || entity.type !== 'layer') {
if (!entity || (entity.type !== 'raster_layer' && entity.type !== 'control_layer')) {
return;
}
entity.adapter.filter.cancelFilter();

View File

@ -1,17 +0,0 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useFilter } from 'features/controlLayers/components/Filters/Filter';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
const FilterWrapper = (props: PropsWithChildren) => {
const isPreviewDisabled = useAppSelector((s) => s.canvasV2.selectedEntityIdentifier?.type !== 'layer');
const filter = useFilter();
return (
<Flex flexDir="column" gap={3} w="full" h="full">
{props.children}
</Flex>
);
};
export default memo(FilterWrapper);

View File

@ -1,4 +1,4 @@
import { Spacer, useDisclosure } from '@invoke-ai/ui-library';
import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
@ -15,18 +15,17 @@ type Props = {
export const IPAdapter = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'ip_adapter' }), [id]);
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true });
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer>
<CanvasEntityHeader onDoubleClick={onToggle}>
<CanvasEntityHeader>
<CanvasEntityEnabledToggle />
<CanvasEntityTitle />
<Spacer />
<CanvasEntityDeleteButton />
</CanvasEntityHeader>
{isOpen && <IPAdapterSettings />}
<IPAdapterSettings />
</CanvasEntityContainer>
</EntityIdentifierContext.Provider>
);

View File

@ -1,7 +1,7 @@
import { Box, Flex } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { CanvasEntitySettings } from 'features/controlLayers/components/common/CanvasEntitySettings';
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@ -73,7 +73,7 @@ export const IPAdapterSettings = memo(() => {
const postUploadAction = useMemo<IPALayerImagePostUploadAction>(() => ({ type: 'SET_IPA_IMAGE', id }), [id]);
return (
<CanvasEntitySettings>
<CanvasEntitySettingsWrapper>
<Flex flexDir="column" gap={4} position="relative" w="full">
<Flex gap={3} alignItems="center" w="full">
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
@ -102,7 +102,7 @@ export const IPAdapterSettings = memo(() => {
</Flex>
</Flex>
</Flex>
</CanvasEntitySettings>
</CanvasEntitySettingsWrapper>
);
});

View File

@ -1,32 +1,38 @@
import { Spacer, useDisclosure } from '@invoke-ai/ui-library';
import { Spacer } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
import { CanvasEntityGroupTitle } from 'features/controlLayers/components/common/CanvasEntityGroupTitle';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityTitle } from 'features/controlLayers/components/common/CanvasEntityTitle';
import { InpaintMaskActionsMenu } from 'features/controlLayers/components/InpaintMask/InpaintMaskActionsMenu';
import { InpaintMaskSettings } from 'features/controlLayers/components/InpaintMask/InpaintMaskSettings';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { InpaintMaskMaskFillColorPicker } from './InpaintMaskMaskFillColorPicker';
export const InpaintMask = memo(() => {
const { t } = useTranslation();
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id: 'inpaint_mask', type: 'inpaint_mask' }), []);
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: false });
const isSelected = useAppSelector((s) => Boolean(s.canvasV2.selectedEntityIdentifier?.type === 'inpaint_mask'));
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer>
<CanvasEntityHeader onDoubleClick={onToggle}>
<CanvasEntityEnabledToggle />
<CanvasEntityTitle />
<Spacer />
<InpaintMaskMaskFillColorPicker />
<InpaintMaskActionsMenu />
</CanvasEntityHeader>
{isOpen && <InpaintMaskSettings />}
</CanvasEntityContainer>
</EntityIdentifierContext.Provider>
<>
<CanvasEntityGroupTitle title={t('controlLayers.inpaintMask')} isSelected={isSelected} />
<EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer>
<CanvasEntityHeader>
<CanvasEntityEnabledToggle />
<CanvasEntityTitle />
<Spacer />
<InpaintMaskMaskFillColorPicker />
<InpaintMaskActionsMenu />
</CanvasEntityHeader>
</CanvasEntityContainer>
</EntityIdentifierContext.Provider>
</>
);
});

View File

@ -1,8 +0,0 @@
import { CanvasEntitySettings } from 'features/controlLayers/components/common/CanvasEntitySettings';
import { memo } from 'react';
export const InpaintMaskSettings = memo(() => {
return <CanvasEntitySettings>PLACEHOLDER</CanvasEntitySettings>;
});
InpaintMaskSettings.displayName = 'InpaintMaskSettings';

View File

@ -1,22 +0,0 @@
import { CanvasEntitySettings } from 'features/controlLayers/components/common/CanvasEntitySettings';
import { LayerControlAdapter } from 'features/controlLayers/components/Layer/LayerControlAdapter';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useLayerControlAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
import { memo } from 'react';
export const LayerSettings = memo(() => {
const entityIdentifier = useEntityIdentifierContext();
const controlAdapter = useLayerControlAdapter(entityIdentifier);
if (!controlAdapter) {
return null;
}
return (
<CanvasEntitySettings>
<LayerControlAdapter controlAdapter={controlAdapter} />
</CanvasEntitySettings>
);
});
LayerSettings.displayName = 'LayerSettings';

View File

@ -4,8 +4,7 @@ import { CanvasEntityDeleteButton } from 'features/controlLayers/components/comm
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityTitle } from 'features/controlLayers/components/common/CanvasEntityTitle';
import { LayerActionsMenu } from 'features/controlLayers/components/Layer/LayerActionsMenu';
import { LayerSettings } from 'features/controlLayers/components/Layer/LayerSettings';
import { RasterLayerActionsMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerActionsMenu';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { memo, useMemo } from 'react';
@ -14,8 +13,8 @@ type Props = {
id: string;
};
export const Layer = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'layer' }), [id]);
export const RasterLayer = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'raster_layer' }), [id]);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
@ -24,13 +23,12 @@ export const Layer = memo(({ id }: Props) => {
<CanvasEntityEnabledToggle />
<CanvasEntityTitle />
<Spacer />
<LayerActionsMenu />
<RasterLayerActionsMenu />
<CanvasEntityDeleteButton />
</CanvasEntityHeader>
<LayerSettings />
</CanvasEntityContainer>
</EntityIdentifierContext.Provider>
);
});
Layer.displayName = 'Layer';
RasterLayer.displayName = 'RasterLayer';

View File

@ -3,7 +3,7 @@ import { CanvasEntityActionMenuItems } from 'features/controlLayers/components/c
import { CanvasEntityMenuButton } from 'features/controlLayers/components/common/CanvasEntityMenuButton';
import { memo } from 'react';
export const LayerActionsMenu = memo(() => {
export const RasterLayerActionsMenu = memo(() => {
return (
<Menu>
<CanvasEntityMenuButton />
@ -14,4 +14,4 @@ export const LayerActionsMenu = memo(() => {
);
});
LayerActionsMenu.displayName = 'LayerActionsMenu';
RasterLayerActionsMenu.displayName = 'RasterLayerActionsMenu';

View File

@ -1,19 +1,19 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupTitle } from 'features/controlLayers/components/common/CanvasEntityGroupTitle';
import { Layer } from 'features/controlLayers/components/Layer/Layer';
import { RasterLayer } from 'features/controlLayers/components/RasterLayer/RasterLayer';
import { mapId } from 'features/controlLayers/konva/util';
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const selectEntityIds = createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => {
return canvasV2.layers.entities.map(mapId).reverse();
return canvasV2.rasterLayers.entities.map(mapId).reverse();
});
export const LayerEntityList = memo(() => {
export const RasterLayerEntityList = memo(() => {
const { t } = useTranslation();
const isSelected = useAppSelector((s) => Boolean(s.canvasV2.selectedEntityIdentifier?.type === 'layer'));
const isSelected = useAppSelector((s) => Boolean(s.canvasV2.selectedEntityIdentifier?.type === 'raster_layer'));
const layerIds = useAppSelector(selectEntityIds);
if (layerIds.length === 0) {
@ -24,15 +24,15 @@ export const LayerEntityList = memo(() => {
return (
<>
<CanvasEntityGroupTitle
title={t('controlLayers.layers_withCount', { count: layerIds.length })}
title={t('controlLayers.rasterLayers_withCount', { count: layerIds.length })}
isSelected={isSelected}
/>
{layerIds.map((id) => (
<Layer key={id} id={id} />
<RasterLayer key={id} id={id} />
))}
</>
);
}
});
LayerEntityList.displayName = 'LayerEntityList';
RasterLayerEntityList.displayName = 'RasterLayerEntityList';

View File

@ -1,4 +1,4 @@
import { Spacer, useDisclosure } from '@invoke-ai/ui-library';
import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
@ -20,11 +20,10 @@ type Props = {
export const RegionalGuidance = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'regional_guidance' }), [id]);
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true });
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer>
<CanvasEntityHeader onDoubleClick={onToggle}>
<CanvasEntityHeader>
<CanvasEntityEnabledToggle />
<CanvasEntityTitle />
<Spacer />
@ -34,7 +33,7 @@ export const RegionalGuidance = memo(({ id }: Props) => {
<RegionalGuidanceActionsMenu />
<CanvasEntityDeleteButton />
</CanvasEntityHeader>
{isOpen && <RegionalGuidanceSettings />}
<RegionalGuidanceSettings />
</CanvasEntityContainer>
</EntityIdentifierContext.Provider>
);

View File

@ -1,6 +1,6 @@
import { useAppSelector } from 'app/store/storeHooks';
import { AddPromptButtons } from 'features/controlLayers/components/AddPromptButtons';
import { CanvasEntitySettings } from 'features/controlLayers/components/common/CanvasEntitySettings';
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { selectRGOrThrow } from 'features/controlLayers/store/regionsReducers';
import { memo } from 'react';
@ -16,12 +16,12 @@ export const RegionalGuidanceSettings = memo(() => {
const hasIPAdapters = useAppSelector((s) => selectRGOrThrow(s.canvasV2, id).ipAdapters.length > 0);
return (
<CanvasEntitySettings>
<CanvasEntitySettingsWrapper>
{!hasPositivePrompt && !hasNegativePrompt && !hasIPAdapters && <AddPromptButtons id={id} />}
{hasPositivePrompt && <RegionalGuidancePositivePrompt id={id} />}
{hasNegativePrompt && <RegionalGuidanceNegativePrompt id={id} />}
{hasIPAdapters && <RegionalGuidanceIPAdapters id={id} />}
</CanvasEntitySettings>
</CanvasEntitySettingsWrapper>
);
});

View File

@ -3,16 +3,18 @@ import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useLayerUseAsControl } from 'features/controlLayers/hooks/useLayerControlAdapter';
import { useDefaultControlAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import {
$filteringEntity,
controlLayerConvertedToRasterLayer,
entityArrangedBackwardOne,
entityArrangedForwardOne,
entityArrangedToBack,
entityArrangedToFront,
entityDeleted,
entityReset,
rasterLayerConvertedToControlLayer,
selectCanvasV2Slice,
} from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasEntityIdentifier, CanvasV2State } from 'features/controlLayers/store/types';
@ -28,17 +30,21 @@ import {
PiQuestionMarkBold,
PiStarHalfBold,
PiTrashSimpleBold,
PiXBold,
} from 'react-icons/pi';
const getIndexAndCount = (
canvasV2: CanvasV2State,
{ id, type }: CanvasEntityIdentifier
): { index: number; count: number } => {
if (type === 'layer') {
if (type === 'raster_layer') {
return {
index: canvasV2.layers.entities.findIndex((entity) => entity.id === id),
count: canvasV2.layers.entities.length,
index: canvasV2.rasterLayers.entities.findIndex((entity) => entity.id === id),
count: canvasV2.rasterLayers.entities.length,
};
} else if (type === 'control_layer') {
return {
index: canvasV2.controlLayers.entities.findIndex((entity) => entity.id === id),
count: canvasV2.controlLayers.entities.length,
};
} else if (type === 'regional_guidance') {
return {
@ -58,7 +64,6 @@ export const CanvasEntityActionMenuItems = memo(() => {
const canvasManager = useStore($canvasManager);
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext();
const useAsControl = useLayerUseAsControl(entityIdentifier);
const selectValidActions = useMemo(
() =>
createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => {
@ -76,16 +81,39 @@ export const CanvasEntityActionMenuItems = memo(() => {
const validActions = useAppSelector(selectValidActions);
const isArrangeable = useMemo(
() => entityIdentifier.type === 'layer' || entityIdentifier.type === 'regional_guidance',
() =>
entityIdentifier.type === 'raster_layer' ||
entityIdentifier.type === 'control_layer' ||
entityIdentifier.type === 'regional_guidance',
[entityIdentifier.type]
);
const isDeleteable = useMemo(
() => entityIdentifier.type === 'layer' || entityIdentifier.type === 'regional_guidance',
() =>
entityIdentifier.type === 'raster_layer' ||
entityIdentifier.type === 'control_layer' ||
entityIdentifier.type === 'regional_guidance',
[entityIdentifier.type]
);
const isFilterable = useMemo(() => entityIdentifier.type === 'layer', [entityIdentifier.type]);
const isUseAsControlable = useMemo(() => entityIdentifier.type === 'layer', [entityIdentifier.type]);
const isFilterable = useMemo(
() => entityIdentifier.type === 'raster_layer' || entityIdentifier.type === 'control_layer',
[entityIdentifier.type]
);
const isRasterLayer = useMemo(() => entityIdentifier.type === 'raster_layer', [entityIdentifier.type]);
const isControlLayer = useMemo(() => entityIdentifier.type === 'control_layer', [entityIdentifier.type]);
const defaultControlAdapter = useDefaultControlAdapter();
const convertRasterLayerToControlLayer = useCallback(() => {
dispatch(rasterLayerConvertedToControlLayer({ id: entityIdentifier.id, controlAdapter: defaultControlAdapter }));
}, [dispatch, defaultControlAdapter, entityIdentifier.id]);
const convertControlLayerToRasterLayer = useCallback(() => {
dispatch(controlLayerConvertedToRasterLayer({ id: entityIdentifier.id }));
}, [dispatch, entityIdentifier.id]);
const deleteEntity = useCallback(() => {
dispatch(entityDeleted({ entityIdentifier }));
@ -142,9 +170,14 @@ export const CanvasEntityActionMenuItems = memo(() => {
{t('common.filter')}
</MenuItem>
)}
{isUseAsControlable && (
<MenuItem onClick={useAsControl.toggle} icon={useAsControl.hasControlAdapter ? <PiXBold /> : <PiCheckBold />}>
{useAsControl.hasControlAdapter ? t('common.removeControl') : t('common.useAsControl')}
{isRasterLayer && (
<MenuItem onClick={convertRasterLayerToControlLayer} icon={<PiCheckBold />}>
{t('common.convertToControlLayer')}
</MenuItem>
)}
{isControlLayer && (
<MenuItem onClick={convertControlLayerToRasterLayer} icon={<PiCheckBold />}>
{t('common.convertToRasterLayer')}
</MenuItem>
)}
<MenuDivider />

View File

@ -8,7 +8,7 @@ type Props = {
export const CanvasEntityGroupTitle = memo(({ title, isSelected }: Props) => {
return (
<Text color={isSelected ? 'base.100' : 'base.300'} userSelect="none">
<Text color={isSelected ? 'base.200' : 'base.500'} fontWeight="semibold" userSelect="none">
{title}
</Text>
);

View File

@ -2,7 +2,7 @@ import { Flex } from '@invoke-ai/ui-library';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
export const CanvasEntitySettings = memo(({ children }: PropsWithChildren) => {
export const CanvasEntitySettingsWrapper = memo(({ children }: PropsWithChildren) => {
return (
<Flex flexDir="column" gap={3} px={3} pb={3}>
{children}
@ -10,4 +10,4 @@ export const CanvasEntitySettings = memo(({ children }: PropsWithChildren) => {
);
});
CanvasEntitySettings.displayName = 'CanvasEntitySettings';
CanvasEntitySettingsWrapper.displayName = 'CanvasEntitySettingsWrapper';

View File

@ -1,10 +1,7 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import {
entityReset,
selectCanvasV2Slice,
} from 'features/controlLayers/store/canvasV2Slice';
import { entityReset, selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { useCallback, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
@ -17,7 +14,6 @@ export function useCanvasResetLayerHotkey() {
useAssertSingleton(useCanvasResetLayerHotkey.name);
const dispatch = useAppDispatch();
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const isStaging = useAppSelector((s) => s.canvasV2.session.isStaging);
const resetSelectedLayer = useCallback(() => {
if (selectedEntityIdentifier === null) {
@ -27,16 +23,9 @@ export function useCanvasResetLayerHotkey() {
}, [dispatch, selectedEntityIdentifier]);
const isResetEnabled = useMemo(
() =>
(!isStaging && selectedEntityIdentifier?.type === 'layer') ||
selectedEntityIdentifier?.type === 'regional_guidance' ||
selectedEntityIdentifier?.type === 'inpaint_mask',
[isStaging, selectedEntityIdentifier?.type]
() => selectedEntityIdentifier?.type === 'inpaint_mask',
[selectedEntityIdentifier?.type]
);
useHotkeys('shift+c', resetSelectedLayer, { enabled: isResetEnabled }, [
isResetEnabled,
isStaging,
resetSelectedLayer,
]);
useHotkeys('shift+c', resetSelectedLayer, { enabled: isResetEnabled }, [isResetEnabled, resetSelectedLayer]);
}

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectCanvasV2Slice, selectEntity } from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { type CanvasEntityIdentifier,isDrawableEntity } from 'features/controlLayers/store/types';
import { useMemo } from 'react';
export const useEntityObjectCount = (entityIdentifier: CanvasEntityIdentifier) => {
@ -11,11 +11,7 @@ export const useEntityObjectCount = (entityIdentifier: CanvasEntityIdentifier) =
const entity = selectEntity(canvasV2, entityIdentifier);
if (!entity) {
return 0;
} else if (entity.type === 'layer') {
return entity.objects.length;
} else if (entity.type === 'inpaint_mask') {
return entity.objects.length;
} else if (entity.type === 'regional_guidance') {
} else if (isDrawableEntity(entity)) {
return entity.objects.length;
} else {
return 0;

View File

@ -13,10 +13,10 @@ export const useEntityTitle = (entityIdentifier: CanvasEntityIdentifier) => {
const parts: string[] = [];
if (entityIdentifier.type === 'inpaint_mask') {
parts.push(t('controlLayers.inpaintMask'));
} else if (entityIdentifier.type === 'control_adapter') {
parts.push(t('controlLayers.globalControlAdapter'));
} else if (entityIdentifier.type === 'layer') {
parts.push(t('controlLayers.layer'));
} else if (entityIdentifier.type === 'control_layer') {
parts.push(t('controlLayers.controlLayer'));
} else if (entityIdentifier.type === 'raster_layer') {
parts.push(t('controlLayers.rasterLayer'));
} else if (entityIdentifier.type === 'ip_adapter') {
parts.push(t('controlLayers.ipAdapter'));
} else if (entityIdentifier.type === 'regional_guidance') {

View File

@ -1,23 +1,19 @@
import { createMemoizedAppSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { layerUsedAsControlChanged, selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { selectLayer } from 'features/controlLayers/store/layersReducers';
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { selectControlLayerOrThrow } from 'features/controlLayers/store/controlLayersReducers';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { initialControlNetV2, initialT2IAdapterV2 } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { useCallback, useMemo } from 'react';
import { useMemo } from 'react';
import { useControlNetAndT2IAdapterModels } from 'services/api/hooks/modelsByType';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
export const useLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier) => {
export const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier) => {
const selectControlAdapter = useMemo(
() =>
createMemoizedAppSelector(selectCanvasV2Slice, (canvasV2) => {
const layer = selectLayer(canvasV2, entityIdentifier.id);
if (!layer) {
return null;
}
const layer = selectControlLayerOrThrow(canvasV2, entityIdentifier.id);
return layer.controlAdapter;
}),
[entityIdentifier]
@ -26,32 +22,23 @@ export const useLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier)
return controlAdapter;
};
export const useLayerUseAsControl = (entityIdentifier: CanvasEntityIdentifier) => {
const dispatch = useAppDispatch();
export const useDefaultControlAdapter = () => {
const [modelConfigs] = useControlNetAndT2IAdapterModels();
const baseModel = useAppSelector((s) => s.canvasV2.params.model?.base);
const controlAdapter = useLayerControlAdapter(entityIdentifier);
const model: ControlNetModelConfig | T2IAdapterModelConfig | null = useMemo(() => {
// prefer to use a model that matches the base model
const defaultControlAdapter = useMemo(() => {
const compatibleModels = modelConfigs.filter((m) => (baseModel ? m.base === baseModel : true));
return compatibleModels[0] ?? modelConfigs[0] ?? null;
}, [baseModel, modelConfigs]);
const toggle = useCallback(() => {
if (controlAdapter) {
dispatch(layerUsedAsControlChanged({ id: entityIdentifier.id, controlAdapter: null }));
return;
}
const newControlAdapter = deepClone(model?.type === 't2i_adapter' ? initialT2IAdapterV2 : initialControlNetV2);
const model = compatibleModels[0] ?? modelConfigs[0] ?? null;
const controlAdapter =
model?.type === 't2i_adapter' ? deepClone(initialT2IAdapterV2) : deepClone(initialControlNetV2);
if (model) {
newControlAdapter.model = zModelIdentifierField.parse(model);
controlAdapter.model = zModelIdentifierField.parse(model);
}
dispatch(layerUsedAsControlChanged({ id: entityIdentifier.id, controlAdapter: newControlAdapter }));
}, [controlAdapter, dispatch, entityIdentifier.id, model]);
return controlAdapter;
}, [baseModel, modelConfigs]);
return { hasControlAdapter: Boolean(controlAdapter), toggle };
return defaultControlAdapter;
};

View File

@ -4,7 +4,12 @@ import { CanvasFilter } from 'features/controlLayers/konva/CanvasFilter';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasObjectRenderer } from 'features/controlLayers/konva/CanvasObjectRenderer';
import { CanvasTransformer } from 'features/controlLayers/konva/CanvasTransformer';
import type { CanvasEntityIdentifier, CanvasLayerState, CanvasV2State } from 'features/controlLayers/store/types';
import type {
CanvasControlLayerState,
CanvasEntityIdentifier,
CanvasRasterLayerState,
CanvasV2State,
} from 'features/controlLayers/store/types';
import Konva from 'konva';
import { get } from 'lodash-es';
import type { Logger } from 'roarr';
@ -17,7 +22,7 @@ export class CanvasLayerAdapter {
manager: CanvasManager;
log: Logger;
state: CanvasLayerState;
state: CanvasRasterLayerState | CanvasControlLayerState;
konva: {
layer: Konva.Layer;
@ -110,7 +115,7 @@ export class CanvasLayerAdapter {
this.konva.layer.visible(isEnabled);
};
updateObjects = async (arg?: { objects: CanvasLayerState['objects'] }) => {
updateObjects = async (arg?: { objects: CanvasRasterLayerState['objects'] }) => {
this.log.trace('Updating objects');
const objects = get(arg, 'objects', this.state.objects);

View File

@ -29,7 +29,6 @@ import { getImageDTO, uploadImage } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { CanvasBackground } from './CanvasBackground';
import type { CanvasControlAdapter } from './CanvasControlAdapter';
import { CanvasLayerAdapter } from './CanvasLayerAdapter';
import { CanvasMaskAdapter } from './CanvasMaskAdapter';
import { CanvasPreview } from './CanvasPreview';
@ -46,10 +45,10 @@ export class CanvasManager {
path: string[];
stage: Konva.Stage;
container: HTMLDivElement;
controlAdapters: Map<string, CanvasControlAdapter>;
layers: Map<string, CanvasLayerAdapter>;
regions: Map<string, CanvasMaskAdapter>;
inpaintMask: CanvasMaskAdapter;
rasterLayerAdapters: Map<string, CanvasLayerAdapter> = new Map();
controlLayerAdapters: Map<string, CanvasLayerAdapter> = new Map();
regionalGuidanceAdapters: Map<string, CanvasMaskAdapter> = new Map();
inpaintMaskAdapter: CanvasMaskAdapter;
stateApi: CanvasStateApi;
preview: CanvasPreview;
background: CanvasBackground;
@ -94,10 +93,6 @@ export class CanvasManager {
this.background = new CanvasBackground(this);
this.stage.add(this.background.konva.layer);
this.layers = new Map();
this.regions = new Map();
this.controlAdapters = new Map();
this._worker.onmessage = (event: MessageEvent<ExtentsResult | WorkerLogMessage>) => {
const { type, data } = event.data;
if (type === 'log') {
@ -128,8 +123,8 @@ export class CanvasManager {
this.stateApi.$currentFill.set(this.stateApi.getCurrentFill());
this.stateApi.$selectedEntity.set(this.stateApi.getSelectedEntity());
this.inpaintMask = new CanvasMaskAdapter(this.stateApi.getInpaintMaskState(), this);
this.stage.add(this.inpaintMask.konva.layer);
this.inpaintMaskAdapter = new CanvasMaskAdapter(this.stateApi.getInpaintMaskState(), this);
this.stage.add(this.inpaintMaskAdapter.konva.layer);
}
enableDebugging() {
@ -152,18 +147,24 @@ export class CanvasManager {
}
arrangeEntities() {
const { getLayersState, getRegionsState } = this.stateApi;
const layers = getLayersState().entities;
const regions = getRegionsState().entities;
let zIndex = 0;
this.background.konva.layer.zIndex(++zIndex);
for (const layer of layers) {
this.layers.get(layer.id)?.konva.layer.zIndex(++zIndex);
for (const layer of this.stateApi.getRasterLayersState().entities) {
this.rasterLayerAdapters.get(layer.id)?.konva.layer.zIndex(++zIndex);
}
for (const rg of regions) {
this.regions.get(rg.id)?.konva.layer.zIndex(++zIndex);
for (const layer of this.stateApi.getControlLayersState().entities) {
this.controlLayerAdapters.get(layer.id)?.konva.layer.zIndex(++zIndex);
}
this.inpaintMask.konva.layer.zIndex(++zIndex);
for (const rg of this.stateApi.getRegionsState().entities) {
this.regionalGuidanceAdapters.get(rg.id)?.konva.layer.zIndex(++zIndex);
}
this.inpaintMaskAdapter.konva.layer.zIndex(++zIndex);
this.preview.getLayer().zIndex(++zIndex);
}
@ -215,12 +216,14 @@ export class CanvasManager {
const { id, type } = transformingEntity;
if (type === 'layer') {
return this.layers.get(id) ?? null;
if (type === 'raster_layer') {
return this.rasterLayerAdapters.get(id) ?? null;
} else if (type === 'control_layer') {
return this.controlLayerAdapters.get(id) ?? null;
} else if (type === 'inpaint_mask') {
return this.inpaintMask;
return this.inpaintMaskAdapter;
} else if (type === 'regional_guidance') {
return this.regions.get(id) ?? null;
return this.regionalGuidanceAdapters.get(id) ?? null;
}
return null;
@ -268,21 +271,46 @@ export class CanvasManager {
return;
}
if (this._isFirstRender || state.layers.entities !== this._prevState.layers.entities) {
this.log.debug('Rendering layers');
if (this._isFirstRender || state.rasterLayers.entities !== this._prevState.rasterLayers.entities) {
this.log.debug('Rendering raster layers');
for (const canvasLayer of this.layers.values()) {
if (!state.layers.entities.find((l) => l.id === canvasLayer.id)) {
for (const canvasLayer of this.rasterLayerAdapters.values()) {
if (!state.rasterLayers.entities.find((l) => l.id === canvasLayer.id)) {
await canvasLayer.destroy();
this.layers.delete(canvasLayer.id);
this.rasterLayerAdapters.delete(canvasLayer.id);
}
}
for (const entityState of state.layers.entities) {
let adapter = this.layers.get(entityState.id);
for (const entityState of state.rasterLayers.entities) {
let adapter = this.rasterLayerAdapters.get(entityState.id);
if (!adapter) {
adapter = new CanvasLayerAdapter(entityState, this);
this.layers.set(adapter.id, adapter);
this.rasterLayerAdapters.set(adapter.id, adapter);
this.stage.add(adapter.konva.layer);
}
await adapter.update({
state: entityState,
toolState: state.tool,
isSelected: state.selectedEntityIdentifier?.id === entityState.id,
});
}
}
if (this._isFirstRender || state.controlLayers.entities !== this._prevState.controlLayers.entities) {
this.log.debug('Rendering control layers');
for (const canvasLayer of this.controlLayerAdapters.values()) {
if (!state.controlLayers.entities.find((l) => l.id === canvasLayer.id)) {
await canvasLayer.destroy();
this.controlLayerAdapters.delete(canvasLayer.id);
}
}
for (const entityState of state.controlLayers.entities) {
let adapter = this.controlLayerAdapters.get(entityState.id);
if (!adapter) {
adapter = new CanvasLayerAdapter(entityState, this);
this.controlLayerAdapters.set(adapter.id, adapter);
this.stage.add(adapter.konva.layer);
}
await adapter.update({
@ -303,18 +331,18 @@ export class CanvasManager {
this.log.debug('Rendering regions');
// Destroy the konva nodes for nonexistent entities
for (const canvasRegion of this.regions.values()) {
for (const canvasRegion of this.regionalGuidanceAdapters.values()) {
if (!state.regions.entities.find((rg) => rg.id === canvasRegion.id)) {
canvasRegion.destroy();
this.regions.delete(canvasRegion.id);
this.regionalGuidanceAdapters.delete(canvasRegion.id);
}
}
for (const entityState of state.regions.entities) {
let adapter = this.regions.get(entityState.id);
let adapter = this.regionalGuidanceAdapters.get(entityState.id);
if (!adapter) {
adapter = new CanvasMaskAdapter(entityState, this);
this.regions.set(adapter.id, adapter);
this.regionalGuidanceAdapters.set(adapter.id, adapter);
this.stage.add(adapter.konva.layer);
}
await adapter.update({
@ -333,7 +361,7 @@ export class CanvasManager {
state.selectedEntityIdentifier?.id !== this._prevState.selectedEntityIdentifier?.id
) {
this.log.debug('Rendering inpaint mask');
await this.inpaintMask.update({
await this.inpaintMaskAdapter.update({
state: state.inpaintMask,
toolState: state.tool,
isSelected: state.selectedEntityIdentifier?.id === state.inpaintMask.id,
@ -354,11 +382,6 @@ export class CanvasManager {
await this.preview.bbox.render();
}
if (this._isFirstRender || state.layers !== this._prevState.layers || state.regions !== this._prevState.regions) {
// this.log.debug('Updating entity bboxes');
// debouncedUpdateBboxes(stage, canvasV2.layers, canvasV2.controlAdapters, canvasV2.regions, onBboxChanged);
}
if (this._isFirstRender || state.session !== this._prevState.session) {
this.log.debug('Rendering staging area');
await this.preview.stagingArea.render();
@ -366,7 +389,7 @@ export class CanvasManager {
if (
this._isFirstRender ||
state.layers.entities !== this._prevState.layers.entities ||
state.rasterLayers.entities !== this._prevState.rasterLayers.entities ||
state.regions.entities !== this._prevState.regions.entities ||
state.inpaintMask !== this._prevState.inpaintMask ||
state.selectedEntityIdentifier?.id !== this._prevState.selectedEntityIdentifier?.id
@ -402,15 +425,15 @@ export class CanvasManager {
return () => {
this.log.debug('Cleaning up konva renderer');
this.inpaintMask.destroy();
for (const region of this.regions.values()) {
region.destroy();
this.inpaintMaskAdapter.destroy();
for (const adapter of this.regionalGuidanceAdapters.values()) {
adapter.destroy();
}
for (const layer of this.layers.values()) {
layer.destroy();
for (const adapter of this.rasterLayerAdapters.values()) {
adapter.destroy();
}
for (const controlAdapter of this.controlAdapters.values()) {
controlAdapter.destroy();
for (const adapter of this.controlLayerAdapters.values()) {
adapter.destroy();
}
this.background.destroy();
this.preview.destroy();
@ -507,7 +530,7 @@ export class CanvasManager {
}
getCompositeLayerStageClone = (): Konva.Stage => {
const layersState = this.stateApi.getLayersState();
const layersState = this.stateApi.getRasterLayersState();
const stageClone = this.stage.clone();
stageClone.scaleX(1);
@ -536,7 +559,7 @@ export class CanvasManager {
};
getCompositeRasterizedImageCache = (rect: Rect): ImageCache | null => {
const layerState = this.stateApi.getLayersState();
const layerState = this.stateApi.getRasterLayersState();
const imageCache = layerState.compositeRasterizationCache.find((cache) => isEqual(cache.rect, rect));
return imageCache ?? null;
};
@ -567,11 +590,11 @@ export class CanvasManager {
};
getInpaintMaskBlob = (rect?: Rect): Promise<Blob> => {
return this.inpaintMask.renderer.getBlob(rect);
return this.inpaintMaskAdapter.renderer.getBlob(rect);
};
getInpaintMaskImageData = (rect?: Rect): ImageData => {
return this.inpaintMask.renderer.getImageData(rect);
return this.inpaintMaskAdapter.renderer.getImageData(rect);
};
getGenerationMode(): GenerationMode {
@ -617,7 +640,7 @@ export class CanvasManager {
logDebugInfo() {
// eslint-disable-next-line no-console
console.log(this);
for (const layer of this.layers.values()) {
for (const layer of this.rasterLayerAdapters.values()) {
// eslint-disable-next-line no-console
console.log(layer);
}

View File

@ -26,14 +26,15 @@ import {
entityReset,
entitySelected,
eraserWidthChanged,
layerCompositeRasterized,
rasterLayerCompositeRasterized,
toolBufferChanged,
toolChanged,
} from 'features/controlLayers/store/canvasV2Slice';
import type {
CanvasControlLayerState,
CanvasEntityIdentifier,
CanvasInpaintMaskState,
CanvasLayerState,
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
CanvasV2State,
EntityBrushLineAddedPayload,
@ -53,8 +54,14 @@ import { atom } from 'nanostores';
type EntityStateAndAdapter =
| {
id: string;
type: CanvasLayerState['type'];
state: CanvasLayerState;
type: CanvasRasterLayerState['type'];
state: CanvasRasterLayerState;
adapter: CanvasLayerAdapter;
}
| {
id: string;
type: CanvasControlLayerState['type'];
state: CanvasControlLayerState;
adapter: CanvasLayerAdapter;
}
| {
@ -63,12 +70,6 @@ type EntityStateAndAdapter =
state: CanvasInpaintMaskState;
adapter: CanvasMaskAdapter;
}
// | {
// id: string;
// type: CanvasControlAdapterState['type'];
// state: CanvasControlAdapterState;
// adapter: CanvasControlAdapter;
// }
| {
id: string;
type: CanvasRegionalGuidanceState['type'];
@ -117,7 +118,7 @@ export class CanvasStateApi {
};
compositeLayerRasterized = (arg: { imageName: string; rect: Rect }) => {
log.trace(arg, 'Composite layer rasterized');
this._store.dispatch(layerCompositeRasterized(arg));
this._store.dispatch(rasterLayerCompositeRasterized(arg));
};
setSelectedEntity = (arg: EntityIdentifierPayload) => {
log.trace({ arg }, 'Setting selected entity');
@ -157,8 +158,11 @@ export class CanvasStateApi {
getRegionsState = () => {
return this.getState().regions;
};
getLayersState = () => {
return this.getState().layers;
getRasterLayersState = () => {
return this.getState().rasterLayers;
};
getControlLayersState = () => {
return this.getState().controlLayers;
};
getInpaintMaskState = () => {
return this.getState().inpaintMask;
@ -185,15 +189,18 @@ export class CanvasStateApi {
let entityState: EntityStateAndAdapter['state'] | null = null;
let entityAdapter: EntityStateAndAdapter['adapter'] | null = null;
if (identifier.type === 'layer') {
entityState = state.layers.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.manager.layers.get(identifier.id) ?? null;
if (identifier.type === 'raster_layer') {
entityState = state.rasterLayers.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.manager.rasterLayerAdapters.get(identifier.id) ?? null;
} else if (identifier.type === 'control_layer') {
entityState = state.controlLayers.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.manager.controlLayerAdapters.get(identifier.id) ?? null;
} else if (identifier.type === 'regional_guidance') {
entityState = state.regions.entities.find((i) => i.id === identifier.id) ?? null;
entityAdapter = this.manager.regions.get(identifier.id) ?? null;
entityAdapter = this.manager.regionalGuidanceAdapters.get(identifier.id) ?? null;
} else if (identifier.type === 'inpaint_mask') {
entityState = state.inpaintMask;
entityAdapter = this.manager.inpaintMask;
entityAdapter = this.manager.inpaintMaskAdapter;
}
if (entityState && entityAdapter) {

View File

@ -8,6 +8,7 @@ import {
BRUSH_ERASER_BORDER_WIDTH,
} from 'features/controlLayers/konva/constants';
import { alignCoordForTool, getPrefixedId } from 'features/controlLayers/konva/util';
import { isDrawableEntity } from 'features/controlLayers/store/types';
import Konva from 'konva';
import type { Logger } from 'roarr';
@ -156,10 +157,7 @@ export class CanvasTool {
const tool = toolState.selected;
const isDrawableEntity =
selectedEntity?.state.type === 'regional_guidance' ||
selectedEntity?.state.type === 'layer' ||
selectedEntity?.state.type === 'inpaint_mask';
const isDrawable = selectedEntity && isDrawableEntity(selectedEntity.state);
// Update the stage's pointer style
if (tool === 'view') {
@ -168,7 +166,7 @@ export class CanvasTool {
} else if (renderedEntityCount === 0) {
// We have no layers, so we should not render any tool
stage.container().style.cursor = 'default';
} else if (!isDrawableEntity) {
} else if (!isDrawable) {
// Non-drawable layers don't have tools
stage.container().style.cursor = 'not-allowed';
} else if (tool === 'move' || Boolean(this.manager.stateApi.$transformingEntity.get())) {
@ -186,7 +184,7 @@ export class CanvasTool {
stage.draggable(tool === 'view');
if (!cursorPos || renderedEntityCount === 0 || !isDrawableEntity) {
if (!cursorPos || renderedEntityCount === 0 || !isDrawable) {
// We can bail early if the mouse isn't over the stage or there are no layers
this.konva.group.visible(false);
} else {

View File

@ -6,8 +6,9 @@ import {
offsetCoord,
} from 'features/controlLayers/konva/util';
import type {
CanvasControlLayerState,
CanvasInpaintMaskState,
CanvasLayerState,
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
CanvasV2State,
Coordinate,
@ -84,7 +85,7 @@ const getLastPointOfLine = (points: number[]): Coordinate | null => {
};
const getLastPointOfLastLineOfEntity = (
entity: CanvasLayerState | CanvasRegionalGuidanceState | CanvasInpaintMaskState,
entity: CanvasRasterLayerState | CanvasControlLayerState | CanvasRegionalGuidanceState | CanvasInpaintMaskState,
tool: Tool
): Coordinate | null => {
const lastObject = entity.objects[entity.objects.length - 1];
@ -138,7 +139,9 @@ export const setStageEventHandlers = (manager: CanvasManager): (() => void) => {
return e.evt.buttons === 1;
}
function getClip(entity: CanvasRegionalGuidanceState | CanvasLayerState | CanvasInpaintMaskState) {
function getClip(
entity: CanvasRegionalGuidanceState | CanvasControlLayerState | CanvasRasterLayerState | CanvasInpaintMaskState
) {
const settings = getSettings();
const bboxRect = getBbox().rect;

View File

@ -12,12 +12,12 @@ import { pick } from 'lodash-es';
export const bboxReducers = {
bboxScaledSizeChanged: (state, action: PayloadAction<Partial<Dimensions>>) => {
state.layers.imageCache = null;
state.rasterLayers.imageCache = null;
state.bbox.scaledSize = { ...state.bbox.scaledSize, ...action.payload };
},
bboxScaleMethodChanged: (state, action: PayloadAction<BoundingBoxScaleMethod>) => {
state.bbox.scaleMethod = action.payload;
state.layers.imageCache = null;
state.rasterLayers.imageCache = null;
if (action.payload === 'auto') {
const optimalDimension = getOptimalDimension(state.params.model);
@ -27,7 +27,7 @@ export const bboxReducers = {
},
bboxChanged: (state, action: PayloadAction<IRect>) => {
state.bbox.rect = action.payload;
state.layers.imageCache = null;
state.rasterLayers.imageCache = null;
if (state.bbox.scaleMethod === 'auto') {
const optimalDimension = getOptimalDimension(state.params.model);

View File

@ -5,11 +5,12 @@ import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/uti
import { deepClone } from 'common/util/deepClone';
import { bboxReducers } from 'features/controlLayers/store/bboxReducers';
import { compositingReducers } from 'features/controlLayers/store/compositingReducers';
import { controlLayersReducers } from 'features/controlLayers/store/controlLayersReducers';
import { inpaintMaskReducers } from 'features/controlLayers/store/inpaintMaskReducers';
import { ipAdaptersReducers } from 'features/controlLayers/store/ipAdaptersReducers';
import { layersReducers } from 'features/controlLayers/store/layersReducers';
import { lorasReducers } from 'features/controlLayers/store/lorasReducers';
import { paramsReducers } from 'features/controlLayers/store/paramsReducers';
import { rasterLayersReducers } from 'features/controlLayers/store/rasterLayersReducers';
import { regionsReducers } from 'features/controlLayers/store/regionsReducers';
import { sessionReducers } from 'features/controlLayers/store/sessionReducers';
import { settingsReducers } from 'features/controlLayers/store/settingsReducers';
@ -23,9 +24,10 @@ import type { InvocationDenoiseProgressEvent } from 'services/events/types';
import { assert } from 'tsafe';
import type {
CanvasControlLayerState,
CanvasEntityIdentifier,
CanvasInpaintMaskState,
CanvasLayerState,
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
CanvasV2State,
Coordinate,
@ -38,12 +40,13 @@ import type {
FilterConfig,
StageAttrs,
} from './types';
import { IMAGE_FILTERS, RGBA_RED } from './types';
import { IMAGE_FILTERS, isDrawableEntity, RGBA_RED } from './types';
const initialState: CanvasV2State = {
_version: 3,
selectedEntityIdentifier: null,
layers: { entities: [], compositeRasterizationCache: [] },
rasterLayers: { entities: [], compositeRasterizationCache: [] },
controlLayers: { entities: [] },
ipAdapters: { entities: [] },
regions: { entities: [] },
loras: [],
@ -143,27 +146,21 @@ const initialState: CanvasV2State = {
export function selectEntity(state: CanvasV2State, { id, type }: CanvasEntityIdentifier) {
switch (type) {
case 'layer':
return state.layers.entities.find((layer) => layer.id === id);
case 'raster_layer':
return state.rasterLayers.entities.find((layer) => layer.id === id);
case 'control_layer':
return state.controlLayers.entities.find((layer) => layer.id === id);
case 'inpaint_mask':
return state.inpaintMask;
case 'regional_guidance':
return state.regions.entities.find((rg) => rg.id === id);
case 'ip_adapter':
return state.ipAdapters.entities.find((ip) => ip.id === id);
default:
return;
}
}
const invalidateCompositeRasterizationCache = (entity: CanvasLayerState, state: CanvasV2State) => {
if (entity.controlAdapter === null) {
state.layers.compositeRasterizationCache = [];
}
};
const invalidateRasterizationCaches = (
entity: CanvasLayerState | CanvasInpaintMaskState | CanvasRegionalGuidanceState,
entity: CanvasRasterLayerState | CanvasControlLayerState | CanvasInpaintMaskState | CanvasRegionalGuidanceState,
state: CanvasV2State
) => {
// TODO(psyche): We can be more efficient and only invalidate caches when the entity's changes intersect with the
@ -176,8 +173,8 @@ const invalidateRasterizationCaches = (
// layer's image data will contribute to the composite layer's image data.
// If the layer is used as a control layer, it will not contribute to the composite layer, so we do not need to reset
// its cache.
if (entity.type === 'layer') {
invalidateCompositeRasterizationCache(entity, state);
if (entity.type === 'raster_layer') {
state.rasterLayers.compositeRasterizationCache = [];
}
};
@ -185,7 +182,8 @@ export const canvasV2Slice = createSlice({
name: 'canvasV2',
initialState,
reducers: {
...layersReducers,
...rasterLayersReducers,
...controlLayersReducers,
...ipAdaptersReducers,
...regionsReducers,
...lorasReducers,
@ -205,7 +203,7 @@ export const canvasV2Slice = createSlice({
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
} else if (entity.type === 'layer' || entity.type === 'inpaint_mask' || entity.type === 'regional_guidance') {
} else if (isDrawableEntity(entity)) {
entity.isEnabled = true;
entity.objects = [];
entity.position = { x: 0, y: 0 };
@ -229,7 +227,7 @@ export const canvasV2Slice = createSlice({
return;
}
if (entity.type === 'layer' || entity.type === 'inpaint_mask' || entity.type === 'regional_guidance') {
if (isDrawableEntity(entity)) {
entity.position = position;
// When an entity is moved, we need to invalidate the rasterization caches.
invalidateRasterizationCaches(entity, state);
@ -242,7 +240,7 @@ export const canvasV2Slice = createSlice({
return;
}
if (entity.type === 'layer' || entity.type === 'inpaint_mask' || entity.type === 'regional_guidance') {
if (isDrawableEntity(entity)) {
entity.objects = [imageObject];
entity.position = { x: rect.x, y: rect.y };
// Remove the cache for the given rect. This should never happen, because we should never rasterize the same
@ -258,7 +256,7 @@ export const canvasV2Slice = createSlice({
return;
}
if (entity.type === 'layer' || entity.type === 'inpaint_mask' || entity.type === 'regional_guidance') {
if (isDrawableEntity(entity)) {
entity.objects.push(brushLine);
// When adding a brush line, we need to invalidate the rasterization caches.
invalidateRasterizationCaches(entity, state);
@ -269,7 +267,7 @@ export const canvasV2Slice = createSlice({
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
} else if (entity.type === 'layer' || entity.type === 'inpaint_mask' || entity.type === 'regional_guidance') {
} else if (isDrawableEntity(entity)) {
entity.objects.push(eraserLine);
// When adding an eraser line, we need to invalidate the rasterization caches.
invalidateRasterizationCaches(entity, state);
@ -282,7 +280,7 @@ export const canvasV2Slice = createSlice({
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
} else if (entity.type === 'layer') {
} else if (isDrawableEntity(entity)) {
entity.objects.push(rect);
// When adding an eraser line, we need to invalidate the rasterization caches.
invalidateRasterizationCaches(entity, state);
@ -292,18 +290,37 @@ export const canvasV2Slice = createSlice({
},
entityDeleted: (state, action: PayloadAction<EntityIdentifierPayload>) => {
const { entityIdentifier } = action.payload;
const entity = selectEntity(state, entityIdentifier);
if (entity?.type === 'layer') {
// When a layer is deleted, we may need to invalidate the composite rasterization cache.
invalidateCompositeRasterizationCache(entity, state);
}
if (entityIdentifier.type === 'layer') {
state.layers.entities = state.layers.entities.filter((layer) => layer.id !== entityIdentifier.id);
let selectedEntityIdentifier: CanvasEntityIdentifier = { type: state.inpaintMask.type, id: state.inpaintMask.id };
if (entityIdentifier.type === 'raster_layer') {
// When deleting a raster layer, we need to invalidate the composite rasterization cache.
const index = state.rasterLayers.entities.findIndex((layer) => layer.id === entityIdentifier.id);
state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== entityIdentifier.id);
state.rasterLayers.compositeRasterizationCache = [];
const nextRasterLayer = state.rasterLayers.entities[index];
if (nextRasterLayer) {
selectedEntityIdentifier = { type: nextRasterLayer.type, id: nextRasterLayer.id };
}
} else if (entityIdentifier.type === 'control_layer') {
const index = state.controlLayers.entities.findIndex((layer) => layer.id === entityIdentifier.id);
state.controlLayers.entities = state.controlLayers.entities.filter((rg) => rg.id !== entityIdentifier.id);
const nextControlLayer = state.controlLayers.entities[index];
if (nextControlLayer) {
selectedEntityIdentifier = { type: nextControlLayer.type, id: nextControlLayer.id };
}
} else if (entityIdentifier.type === 'regional_guidance') {
const index = state.regions.entities.findIndex((layer) => layer.id === entityIdentifier.id);
state.regions.entities = state.regions.entities.filter((rg) => rg.id !== entityIdentifier.id);
const region = state.regions.entities[index];
if (region) {
selectedEntityIdentifier = { type: region.type, id: region.id };
}
} else {
assert(false, 'Not implemented');
}
state.selectedEntityIdentifier = selectedEntityIdentifier;
},
entityArrangedForwardOne: (state, action: PayloadAction<EntityIdentifierPayload>) => {
const { entityIdentifier } = action.payload;
@ -311,10 +328,12 @@ export const canvasV2Slice = createSlice({
if (!entity) {
return;
}
if (entity.type === 'layer') {
moveOneToEnd(state.layers.entities, entity);
// When arranging an entity, we may need to invalidate the composite rasterization cache.
invalidateCompositeRasterizationCache(entity, state);
if (entity.type === 'raster_layer') {
moveOneToEnd(state.rasterLayers.entities, entity);
// When arranging a raster layer, we need to invalidate the composite rasterization cache.
state.rasterLayers.compositeRasterizationCache = [];
} else if (entity.type === 'control_layer') {
moveOneToEnd(state.controlLayers.entities, entity);
} else if (entity.type === 'regional_guidance') {
moveOneToEnd(state.regions.entities, entity);
}
@ -325,10 +344,12 @@ export const canvasV2Slice = createSlice({
if (!entity) {
return;
}
if (entity.type === 'layer') {
moveToEnd(state.layers.entities, entity);
// When arranging an entity, we may need to invalidate the composite rasterization cache.
invalidateCompositeRasterizationCache(entity, state);
if (entity.type === 'raster_layer') {
moveToEnd(state.rasterLayers.entities, entity);
// When arranging a raster layer, we need to invalidate the composite rasterization cache.
state.rasterLayers.compositeRasterizationCache = [];
} else if (entity.type === 'control_layer') {
moveToEnd(state.controlLayers.entities, entity);
} else if (entity.type === 'regional_guidance') {
moveToEnd(state.regions.entities, entity);
}
@ -339,10 +360,11 @@ export const canvasV2Slice = createSlice({
if (!entity) {
return;
}
if (entity.type === 'layer') {
moveOneToStart(state.layers.entities, entity);
// When arranging an entity, we may need to invalidate the composite rasterization cache.
invalidateCompositeRasterizationCache(entity, state);
if (entity.type === 'raster_layer') {
moveOneToStart(state.rasterLayers.entities, entity);
// When arranging a raster layer, we need to invalidate the composite rasterization cache.
} else if (entity.type === 'control_layer') {
moveOneToStart(state.controlLayers.entities, entity);
} else if (entity.type === 'regional_guidance') {
moveOneToStart(state.regions.entities, entity);
}
@ -353,18 +375,19 @@ export const canvasV2Slice = createSlice({
if (!entity) {
return;
}
if (entity.type === 'layer') {
moveToStart(state.layers.entities, entity);
// When arranging an entity, we may need to invalidate the composite rasterization cache.
invalidateCompositeRasterizationCache(entity, state);
if (entity.type === 'raster_layer') {
moveToStart(state.rasterLayers.entities, entity);
state.rasterLayers.compositeRasterizationCache = [];
} else if (entity.type === 'control_layer') {
moveToStart(state.controlLayers.entities, entity);
} else if (entity.type === 'regional_guidance') {
moveToStart(state.regions.entities, entity);
}
},
allEntitiesDeleted: (state) => {
state.regions.entities = [];
state.layers.entities = [];
state.layers.compositeRasterizationCache = [];
state.rasterLayers.entities = [];
state.rasterLayers.compositeRasterizationCache = [];
state.ipAdapters.entities = [];
},
filterSelected: (state, action: PayloadAction<{ type: FilterConfig['type'] }>) => {
@ -377,8 +400,8 @@ export const canvasV2Slice = createSlice({
// Invalidate the rasterization caches for all entities.
// Layers & composite layer
state.layers.compositeRasterizationCache = [];
for (const layer of state.layers.entities) {
state.rasterLayers.compositeRasterizationCache = [];
for (const layer of state.rasterLayers.entities) {
layer.rasterizationCache = [];
}
@ -399,7 +422,8 @@ export const canvasV2Slice = createSlice({
state.bbox.scaledSize = getScaledBoundingBoxDimensions(size, optimalDimension);
state.ipAdapters = deepClone(initialState.ipAdapters);
state.layers = deepClone(initialState.layers);
state.rasterLayers = deepClone(initialState.rasterLayers);
state.controlLayers = deepClone(initialState.controlLayers);
state.regions = deepClone(initialState.regions);
state.selectedEntityIdentifier = deepClone(initialState.selectedEntityIdentifier);
state.session = deepClone(initialState.session);
@ -445,16 +469,21 @@ export const {
bboxAspectRatioIdChanged,
bboxDimensionsSwapped,
bboxSizeOptimized,
// layers
layerAdded,
layerRecalled,
layerAllDeleted,
layerUsedAsControlChanged,
layerControlAdapterModelChanged,
layerControlAdapterControlModeChanged,
layerControlAdapterWeightChanged,
layerControlAdapterBeginEndStepPctChanged,
layerCompositeRasterized,
// Raster layers
rasterLayerAdded,
rasterLayerRecalled,
rasterLayerAllDeleted,
rasterLayerConvertedToControlLayer,
rasterLayerCompositeRasterized,
// Control layers
controlLayerAdded,
controlLayerRecalled,
controlLayerAllDeleted,
controlLayerConvertedToRasterLayer,
controlLayerModelChanged,
controlLayerControlModeChanged,
controlLayerWeightChanged,
controlLayerBeginEndStepPctChanged,
// IP Adapters
ipaAdded,
ipaRecalled,

View File

@ -0,0 +1,153 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { merge, omit } from 'lodash-es';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import type {
CanvasControlLayerState,
CanvasRasterLayerState,
CanvasV2State,
ControlModeV2,
ControlNetConfig,
T2IAdapterConfig,
} from './types';
import { initialControlNetV2 } from './types';
export const selectControlLayer = (state: CanvasV2State, id: string) =>
state.controlLayers.entities.find((layer) => layer.id === id);
export const selectControlLayerOrThrow = (state: CanvasV2State, id: string) => {
const layer = selectControlLayer(state, id);
assert(layer, `Layer with id ${id} not found`);
return layer;
};
export const controlLayersReducers = {
controlLayerAdded: {
reducer: (
state,
action: PayloadAction<{ id: string; overrides?: Partial<CanvasControlLayerState>; isSelected?: boolean }>
) => {
const { id, overrides, isSelected } = action.payload;
const layer: CanvasControlLayerState = {
id,
type: 'control_layer',
isEnabled: true,
objects: [],
opacity: 1,
position: { x: 0, y: 0 },
rasterizationCache: [],
controlAdapter: deepClone(initialControlNetV2),
};
merge(layer, overrides);
state.controlLayers.entities.push(layer);
if (isSelected) {
state.selectedEntityIdentifier = { type: 'control_layer', id };
}
},
prepare: (payload: { overrides?: Partial<CanvasControlLayerState>; isSelected?: boolean }) => ({
payload: { ...payload, id: getPrefixedId('control_layer') },
}),
},
controlLayerRecalled: (state, action: PayloadAction<{ data: CanvasControlLayerState }>) => {
const { data } = action.payload;
state.controlLayers.entities.push(data);
state.selectedEntityIdentifier = { type: 'control_layer', id: data.id };
},
controlLayerAllDeleted: (state) => {
state.controlLayers.entities = [];
},
controlLayerConvertedToRasterLayer: {
reducer: (state, action: PayloadAction<{ id: string; newId: string }>) => {
const { id, newId } = action.payload;
const layer = selectControlLayer(state, id);
if (!layer) {
return;
}
// Convert the raster layer to control layer
const rasterLayerState: CanvasRasterLayerState = {
...omit(deepClone(layer), ['type', 'controlAdapter']),
id: newId,
type: 'raster_layer',
};
// Remove the control layer
state.controlLayers.entities = state.controlLayers.entities.filter((layer) => layer.id !== id);
// Add the new raster layer
state.rasterLayers.entities.push(rasterLayerState);
// The composite layer's image data will change when the control layer is converted to raster layer.
state.rasterLayers.compositeRasterizationCache = [];
state.selectedEntityIdentifier = { type: rasterLayerState.type, id: rasterLayerState.id };
},
prepare: (payload: { id: string }) => ({
payload: { ...payload, newId: getPrefixedId('raster_layer') },
}),
},
controlLayerModelChanged: (
state,
action: PayloadAction<{
id: string;
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null;
}>
) => {
const { id, modelConfig } = action.payload;
const layer = selectControlLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
if (!modelConfig) {
layer.controlAdapter.model = null;
return;
}
layer.controlAdapter.model = zModelIdentifierField.parse(modelConfig);
// We may need to convert the CA to match the model
if (layer.controlAdapter.type === 't2i_adapter' && layer.controlAdapter.model.type === 'controlnet') {
// Converting from T2I Adapter to ControlNet - add `controlMode`
const controlNetConfig: ControlNetConfig = {
...layer.controlAdapter,
type: 'controlnet',
controlMode: 'balanced',
};
layer.controlAdapter = controlNetConfig;
} else if (layer.controlAdapter.type === 'controlnet' && layer.controlAdapter.model.type === 't2i_adapter') {
// Converting from ControlNet to T2I Adapter - remove `controlMode`
const { controlMode: _, ...rest } = layer.controlAdapter;
const t2iAdapterConfig: T2IAdapterConfig = { ...rest, type: 't2i_adapter' };
layer.controlAdapter = t2iAdapterConfig;
}
},
controlLayerControlModeChanged: (state, action: PayloadAction<{ id: string; controlMode: ControlModeV2 }>) => {
const { id, controlMode } = action.payload;
const layer = selectControlLayer(state, id);
if (!layer || !layer.controlAdapter || layer.controlAdapter.type !== 'controlnet') {
return;
}
layer.controlAdapter.controlMode = controlMode;
},
controlLayerWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
const { id, weight } = action.payload;
const layer = selectControlLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
layer.controlAdapter.weight = weight;
},
controlLayerBeginEndStepPctChanged: (
state,
action: PayloadAction<{ id: string; beginEndStepPct: [number, number] }>
) => {
const { id, beginEndStepPct } = action.payload;
const layer = selectControlLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
layer.controlAdapter.beginEndStepPct = beginEndStepPct;
},
} satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -1,142 +0,0 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { isEqual, merge } from 'lodash-es';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import type { CanvasLayerState, CanvasV2State, ControlModeV2, ControlNetConfig, Rect, T2IAdapterConfig } from './types';
export const selectLayer = (state: CanvasV2State, id: string) => state.layers.entities.find((layer) => layer.id === id);
export const selectLayerOrThrow = (state: CanvasV2State, id: string) => {
const layer = selectLayer(state, id);
assert(layer, `Layer with id ${id} not found`);
return layer;
};
export const layersReducers = {
layerAdded: {
reducer: (
state,
action: PayloadAction<{ id: string; overrides?: Partial<CanvasLayerState>; isSelected?: boolean }>
) => {
const { id, overrides, isSelected } = action.payload;
const layer: CanvasLayerState = {
id,
type: 'layer',
isEnabled: true,
objects: [],
opacity: 1,
position: { x: 0, y: 0 },
rasterizationCache: [],
controlAdapter: null,
};
merge(layer, overrides);
state.layers.entities.push(layer);
if (isSelected) {
state.selectedEntityIdentifier = { type: 'layer', id };
}
if (layer.objects.length > 0) {
// This new layer will change the composite layer's image data. Invalidate the cache.
state.layers.compositeRasterizationCache = [];
}
},
prepare: (payload: { overrides?: Partial<CanvasLayerState>; isSelected?: boolean }) => ({
payload: { ...payload, id: getPrefixedId('layer') },
}),
},
layerRecalled: (state, action: PayloadAction<{ data: CanvasLayerState }>) => {
const { data } = action.payload;
state.layers.entities.push(data);
state.selectedEntityIdentifier = { type: 'layer', id: data.id };
if (data.objects.length > 0) {
// This new layer will change the composite layer's image data. Invalidate the cache.
state.layers.compositeRasterizationCache = [];
}
},
layerAllDeleted: (state) => {
state.layers.entities = [];
state.layers.compositeRasterizationCache = [];
},
layerCompositeRasterized: (state, action: PayloadAction<{ imageName: string; rect: Rect }>) => {
state.layers.compositeRasterizationCache = state.layers.compositeRasterizationCache.filter(
(cache) => !isEqual(cache.rect, action.payload.rect)
);
state.layers.compositeRasterizationCache.push(action.payload);
},
layerUsedAsControlChanged: (
state,
action: PayloadAction<{ id: string; controlAdapter: ControlNetConfig | T2IAdapterConfig | null }>
) => {
const { id, controlAdapter } = action.payload;
const layer = selectLayer(state, id);
if (!layer) {
return;
}
layer.controlAdapter = controlAdapter;
// The composite layer's image data will change when the layer is used as control (or not). Invalidate the cache.
state.layers.compositeRasterizationCache = [];
},
layerControlAdapterModelChanged: (
state,
action: PayloadAction<{
id: string;
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null;
}>
) => {
const { id, modelConfig } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
if (!modelConfig) {
layer.controlAdapter.model = null;
return;
}
layer.controlAdapter.model = zModelIdentifierField.parse(modelConfig);
// We may need to convert the CA to match the model
if (layer.controlAdapter.type === 't2i_adapter' && layer.controlAdapter.model.type === 'controlnet') {
// Converting from T2I Adapter to ControlNet - add `controlMode`
const controlNetConfig: ControlNetConfig = {
...layer.controlAdapter,
type: 'controlnet',
controlMode: 'balanced',
};
layer.controlAdapter = controlNetConfig;
} else if (layer.controlAdapter.type === 'controlnet' && layer.controlAdapter.model.type === 't2i_adapter') {
// Converting from ControlNet to T2I Adapter - remove `controlMode`
const { controlMode: _, ...rest } = layer.controlAdapter;
const t2iAdapterConfig: T2IAdapterConfig = { ...rest, type: 't2i_adapter' };
layer.controlAdapter = t2iAdapterConfig;
}
},
layerControlAdapterControlModeChanged: (state, action: PayloadAction<{ id: string; controlMode: ControlModeV2 }>) => {
const { id, controlMode } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter || layer.controlAdapter.type !== 'controlnet') {
return;
}
layer.controlAdapter.controlMode = controlMode;
},
layerControlAdapterWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
const { id, weight } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
layer.controlAdapter.weight = weight;
},
layerControlAdapterBeginEndStepPctChanged: (
state,
action: PayloadAction<{ id: string; beginEndStepPct: [number, number] }>
) => {
const { id, beginEndStepPct } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
layer.controlAdapter.beginEndStepPct = beginEndStepPct;
},
} satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -0,0 +1,108 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { isEqual, merge } from 'lodash-es';
import { assert } from 'tsafe';
import type {
CanvasControlLayerState,
CanvasRasterLayerState,
CanvasV2State,
ControlNetConfig,
Rect,
T2IAdapterConfig,
} from './types';
export const selectRasterLayer = (state: CanvasV2State, id: string) =>
state.rasterLayers.entities.find((layer) => layer.id === id);
export const selectLayerOrThrow = (state: CanvasV2State, id: string) => {
const layer = selectRasterLayer(state, id);
assert(layer, `Layer with id ${id} not found`);
return layer;
};
export const rasterLayersReducers = {
rasterLayerAdded: {
reducer: (
state,
action: PayloadAction<{ id: string; overrides?: Partial<CanvasRasterLayerState>; isSelected?: boolean }>
) => {
const { id, overrides, isSelected } = action.payload;
const layer: CanvasRasterLayerState = {
id,
type: 'raster_layer',
isEnabled: true,
objects: [],
opacity: 1,
position: { x: 0, y: 0 },
rasterizationCache: [],
};
merge(layer, overrides);
state.rasterLayers.entities.push(layer);
if (isSelected) {
state.selectedEntityIdentifier = { type: 'raster_layer', id };
}
if (layer.objects.length > 0) {
// This new layer will change the composite layer's image data. Invalidate the cache.
state.rasterLayers.compositeRasterizationCache = [];
}
},
prepare: (payload: { overrides?: Partial<CanvasRasterLayerState>; isSelected?: boolean }) => ({
payload: { ...payload, id: getPrefixedId('raster_layer') },
}),
},
rasterLayerRecalled: (state, action: PayloadAction<{ data: CanvasRasterLayerState }>) => {
const { data } = action.payload;
state.rasterLayers.entities.push(data);
state.selectedEntityIdentifier = { type: 'raster_layer', id: data.id };
if (data.objects.length > 0) {
// This new layer will change the composite layer's image data. Invalidate the cache.
state.rasterLayers.compositeRasterizationCache = [];
}
},
rasterLayerAllDeleted: (state) => {
state.rasterLayers.entities = [];
state.rasterLayers.compositeRasterizationCache = [];
},
rasterLayerCompositeRasterized: (state, action: PayloadAction<{ imageName: string; rect: Rect }>) => {
state.rasterLayers.compositeRasterizationCache = state.rasterLayers.compositeRasterizationCache.filter(
(cache) => !isEqual(cache.rect, action.payload.rect)
);
state.rasterLayers.compositeRasterizationCache.push(action.payload);
},
rasterLayerConvertedToControlLayer: {
reducer: (
state,
action: PayloadAction<{ id: string; newId: string; controlAdapter: ControlNetConfig | T2IAdapterConfig }>
) => {
const { id, newId, controlAdapter } = action.payload;
const layer = selectRasterLayer(state, id);
if (!layer) {
return;
}
// Convert the raster layer to control layer
const controlLayerState: CanvasControlLayerState = {
...deepClone(layer),
id: newId,
type: 'control_layer',
controlAdapter,
};
// Remove the raster layer
state.rasterLayers.entities = state.rasterLayers.entities.filter((layer) => layer.id !== id);
// Add the converted control layer
state.controlLayers.entities.push(controlLayerState);
// The composite layer's image data will change when the raster layer is converted to control layer.
state.rasterLayers.compositeRasterizationCache = [];
state.selectedEntityIdentifier = { type: controlLayerState.type, id: controlLayerState.id };
},
prepare: (payload: { id: string; controlAdapter: ControlNetConfig | T2IAdapterConfig }) => ({
payload: { ...payload, newId: getPrefixedId('control_layer') },
}),
},
} satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -7,7 +7,7 @@ export const selectEntityCount = createSelector(selectCanvasV2Slice, (canvasV2)
canvasV2.regions.entities.length +
// canvasV2.controlAdapters.entities.length +
canvasV2.ipAdapters.entities.length +
canvasV2.layers.entities.length
canvasV2.rasterLayers.entities.length
);
});

View File

@ -728,23 +728,22 @@ const zT2IAdapterConfig = z.object({
});
export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
export const zCanvasLayerState = z.object({
export const zCanvasRasterLayerState = z.object({
id: zId,
type: z.literal('layer'),
type: z.literal('raster_layer'),
isEnabled: z.boolean(),
position: zCoordinate,
opacity: zOpacity,
objects: z.array(zCanvasObjectState),
rasterizationCache: z.array(zImageCache),
controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig]).nullable(),
});
export type CanvasLayerState = z.infer<typeof zCanvasLayerState>;
export type CanvasLayerStateWithValidControlNet = Omit<CanvasLayerState, 'controlAdapter'> & {
controlAdapter: Omit<ControlNetConfig, 'model'> & { model: ControlNetModelConfig };
};
export type CanvasLayerStateWithValidT2IAdapter = Omit<CanvasLayerState, 'controlAdapter'> & {
controlAdapter: Omit<T2IAdapterConfig, 'model'> & { model: T2IAdapterModelConfig };
};
export type CanvasRasterLayerState = z.infer<typeof zCanvasRasterLayerState>;
export const zCanvasControlLayerState = zCanvasRasterLayerState.extend({
type: z.literal('control_layer'),
controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig]),
});
export type CanvasControlLayerState = z.infer<typeof zCanvasControlLayerState>;
export const initialControlNetV2: ControlNetConfig = {
type: 'controlnet',
@ -808,8 +807,8 @@ export const isBoundingBoxScaleMethod = (v: unknown): v is BoundingBoxScaleMetho
zBoundingBoxScaleMethod.safeParse(v).success;
export type CanvasEntityState =
| CanvasLayerState
| CanvasControlAdapterState
| CanvasRasterLayerState
| CanvasControlLayerState
| CanvasRegionalGuidanceState
| CanvasInpaintMaskState
| CanvasIPAdapterState;
@ -832,7 +831,8 @@ export type CanvasV2State = {
_version: 3;
selectedEntityIdentifier: CanvasEntityIdentifier | null;
inpaintMask: CanvasInpaintMaskState;
layers: { entities: CanvasLayerState[]; compositeRasterizationCache: ImageCache[] };
rasterLayers: { entities: CanvasRasterLayerState[]; compositeRasterizationCache: ImageCache[] };
controlLayers: { entities: CanvasControlLayerState[] };
ipAdapters: { entities: CanvasIPAdapterState[] };
regions: { entities: CanvasRegionalGuidanceState[] };
loras: LoRA[];
@ -962,10 +962,19 @@ export type RemoveIndexString<T> = {
export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint';
export function isDrawableEntityType(entityType: CanvasEntityState['type']) {
return (
entityType === 'raster_layer' ||
entityType === 'control_layer' ||
entityType === 'regional_guidance' ||
entityType === 'inpaint_mask'
);
}
export function isDrawableEntity(
entity: CanvasEntityState
): entity is CanvasLayerState | CanvasRegionalGuidanceState | CanvasInpaintMaskState {
return entity.type === 'layer' || entity.type === 'regional_guidance' || entity.type === 'inpaint_mask';
): entity is CanvasRasterLayerState | CanvasControlLayerState | CanvasRegionalGuidanceState | CanvasInpaintMaskState {
return isDrawableEntityType(entity.type);
}
export function isDrawableEntityAdapter(
@ -973,9 +982,3 @@ export function isDrawableEntityAdapter(
): adapter is CanvasLayerAdapter | CanvasMaskAdapter {
return adapter instanceof CanvasLayerAdapter || adapter instanceof CanvasMaskAdapter;
}
export function isDrawableEntityType(
entityType: CanvasEntityState['type']
): entityType is 'layer' | 'regional_guidance' | 'inpaint_mask' {
return entityType === 'layer' || entityType === 'regional_guidance' || entityType === 'inpaint_mask';
}

View File

@ -11,7 +11,7 @@ import { some } from 'lodash-es';
import type { ImageUsage } from './types';
export const getImageUsage = (nodes: NodesState, canvasV2: CanvasV2State, image_name: string) => {
const isLayerImage = canvasV2.layers.entities.some((layer) =>
const isLayerImage = canvasV2.rasterLayers.entities.some((layer) =>
layer.objects.some((obj) => obj.type === 'image' && obj.image.image_name === image_name)
);

View File

@ -1,4 +1,4 @@
import type { CanvasLayerState } from 'features/controlLayers/store/types';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers';
@ -9,7 +9,7 @@ type Props = {
};
export const MetadataLayers = ({ metadata }: Props) => {
const [layers, setLayers] = useState<CanvasLayerState[]>([]);
const [layers, setLayers] = useState<CanvasRasterLayerState[]>([]);
useEffect(() => {
const parse = async () => {
@ -40,8 +40,8 @@ const MetadataViewLayer = ({
handlers,
}: {
label: string;
layer: CanvasLayerState;
handlers: MetadataHandlers<CanvasLayerState[], CanvasLayerState>;
layer: CanvasRasterLayerState;
handlers: MetadataHandlers<CanvasRasterLayerState[], CanvasRasterLayerState>;
}) => {
const onRecall = useCallback(() => {
if (!handlers.recallItem) {

View File

@ -2,7 +2,7 @@ import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone';
import { objectKeys } from 'common/util/objectKeys';
import { shouldConcatPromptsChanged } from 'features/controlLayers/store/canvasV2Slice';
import type { CanvasLayerState, LoRA } from 'features/controlLayers/store/types';
import type { CanvasRasterLayerState, LoRA } from 'features/controlLayers/store/types';
import type {
AnyControlAdapterConfigMetadata,
BuildMetadataHandlers,
@ -48,7 +48,7 @@ const renderControlAdapterValue: MetadataRenderValueFunc<AnyControlAdapterConfig
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
}
};
const renderLayerValue: MetadataRenderValueFunc<CanvasLayerState> = async (layer) => {
const renderLayerValue: MetadataRenderValueFunc<CanvasRasterLayerState> = async (layer) => {
if (layer.type === 'initial_image_layer') {
let rendered = t('controlLayers.globalInitialImageLayer');
if (layer.image) {
@ -88,7 +88,7 @@ const renderLayerValue: MetadataRenderValueFunc<CanvasLayerState> = async (layer
}
assert(false, 'Unknown layer type');
};
const renderLayersValue: MetadataRenderValueFunc<CanvasLayerState[]> = async (layers) => {
const renderLayersValue: MetadataRenderValueFunc<CanvasRasterLayerState[]> = async (layers) => {
return `${layers.length} ${t('controlLayers.layers', { count: layers.length })}`;
};

View File

@ -1,6 +1,6 @@
import { getCAId, getImageObjectId, getIPAId, getLayerId } from 'features/controlLayers/konva/naming';
import { defaultLoRAConfig } from 'features/controlLayers/store/lorasReducers';
import type { CanvasControlAdapterState, CanvasIPAdapterState, CanvasLayerState, LoRA } from 'features/controlLayers/store/types';
import type { CanvasControlAdapterState, CanvasIPAdapterState, CanvasRasterLayerState, LoRA } from 'features/controlLayers/store/types';
import {
IMAGE_FILTERS,
imageDTOToImageWithDims,
@ -8,7 +8,7 @@ import {
initialIPAdapterV2,
initialT2IAdapterV2,
isFilterType,
zCanvasLayerState,
zCanvasRasterLayerState,
} from 'features/controlLayers/store/types';
import type {
ControlNetConfigMetadata,
@ -424,22 +424,22 @@ const parseAllIPAdapters: MetadataParseFunc<IPAdapterConfigMetadata[]> = async (
};
//#region Control Layers
const parseLayer: MetadataParseFunc<CanvasLayerState> = async (metadataItem) => zCanvasLayerState.parseAsync(metadataItem);
const parseLayer: MetadataParseFunc<CanvasRasterLayerState> = async (metadataItem) => zCanvasRasterLayerState.parseAsync(metadataItem);
const parseLayers: MetadataParseFunc<CanvasLayerState[]> = async (metadata) => {
const parseLayers: MetadataParseFunc<CanvasRasterLayerState[]> = async (metadata) => {
// We need to support recalling pre-Control Layers metadata into Control Layers. A separate set of parsers handles
// taking pre-CL metadata and parsing it into layers. It doesn't always map 1-to-1, so this is best-effort. For
// example, CL Control Adapters don't support resize mode, so we simply omit that property.
try {
const layers: CanvasLayerState[] = [];
const layers: CanvasRasterLayerState[] = [];
try {
const control_layers = await getProperty(metadata, 'control_layers');
const controlLayersRaw = await getProperty(control_layers, 'layers', isArray);
const controlLayersParseResults = await Promise.allSettled(controlLayersRaw.map(parseLayer));
const controlLayers = controlLayersParseResults
.filter((result): result is PromiseFulfilledResult<CanvasLayerState> => result.status === 'fulfilled')
.filter((result): result is PromiseFulfilledResult<CanvasRasterLayerState> => result.status === 'fulfilled')
.map((result) => result.value);
layers.push(...controlLayers);
} catch {
@ -498,16 +498,16 @@ const parseLayers: MetadataParseFunc<CanvasLayerState[]> = async (metadata) => {
}
};
const parseInitialImageToInitialImageLayer: MetadataParseFunc<CanvasLayerState> = async (metadata) => {
const parseInitialImageToInitialImageLayer: MetadataParseFunc<CanvasRasterLayerState> = async (metadata) => {
// TODO(psyche): recall denoise strength
// const denoisingStrength = await getProperty(metadata, 'strength', isParameterStrength);
const imageName = await getProperty(metadata, 'init_image', isString);
const imageDTO = await getImageDTO(imageName);
assert(imageDTO, 'ImageDTO is null');
const id = getLayerId(uuidv4());
const layer: CanvasLayerState = {
const layer: CanvasRasterLayerState = {
id,
type: 'layer',
type: 'raster_layer',
bbox: null,
bboxNeedsUpdate: true,
x: 0,

View File

@ -15,8 +15,8 @@ import {
bboxWidthChanged,
// caRecalled,
ipaRecalled,
layerAllDeleted,
layerRecalled,
rasterLayerAllDeleted,
rasterLayerRecalled,
loraAllDeleted,
loraRecalled,
negativePrompt2Changed,
@ -42,7 +42,7 @@ import {
import type {
CanvasControlAdapterState,
CanvasIPAdapterState,
CanvasLayerState,
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
LoRA,
} from 'features/controlLayers/store/types';
@ -328,7 +328,7 @@ const recallRG: MetadataRecallFunc<CanvasRegionalGuidanceState> = async (rg) =>
};
//#region Control Layers
const recallLayer: MetadataRecallFunc<CanvasLayerState> = async (layer) => {
const recallLayer: MetadataRecallFunc<CanvasRasterLayerState> = async (layer) => {
const { dispatch } = getStore();
const clone = deepClone(layer);
const invalidObjects: string[] = [];
@ -355,13 +355,13 @@ const recallLayer: MetadataRecallFunc<CanvasLayerState> = async (layer) => {
}
}
clone.id = getRGId(uuidv4());
dispatch(layerRecalled({ data: clone }));
dispatch(rasterLayerRecalled({ data: clone }));
return;
};
const recallLayers: MetadataRecallFunc<CanvasLayerState[]> = (layers) => {
const recallLayers: MetadataRecallFunc<CanvasRasterLayerState[]> = (layers) => {
const { dispatch } = getStore();
dispatch(layerAllDeleted());
dispatch(rasterLayerAllDeleted());
for (const l of layers) {
recallLayer(l);
}

View File

@ -1,5 +1,5 @@
import { getStore } from 'app/store/nanostores/store';
import type { CanvasLayerState, LoRA } from 'features/controlLayers/store/types';
import type { CanvasRasterLayerState, LoRA } from 'features/controlLayers/store/types';
import type {
ControlNetConfigMetadata,
IPAdapterConfigMetadata,
@ -109,7 +109,7 @@ const validateIPAdapters: MetadataValidateFunc<IPAdapterConfigMetadata[]> = (ipA
return new Promise((resolve) => resolve(validatedIPAdapters));
};
const validateLayer: MetadataValidateFunc<CanvasLayerState> = async (layer) => {
const validateLayer: MetadataValidateFunc<CanvasRasterLayerState> = async (layer) => {
if (layer.type === 'control_adapter_layer') {
const model = layer.controlAdapter.model;
assert(model, 'Control Adapter layer missing model');
@ -131,8 +131,8 @@ const validateLayer: MetadataValidateFunc<CanvasLayerState> = async (layer) => {
return layer;
};
const validateLayers: MetadataValidateFunc<CanvasLayerState[]> = async (layers) => {
const validatedLayers: CanvasLayerState[] = [];
const validateLayers: MetadataValidateFunc<CanvasRasterLayerState[]> = async (layers) => {
const validatedLayers: CanvasRasterLayerState[] = [];
for (const l of layers) {
try {
const validated = await validateLayer(l);

View File

@ -1,15 +1,10 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type {
CanvasLayerState,
CanvasLayerStateWithValidControlNet,
CanvasLayerStateWithValidT2IAdapter,
CanvasControlLayerState,
ControlNetConfig,
FilterConfig,
ImageWithDims,
Rect,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import type { ImageField } from 'features/nodes/types/common';
import { CONTROL_NET_COLLECT, T2I_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
@ -17,18 +12,18 @@ import { assert } from 'tsafe';
export const addControlAdapters = async (
manager: CanvasManager,
layers: CanvasLayerState[],
layers: CanvasControlLayerState[],
g: Graph,
bbox: Rect,
denoise: Invocation<'denoise_latents'>,
base: BaseModelType
): Promise<(CanvasLayerStateWithValidControlNet | CanvasLayerStateWithValidT2IAdapter)[]> => {
const layersWithValidControlAdapters = layers
): Promise<CanvasControlLayerState[]> => {
const validControlLayers = layers
.filter((layer) => layer.isEnabled)
.filter((layer) => doesLayerHaveValidControlAdapter(layer, base));
.filter((layer) => isValidControlAdapter(layer.controlAdapter, base));
for (const layer of layersWithValidControlAdapters) {
const adapter = manager.layers.get(layer.id);
for (const layer of validControlLayers) {
const adapter = manager.controlLayerAdapters.get(layer.id);
assert(adapter, 'Adapter not found');
const imageDTO = await adapter.renderer.rasterize(bbox);
if (layer.controlAdapter.type === 'controlnet') {
@ -37,7 +32,7 @@ export const addControlAdapters = async (
await addT2IAdapterToGraph(g, layer, imageDTO, denoise);
}
}
return layersWithValidControlAdapters;
return validControlLayers;
};
const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
@ -59,12 +54,14 @@ const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
const addControlNetToGraph = (
g: Graph,
layer: CanvasLayerStateWithValidControlNet,
layer: CanvasControlLayerState,
imageDTO: ImageDTO,
denoise: Invocation<'denoise_latents'>
) => {
const { id, controlAdapter } = layer;
assert(controlAdapter.type === 'controlnet');
const { beginEndStepPct, model, weight, controlMode } = controlAdapter;
assert(model !== null);
const { image_name } = imageDTO;
const controlNetCollect = addControlNetCollectorSafe(g, denoise);
@ -103,12 +100,14 @@ const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
const addT2IAdapterToGraph = (
g: Graph,
layer: CanvasLayerStateWithValidT2IAdapter,
layer: CanvasControlLayerState,
imageDTO: ImageDTO,
denoise: Invocation<'denoise_latents'>
) => {
const { id, controlAdapter } = layer;
assert(controlAdapter.type === 't2i_adapter');
const { beginEndStepPct, model, weight } = controlAdapter;
assert(model !== null);
const { image_name } = imageDTO;
const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise);
@ -127,25 +126,6 @@ const addT2IAdapterToGraph = (
g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item');
};
const buildControlImage = (
image: ImageWithDims | null,
processedImage: ImageWithDims | null,
processorConfig: FilterConfig | null
): ImageField => {
if (processedImage && processorConfig) {
// We've processed the image in the app - use it for the control image.
return {
image_name: processedImage.image_name,
};
} else if (image) {
// No processor selected, and we have an image - the user provided a processed image, use it for the control image.
return {
image_name: image.image_name,
};
}
assert(false, 'Attempted to add unprocessed control image');
};
const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => {
// Must be have a model
const hasModel = Boolean(controlAdapter.model);
@ -153,22 +133,3 @@ const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConf
const modelMatchesBase = controlAdapter.model?.base === base;
return hasModel && modelMatchesBase;
};
const doesLayerHaveValidControlAdapter = (
layer: CanvasLayerState,
base: BaseModelType
): layer is CanvasLayerStateWithValidControlNet | CanvasLayerStateWithValidT2IAdapter => {
if (!layer.controlAdapter) {
// Must have a control adapter
return false;
}
if (!layer.controlAdapter.model) {
// Control adapter must have a model selected
return false;
}
if (layer.controlAdapter.model.base !== base) {
// Selected model must match current base model
return false;
}
return true;
};

View File

@ -22,7 +22,7 @@ export const addInpaint = async (
denoise.denoising_start = denoising_start;
const initialImage = await manager.getCompositeLayerImageDTO(bbox.rect);
const maskImage = await manager.inpaintMask.renderer.rasterize(bbox.rect);
const maskImage = await manager.inpaintMaskAdapter.renderer.rasterize(bbox.rect);
if (!isEqual(scaledSize, originalSize)) {
// Scale before processing requires some resizing

View File

@ -1,6 +1,6 @@
import type { CanvasLayerState } from 'features/controlLayers/store/types';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
export const isValidLayerWithoutControlAdapter = (layer: CanvasLayerState) => {
export const isValidLayerWithoutControlAdapter = (layer: CanvasRasterLayerState) => {
return (
layer.isEnabled &&
// Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers

View File

@ -23,7 +23,7 @@ export const addOutpaint = async (
denoise.denoising_start = denoising_start;
const initialImage = await manager.getCompositeLayerImageDTO(bbox.rect);
const maskImage = await manager.inpaintMask.renderer.rasterize(bbox.rect);
const maskImage = await manager.inpaintMaskAdapter.renderer.rasterize(bbox.rect);
const infill = getInfill(g, compositing);
if (!isEqual(scaledSize, originalSize)) {

View File

@ -43,7 +43,7 @@ export const addRegions = async (
const validRegions = regions.filter((rg) => isValidRegion(rg, base));
for (const region of validRegions) {
const adapter = manager.regions.get(region.id);
const adapter = manager.regionalGuidanceAdapters.get(region.id);
assert(adapter, 'Adapter not found');
const imageDTO = await adapter.renderer.rasterize(bbox);

View File

@ -215,7 +215,7 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P
const _addedCAs = await addControlAdapters(
manager,
state.canvasV2.layers.entities,
state.canvasV2.rasterLayers.entities,
g,
state.canvasV2.bbox.rect,
denoise,

View File

@ -219,7 +219,7 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager):
const _addedCAs = await addControlAdapters(
manager,
state.canvasV2.layers.entities,
state.canvasV2.rasterLayers.entities,
g,
state.canvasV2.bbox.rect,
denoise,