refactor(ui): create classes to abstract mgmt of konva nodes

This commit is contained in:
psychedelicious 2024-06-19 20:30:49 +10:00
parent 995c26751e
commit d965df8ca9
28 changed files with 482 additions and 433 deletions

@ -67,11 +67,11 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
// We should only process if the processor settings or image have changed // We should only process if the processor settings or image have changed
const originalCA = selectCA(originalState.canvasV2, id); const originalCA = selectCA(originalState.canvasV2, id);
const originalImage = originalCA?.image; const originalImage = originalCA?.imageObject;
const originalConfig = originalCA?.processorConfig; const originalConfig = originalCA?.processorConfig;
const image = ca.image; const image = ca.imageObject;
const processedImage = ca.processedImage; const processedImage = ca.processedImageObject;
const config = ca.processorConfig; const config = ca.processorConfig;
if (isEqual(config, originalConfig) && isEqual(image, originalImage) && processedImage) { if (isEqual(config, originalConfig) && isEqual(image, originalImage) && processedImage) {
@ -95,7 +95,7 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
} }
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now // TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never); const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image.image, config as never);
const enqueueBatchArg: BatchConfig = { const enqueueBatchArg: BatchConfig = {
prepend: true, prepend: true,
batch: { batch: {

@ -49,7 +49,7 @@ const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, ima
}; };
const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => { const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.canvasV2.ipAdapters.forEach(({ id, image }) => { state.canvasV2.ipAdapters.forEach(({ id, imageObject: image }) => {
if (image?.name === imageDTO.image_name) { if (image?.name === imageDTO.image_name) {
dispatch(ipaImageChanged({ id, imageDTO: null })); dispatch(ipaImageChanged({ id, imageDTO: null }));
} }

@ -178,7 +178,7 @@ const createSelector = (templates: Templates) =>
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel')); problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
} }
// Must have an image // Must have an image
if (!ipa.image) { if (!ipa.imageObject) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected')); problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
} }
@ -214,7 +214,7 @@ const createSelector = (templates: Templates) =>
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel')); problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
} }
// Must have an image // Must have an image
if (!ipAdapter.image) { if (!ipAdapter.imageObject) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected')); problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
} }
}); });

@ -47,10 +47,10 @@ export const CAImagePreview = memo(
const [isMouseOverImage, setIsMouseOverImage] = useState(false); const [isMouseOverImage, setIsMouseOverImage] = useState(false);
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery( const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
controlAdapter.image?.name ?? skipToken controlAdapter.imageObject?.image.name ?? skipToken
); );
const { currentData: processedControlImage, isError: isErrorProcessedControlImage } = useGetImageDTOQuery( const { currentData: processedControlImage, isError: isErrorProcessedControlImage } = useGetImageDTOQuery(
controlAdapter.processedImage?.name ?? skipToken controlAdapter.processedImageObject?.image.name ?? skipToken
); );
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation(); const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();

@ -95,7 +95,7 @@ export const IPASettings = memo(({ id }: Props) => {
</Flex> </Flex>
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1"> <Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
<IPAImagePreview <IPAImagePreview
image={ipAdapter.image} image={ipAdapter.imageObject?.image ?? null}
onChangeImage={onChangeImage} onChangeImage={onChangeImage}
ipAdapterId={ipAdapter.id} ipAdapterId={ipAdapter.id}
droppableData={droppableData} droppableData={droppableData}

@ -123,7 +123,7 @@ export const RGIPAdapterSettings = memo(({ id, ipAdapterId, ipAdapterNumber }: P
</Flex> </Flex>
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1"> <Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
<IPAImagePreview <IPAImagePreview
image={ipAdapter.image} image={ipAdapter.imageObject?.image ?? null}
onChangeImage={onChangeImage} onChangeImage={onChangeImage}
ipAdapterId={ipAdapter.id} ipAdapterId={ipAdapter.id}
droppableData={droppableData} droppableData={droppableData}

@ -24,21 +24,27 @@ export type RectShapeEntry = {
export type ImageEntry = { export type ImageEntry = {
id: string; id: string;
type: ImageObject['type']; type: ImageObject['type'];
konvaImageGroup: Konva.Group;
konvaPlaceholderGroup: Konva.Group;
konvaPlaceholderText: Konva.Text;
konvaImage: Konva.Image | null; // The image is loaded asynchronously, so it may not be available immediately konvaImage: Konva.Image | null; // The image is loaded asynchronously, so it may not be available immediately
konvaGroup: Konva.Group; isLoading: boolean;
isError: boolean;
}; };
type Entry = BrushLineEntry | EraserLineEntry | RectShapeEntry | ImageEntry; type Entry = BrushLineEntry | EraserLineEntry | RectShapeEntry | ImageEntry;
export class EntityToKonvaMap { export class EntityToKonvaMap {
stage: Konva.Stage;
mappings: Record<string, EntityToKonvaMapping>; mappings: Record<string, EntityToKonvaMapping>;
constructor() { constructor(stage: Konva.Stage) {
this.stage = stage;
this.mappings = {}; this.mappings = {};
} }
addMapping(id: string, konvaLayer: Konva.Layer, konvaObjectGroup: Konva.Group): EntityToKonvaMapping { addMapping(id: string, konvaLayer: Konva.Layer, konvaObjectGroup: Konva.Group): EntityToKonvaMapping {
const mapping = new EntityToKonvaMapping(id, konvaLayer, konvaObjectGroup); const mapping = new EntityToKonvaMapping(id, konvaLayer, konvaObjectGroup, this);
this.mappings[id] = mapping; this.mappings[id] = mapping;
return mapping; return mapping;
} }
@ -66,12 +72,14 @@ export class EntityToKonvaMapping {
konvaLayer: Konva.Layer; konvaLayer: Konva.Layer;
konvaObjectGroup: Konva.Group; konvaObjectGroup: Konva.Group;
konvaNodeEntries: Record<string, Entry>; konvaNodeEntries: Record<string, Entry>;
map: EntityToKonvaMap;
constructor(id: string, konvaLayer: Konva.Layer, konvaObjectGroup: Konva.Group) { constructor(id: string, konvaLayer: Konva.Layer, konvaObjectGroup: Konva.Group, map: EntityToKonvaMap) {
this.id = id; this.id = id;
this.konvaLayer = konvaLayer; this.konvaLayer = konvaLayer;
this.konvaObjectGroup = konvaObjectGroup; this.konvaObjectGroup = konvaObjectGroup;
this.konvaNodeEntries = {}; this.konvaNodeEntries = {};
this.map = map;
} }
addEntry<T extends Entry>(entry: T): T { addEntry<T extends Entry>(entry: T): T {
@ -83,8 +91,8 @@ export class EntityToKonvaMapping {
return this.konvaNodeEntries[id] as T | undefined; return this.konvaNodeEntries[id] as T | undefined;
} }
getEntries(): Entry[] { getEntries<T extends Entry>(): T[] {
return Object.values(this.konvaNodeEntries); return Object.values(this.konvaNodeEntries) as T[];
} }
destroyEntry(id: string): void { destroyEntry(id: string): void {
@ -97,7 +105,7 @@ export class EntityToKonvaMapping {
} else if (entry.type === 'rect_shape') { } else if (entry.type === 'rect_shape') {
entry.konvaRect.destroy(); entry.konvaRect.destroy();
} else if (entry.type === 'image') { } else if (entry.type === 'image') {
entry.konvaGroup.destroy(); entry.konvaImageGroup.destroy();
} }
delete this.konvaNodeEntries[id]; delete this.konvaNodeEntries[id];
} }

@ -4,42 +4,40 @@
// IDs for singleton Konva layers and objects // IDs for singleton Konva layers and objects
export const PREVIEW_LAYER_ID = 'preview_layer'; export const PREVIEW_LAYER_ID = 'preview_layer';
export const PREVIEW_TOOL_GROUP_ID = 'preview_layer.tool_group'; export const PREVIEW_TOOL_GROUP_ID = `${PREVIEW_LAYER_ID}.tool_group`;
export const PREVIEW_BRUSH_GROUP_ID = 'preview_layer.brush_group'; export const PREVIEW_BRUSH_GROUP_ID = `${PREVIEW_LAYER_ID}.brush_group`;
export const PREVIEW_BRUSH_FILL_ID = 'preview_layer.brush_fill'; export const PREVIEW_BRUSH_FILL_ID = `${PREVIEW_LAYER_ID}.brush_fill`;
export const PREVIEW_BRUSH_BORDER_INNER_ID = 'preview_layer.brush_border_inner'; export const PREVIEW_BRUSH_BORDER_INNER_ID = `${PREVIEW_LAYER_ID}.brush_border_inner`;
export const PREVIEW_BRUSH_BORDER_OUTER_ID = 'preview_layer.brush_border_outer'; export const PREVIEW_BRUSH_BORDER_OUTER_ID = `${PREVIEW_LAYER_ID}.brush_border_outer`;
export const PREVIEW_RECT_ID = 'preview_layer.rect'; export const PREVIEW_RECT_ID = `${PREVIEW_LAYER_ID}.rect`;
export const PREVIEW_GENERATION_BBOX_GROUP = 'preview_layer.gen_bbox_group'; export const PREVIEW_GENERATION_BBOX_GROUP = `${PREVIEW_LAYER_ID}.gen_bbox_group`;
export const PREVIEW_GENERATION_BBOX_TRANSFORMER = 'preview_layer.gen_bbox_transformer'; export const PREVIEW_GENERATION_BBOX_TRANSFORMER = `${PREVIEW_LAYER_ID}.gen_bbox_transformer`;
export const PREVIEW_GENERATION_BBOX_DUMMY_RECT = 'preview_layer.gen_bbox_dummy_rect'; export const PREVIEW_GENERATION_BBOX_DUMMY_RECT = `${PREVIEW_LAYER_ID}.gen_bbox_dummy_rect`;
export const PREVIEW_DOCUMENT_SIZE_GROUP = 'preview_layer.doc_size_group'; export const PREVIEW_DOCUMENT_SIZE_GROUP = `${PREVIEW_LAYER_ID}.doc_size_group`;
export const PREVIEW_DOCUMENT_SIZE_STAGE_RECT = 'preview_layer.doc_size_stage_rect'; export const PREVIEW_DOCUMENT_SIZE_STAGE_RECT = `${PREVIEW_LAYER_ID}.doc_size_stage_rect`;
export const PREVIEW_DOCUMENT_SIZE_DOCUMENT_RECT = 'preview_layer.doc_size_doc_rect'; export const PREVIEW_DOCUMENT_SIZE_DOCUMENT_RECT = `${PREVIEW_LAYER_ID}.doc_size_doc_rect`;
// Names for Konva layers and objects (comparable to CSS classes) // Names for Konva layers and objects (comparable to CSS classes)
export const LAYER_BBOX_NAME = 'layer.bbox'; export const LAYER_BBOX_NAME = 'layer_bbox';
export const COMPOSITING_RECT_NAME = 'compositing-rect'; export const COMPOSITING_RECT_NAME = 'compositing_rect';
export const IMAGE_PLACEHOLDER_NAME = 'image_placeholder';
export const CA_LAYER_NAME = 'control_adapter_layer'; export const CA_LAYER_NAME = 'control_adapter';
export const CA_LAYER_IMAGE_NAME = 'control_adapter_layer.image'; export const CA_LAYER_OBJECT_GROUP_NAME = `${CA_LAYER_NAME}.object_group`;
export const CA_LAYER_IMAGE_NAME = `${CA_LAYER_NAME}.image`;
export const INITIAL_IMAGE_LAYER_ID = 'singleton_initial_image_layer';
export const INITIAL_IMAGE_LAYER_NAME = 'initial_image_layer';
export const INITIAL_IMAGE_LAYER_IMAGE_NAME = 'initial_image_layer.image';
export const RG_LAYER_NAME = 'regional_guidance_layer'; export const RG_LAYER_NAME = 'regional_guidance_layer';
export const RG_LAYER_OBJECT_GROUP_NAME = 'regional_guidance_layer.object_group'; export const RG_LAYER_OBJECT_GROUP_NAME = `${RG_LAYER_NAME}.object_group`;
export const RG_LAYER_BRUSH_LINE_NAME = 'regional_guidance_layer.brush_line'; export const RG_LAYER_BRUSH_LINE_NAME = `${RG_LAYER_NAME}.brush_line`;
export const RG_LAYER_ERASER_LINE_NAME = 'regional_guidance_layer.eraser_line'; export const RG_LAYER_ERASER_LINE_NAME = `${RG_LAYER_NAME}.eraser_line`;
export const RG_LAYER_RECT_SHAPE_NAME = 'regional_guidance_layer.rect_shape'; export const RG_LAYER_RECT_SHAPE_NAME = `${RG_LAYER_NAME}.rect_shape`;
export const RASTER_LAYER_NAME = 'raster_layer'; export const RASTER_LAYER_NAME = 'raster_layer';
export const RASTER_LAYER_OBJECT_GROUP_NAME = 'raster_layer.object_group'; export const RASTER_LAYER_OBJECT_GROUP_NAME = `${RASTER_LAYER_NAME}.object_group`;
export const RASTER_LAYER_BRUSH_LINE_NAME = 'raster_layer.brush_line'; export const RASTER_LAYER_BRUSH_LINE_NAME = `${RASTER_LAYER_NAME}.brush_line`;
export const RASTER_LAYER_ERASER_LINE_NAME = 'raster_layer.eraser_line'; export const RASTER_LAYER_ERASER_LINE_NAME = `${RASTER_LAYER_NAME}.eraser_line`;
export const RASTER_LAYER_RECT_SHAPE_NAME = 'raster_layer.rect_shape'; export const RASTER_LAYER_RECT_SHAPE_NAME = `${RASTER_LAYER_NAME}.rect_shape`;
export const RASTER_LAYER_IMAGE_NAME = 'raster_layer.image'; export const RASTER_LAYER_IMAGE_NAME = `${RASTER_LAYER_NAME}.image`;
export const INPAINT_MASK_LAYER_NAME = 'inpaint_mask_layer'; export const INPAINT_MASK_LAYER_NAME = 'inpaint_mask_layer';
@ -51,9 +49,8 @@ export const getLayerId = (entityId: string) => `${RASTER_LAYER_NAME}_${entityId
export const getBrushLineId = (entityId: string, lineId: string) => `${entityId}.brush_line_${lineId}`; export const getBrushLineId = (entityId: string, lineId: string) => `${entityId}.brush_line_${lineId}`;
export const getEraserLineId = (entityId: string, lineId: string) => `${entityId}.eraser_line_${lineId}`; export const getEraserLineId = (entityId: string, lineId: string) => `${entityId}.eraser_line_${lineId}`;
export const getRectShapeId = (entityId: string, rectId: string) => `${entityId}.rect_${rectId}`; export const getRectShapeId = (entityId: string, rectId: string) => `${entityId}.rect_${rectId}`;
export const getImageObjectId = (entityId: string, imageName: string) => `${entityId}.image_${imageName}`; export const getImageObjectId = (entityId: string, imageId: string) => `${entityId}.image_${imageId}`;
export const getObjectGroupId = (entityId: string, groupId: string) => `${entityId}.objectGroup_${groupId}`; export const getObjectGroupId = (entityId: string, groupId: string) => `${entityId}.objectGroup_${groupId}`;
export const getLayerBboxId = (entityId: string) => `${entityId}.bbox`; export const getLayerBboxId = (entityId: string) => `${entityId}.bbox`;
export const getCAId = (entityId: string) => `control_adapter_${entityId}`; export const getCAId = (entityId: string) => `${CA_LAYER_NAME}_${entityId}`;
export const getCAImageId = (entityId: string, imageName: string) => `${entityId}.image_${imageName}`;
export const getIPAId = (entityId: string) => `ip_adapter_${entityId}`; export const getIPAId = (entityId: string) => `ip_adapter_${entityId}`;

@ -1,23 +1,27 @@
import type { EntityToKonvaMap } from 'features/controlLayers/konva/entityToKonvaMap';
import { BACKGROUND_LAYER_ID, PREVIEW_LAYER_ID } from 'features/controlLayers/konva/naming'; import { BACKGROUND_LAYER_ID, PREVIEW_LAYER_ID } from 'features/controlLayers/konva/naming';
import type { ControlAdapterEntity, LayerEntity, RegionEntity } from 'features/controlLayers/store/types'; import type { ControlAdapterEntity, LayerEntity, RegionEntity } from 'features/controlLayers/store/types';
import type Konva from 'konva'; import type Konva from 'konva';
export const arrangeEntities = ( export const arrangeEntities = (
stage: Konva.Stage, stage: Konva.Stage,
layerMap: EntityToKonvaMap,
layers: LayerEntity[], layers: LayerEntity[],
controlAdapterMap: EntityToKonvaMap,
controlAdapters: ControlAdapterEntity[], controlAdapters: ControlAdapterEntity[],
regionMap: EntityToKonvaMap,
regions: RegionEntity[] regions: RegionEntity[]
): void => { ): void => {
let zIndex = 0; let zIndex = 0;
stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`)?.zIndex(++zIndex); stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`)?.zIndex(++zIndex);
for (const layer of layers) { for (const layer of layers) {
stage.findOne<Konva.Layer>(`#${layer.id}`)?.zIndex(++zIndex); layerMap.getMapping(layer.id)?.konvaLayer.zIndex(++zIndex);
} }
for (const ca of controlAdapters) { for (const ca of controlAdapters) {
stage.findOne<Konva.Layer>(`#${ca.id}`)?.zIndex(++zIndex); controlAdapterMap.getMapping(ca.id)?.konvaLayer.zIndex(++zIndex);
} }
for (const rg of regions) { for (const rg of regions) {
stage.findOne<Konva.Layer>(`#${rg.id}`)?.zIndex(++zIndex); regionMap.getMapping(rg.id)?.konvaLayer.zIndex(++zIndex);
} }
stage.findOne<Konva.Layer>(`#${PREVIEW_LAYER_ID}`)?.zIndex(++zIndex); stage.findOne<Konva.Layer>(`#${PREVIEW_LAYER_ID}`)?.zIndex(++zIndex);
}; };

@ -1,9 +1,15 @@
import type { EntityToKonvaMap } from 'features/controlLayers/konva/entityToKonvaMap'; import type { EntityToKonvaMap, EntityToKonvaMapping, ImageEntry } from 'features/controlLayers/konva/entityToKonvaMap';
import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters'; import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters';
import { CA_LAYER_IMAGE_NAME, CA_LAYER_NAME, getCAImageId } from 'features/controlLayers/konva/naming'; import { CA_LAYER_IMAGE_NAME, CA_LAYER_NAME, CA_LAYER_OBJECT_GROUP_NAME } from 'features/controlLayers/konva/naming';
import {
createImageObjectGroup,
createObjectGroup,
updateImageSource,
} from 'features/controlLayers/konva/renderers/objects';
import type { ControlAdapterEntity } from 'features/controlLayers/store/types'; import type { ControlAdapterEntity } from 'features/controlLayers/store/types';
import Konva from 'konva'; import Konva from 'konva';
import type { ImageDTO } from 'services/api/types'; import { isEqual } from 'lodash-es';
import { assert } from 'tsafe';
/** /**
* Logic for creating and rendering control adapter (control net & t2i adapter) layers. These layers have image objects * Logic for creating and rendering control adapter (control net & t2i adapter) layers. These layers have image objects
@ -13,164 +19,98 @@ import type { ImageDTO } from 'services/api/types';
/** /**
* Creates a control adapter layer. * Creates a control adapter layer.
* @param stage The konva stage * @param stage The konva stage
* @param ca The control adapter layer state * @param entity The control adapter layer state
*/ */
const createCALayer = (stage: Konva.Stage, ca: ControlAdapterEntity): Konva.Layer => { const getControlAdapter = (map: EntityToKonvaMap, entity: ControlAdapterEntity): EntityToKonvaMapping => {
let mapping = map.getMapping(entity.id);
if (mapping) {
return mapping;
}
const konvaLayer = new Konva.Layer({ const konvaLayer = new Konva.Layer({
id: ca.id, id: entity.id,
name: CA_LAYER_NAME, name: CA_LAYER_NAME,
imageSmoothingEnabled: false, imageSmoothingEnabled: false,
listening: false, listening: false,
}); });
stage.add(konvaLayer); const konvaObjectGroup = createObjectGroup(konvaLayer, CA_LAYER_OBJECT_GROUP_NAME);
return konvaLayer; map.stage.add(konvaLayer);
}; mapping = map.addMapping(entity.id, konvaLayer, konvaObjectGroup);
return mapping;
/**
* Creates a control adapter layer image.
* @param konvaLayer The konva layer
* @param imageEl The image element
*/
const createCALayerImage = (konvaLayer: Konva.Layer, imageEl: HTMLImageElement): Konva.Image => {
const konvaImage = new Konva.Image({
name: CA_LAYER_IMAGE_NAME,
image: imageEl,
listening: false,
});
konvaLayer.add(konvaImage);
return konvaImage;
};
/**
* Updates the image source for a control adapter layer. This includes loading the image from the server and updating
* the konva image.
* @param stage The konva stage
* @param konvaLayer The konva layer
* @param ca The control adapter layer state
* @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
*/
const updateControlAdapterImageSource = async (
stage: Konva.Stage,
konvaLayer: Konva.Layer,
ca: ControlAdapterEntity,
getImageDTO: (imageName: string) => Promise<ImageDTO | null>
): Promise<void> => {
const image = ca.processedImage ?? ca.image;
if (image) {
const imageName = image.name;
const imageDTO = await getImageDTO(imageName);
if (!imageDTO) {
return;
}
const imageEl = new Image();
const imageId = getCAImageId(ca.id, imageName);
imageEl.onload = () => {
// Find the existing image or create a new one - must find using the name, bc the id may have just changed
const konvaImage =
konvaLayer.findOne<Konva.Image>(`.${CA_LAYER_IMAGE_NAME}`) ?? createCALayerImage(konvaLayer, imageEl);
// Update the image's attributes
konvaImage.setAttrs({
id: imageId,
image: imageEl,
});
updateControlAdapterImageAttrs(stage, konvaImage, ca);
// Must cache after this to apply the filters
konvaImage.cache();
imageEl.id = imageId;
};
imageEl.src = imageDTO.image_url;
} else {
konvaLayer.findOne(`.${CA_LAYER_IMAGE_NAME}`)?.destroy();
}
};
/**
* Updates the image attributes for a control adapter layer's image (width, height, visibility, opacity, filters).
* @param stage The konva stage
* @param konvaImage The konva image
* @param ca The control adapter layer state
*/
const updateControlAdapterImageAttrs = (stage: Konva.Stage, konvaImage: Konva.Image, ca: ControlAdapterEntity): void => {
let needsCache = false;
// TODO(psyche): `node.filters()` returns null if no filters; report upstream
const filters = konvaImage.filters() ?? [];
const filter = filters[0] ?? null;
const filterNeedsUpdate = (filter === null && ca.filter !== 'none') || (filter && filter.name !== ca.filter);
if (
konvaImage.x() !== ca.x ||
konvaImage.y() !== ca.y ||
konvaImage.visible() !== ca.isEnabled ||
filterNeedsUpdate
) {
konvaImage.setAttrs({
opacity: ca.opacity,
scaleX: 1,
scaleY: 1,
visible: ca.isEnabled,
filters: ca.filter === 'LightnessToAlphaFilter' ? [LightnessToAlphaFilter] : [],
});
needsCache = true;
}
if (konvaImage.opacity() !== ca.opacity) {
konvaImage.opacity(ca.opacity);
}
if (needsCache) {
konvaImage.cache();
}
}; };
/** /**
* Renders a control adapter layer. If the layer doesn't already exist, it is created. Otherwise, the layer is updated * Renders a control adapter layer. If the layer doesn't already exist, it is created. Otherwise, the layer is updated
* with the current image source and attributes. * with the current image source and attributes.
* @param stage The konva stage * @param stage The konva stage
* @param ca The control adapter layer state * @param entity The control adapter layer state
* @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source * @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
*/ */
export const renderControlAdapter = ( export const renderControlAdapter = async (map: EntityToKonvaMap, entity: ControlAdapterEntity): Promise<void> => {
stage: Konva.Stage, const mapping = getControlAdapter(map, entity);
controlAdapterMap: EntityToKonvaMap, const imageObject = entity.processedImageObject ?? entity.imageObject;
ca: ControlAdapterEntity,
getImageDTO: (imageName: string) => Promise<ImageDTO | null>
): void => {
const konvaLayer = stage.findOne<Konva.Layer>(`#${ca.id}`) ?? createCALayer(stage, ca);
const konvaImage = konvaLayer.findOne<Konva.Image>(`.${CA_LAYER_IMAGE_NAME}`);
const canvasImageSource = konvaImage?.image();
let imageSourceNeedsUpdate = false; if (!imageObject) {
// The user has deleted/reset the image
if (canvasImageSource instanceof HTMLImageElement) { mapping.getEntries().forEach((entry) => {
const image = ca.processedImage ?? ca.image; mapping.destroyEntry(entry.id);
if (image && canvasImageSource.id !== getCAImageId(ca.id, image.name)) { });
imageSourceNeedsUpdate = true; return;
} else if (!image) {
imageSourceNeedsUpdate = true;
}
} else if (!canvasImageSource) {
imageSourceNeedsUpdate = true;
} }
if (imageSourceNeedsUpdate) { let entry = mapping.getEntries<ImageEntry>()[0];
updateControlAdapterImageSource(stage, konvaLayer, ca, getImageDTO); const opacity = entity.opacity;
} else if (konvaImage) { const visible = entity.isEnabled;
updateControlAdapterImageAttrs(stage, konvaImage, ca); const filters = entity.filter === 'LightnessToAlphaFilter' ? [LightnessToAlphaFilter] : [];
if (!entry) {
entry = await createImageObjectGroup({
mapping,
obj: imageObject,
name: CA_LAYER_IMAGE_NAME,
onLoad: (konvaImage) => {
konvaImage.filters(filters);
konvaImage.cache();
konvaImage.opacity(opacity);
konvaImage.visible(visible);
},
});
} else {
if (entry.isLoading || entry.isError) {
return;
}
assert(entry.konvaImage, `Image entry ${entry.id} must have a konva image if it is not loading or in error state`);
const imageSource = entry.konvaImage.image();
assert(imageSource instanceof HTMLImageElement, `Image source must be an HTMLImageElement`);
if (imageSource.id !== imageObject.image.name) {
updateImageSource({
entry,
image: imageObject.image,
onLoad: (konvaImage) => {
konvaImage.filters(filters);
konvaImage.cache();
konvaImage.opacity(opacity);
konvaImage.visible(visible);
},
});
} else {
if (!isEqual(entry.konvaImage.filters(), filters)) {
entry.konvaImage.filters(filters);
entry.konvaImage.cache();
}
entry.konvaImage.opacity(opacity);
entry.konvaImage.visible(visible);
}
} }
}; };
export const renderControlAdapters = ( export const renderControlAdapters = (map: EntityToKonvaMap, entities: ControlAdapterEntity[]): void => {
stage: Konva.Stage,
controlAdapterMap: EntityToKonvaMap,
controlAdapters: ControlAdapterEntity[],
getImageDTO: (imageName: string) => Promise<ImageDTO | null>
): void => {
// Destroy nonexistent layers // Destroy nonexistent layers
for (const mapping of controlAdapterMap.getMappings()) { for (const mapping of map.getMappings()) {
if (!controlAdapters.find((ca) => ca.id === mapping.id)) { if (!entities.find((ca) => ca.id === mapping.id)) {
controlAdapterMap.destroyMapping(mapping.id); map.destroyMapping(mapping.id);
} }
} }
for (const ca of controlAdapters) { for (const ca of entities) {
renderControlAdapter(stage, controlAdapterMap, ca, getImageDTO); renderControlAdapter(map, ca);
} }
}; };

@ -25,22 +25,21 @@ import Konva from 'konva';
/** /**
* Creates a raster layer. * Creates a raster layer.
* @param stage The konva stage * @param stage The konva stage
* @param layerState The raster layer state * @param entity The raster layer state
* @param onPosChanged Callback for when the layer's position changes * @param onPosChanged Callback for when the layer's position changes
*/ */
const getLayer = ( const getLayer = (
stage: Konva.Stage, map: EntityToKonvaMap,
layerMap: EntityToKonvaMap, entity: LayerEntity,
layerState: LayerEntity,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
): EntityToKonvaMapping => { ): EntityToKonvaMapping => {
let mapping = layerMap.getMapping(layerState.id); let mapping = map.getMapping(entity.id);
if (mapping) { if (mapping) {
return mapping; return mapping;
} }
// This layer hasn't been added to the konva state yet // This layer hasn't been added to the konva state yet
const konvaLayer = new Konva.Layer({ const konvaLayer = new Konva.Layer({
id: layerState.id, id: entity.id,
name: RASTER_LAYER_NAME, name: RASTER_LAYER_NAME,
draggable: true, draggable: true,
dragDistance: 0, dragDistance: 0,
@ -50,41 +49,39 @@ const getLayer = (
// the position - we do not need to call this on the `dragmove` event. // the position - we do not need to call this on the `dragmove` event.
if (onPosChanged) { if (onPosChanged) {
konvaLayer.on('dragend', function (e) { konvaLayer.on('dragend', function (e) {
onPosChanged({ id: layerState.id, x: Math.floor(e.target.x()), y: Math.floor(e.target.y()) }, 'layer'); onPosChanged({ id: entity.id, x: Math.floor(e.target.x()), y: Math.floor(e.target.y()) }, 'layer');
}); });
} }
const konvaObjectGroup = createObjectGroup(konvaLayer, RASTER_LAYER_OBJECT_GROUP_NAME); const konvaObjectGroup = createObjectGroup(konvaLayer, RASTER_LAYER_OBJECT_GROUP_NAME);
konvaLayer.add(konvaObjectGroup); map.stage.add(konvaLayer);
stage.add(konvaLayer); mapping = map.addMapping(entity.id, konvaLayer, konvaObjectGroup);
mapping = layerMap.addMapping(layerState.id, konvaLayer, konvaObjectGroup);
return mapping; return mapping;
}; };
/** /**
* Renders a regional guidance layer. * Renders a regional guidance layer.
* @param stage The konva stage * @param stage The konva stage
* @param layerState The regional guidance layer state * @param entity The regional guidance layer state
* @param tool The current tool * @param tool The current tool
* @param onPosChanged Callback for when the layer's position changes * @param onPosChanged Callback for when the layer's position changes
*/ */
export const renderLayer = async ( export const renderLayer = async (
stage: Konva.Stage, map: EntityToKonvaMap,
layerMap: EntityToKonvaMap, entity: LayerEntity,
layerState: LayerEntity,
tool: Tool, tool: Tool,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
) => { ) => {
const mapping = getLayer(stage, layerMap, layerState, onPosChanged); const mapping = getLayer(map, entity, onPosChanged);
// Update the layer's position and listening state // Update the layer's position and listening state
mapping.konvaLayer.setAttrs({ mapping.konvaLayer.setAttrs({
listening: tool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events listening: tool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events
x: Math.floor(layerState.x), x: Math.floor(entity.x),
y: Math.floor(layerState.y), y: Math.floor(entity.y),
}); });
const objectIds = layerState.objects.map(mapId); const objectIds = entity.objects.map(mapId);
// Destroy any objects that are no longer in state // Destroy any objects that are no longer in state
for (const entry of mapping.getEntries()) { for (const entry of mapping.getEntries()) {
if (!objectIds.includes(entry.id)) { if (!objectIds.includes(entry.id)) {
@ -92,7 +89,7 @@ export const renderLayer = async (
} }
} }
for (const obj of layerState.objects) { for (const obj of entity.objects) {
if (obj.type === 'brush_line') { if (obj.type === 'brush_line') {
const entry = getBrushLine(mapping, obj, RASTER_LAYER_BRUSH_LINE_NAME); const entry = getBrushLine(mapping, obj, RASTER_LAYER_BRUSH_LINE_NAME);
// Only update the points if they have changed. // Only update the points if they have changed.
@ -108,13 +105,13 @@ export const renderLayer = async (
} else if (obj.type === 'rect_shape') { } else if (obj.type === 'rect_shape') {
getRectShape(mapping, obj, RASTER_LAYER_RECT_SHAPE_NAME); getRectShape(mapping, obj, RASTER_LAYER_RECT_SHAPE_NAME);
} else if (obj.type === 'image') { } else if (obj.type === 'image') {
createImageObjectGroup(mapping, obj, RASTER_LAYER_IMAGE_NAME); createImageObjectGroup({ mapping, obj, name: RASTER_LAYER_IMAGE_NAME });
} }
} }
// Only update layer visibility if it has changed. // Only update layer visibility if it has changed.
if (mapping.konvaLayer.visible() !== layerState.isEnabled) { if (mapping.konvaLayer.visible() !== entity.isEnabled) {
mapping.konvaLayer.visible(layerState.isEnabled); mapping.konvaLayer.visible(entity.isEnabled);
} }
// const bboxRect = konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(layerState, konvaLayer); // const bboxRect = konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(layerState, konvaLayer);
@ -135,23 +132,22 @@ export const renderLayer = async (
// bboxRect.visible(false); // bboxRect.visible(false);
// } // }
mapping.konvaObjectGroup.opacity(layerState.opacity); mapping.konvaObjectGroup.opacity(entity.opacity);
}; };
export const renderLayers = ( export const renderLayers = (
stage: Konva.Stage, map: EntityToKonvaMap,
layerMap: EntityToKonvaMap, entities: LayerEntity[],
layers: LayerEntity[],
tool: Tool, tool: Tool,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
): void => { ): void => {
// Destroy nonexistent layers // Destroy nonexistent layers
for (const mapping of layerMap.getMappings()) { for (const mapping of map.getMappings()) {
if (!layers.find((l) => l.id === mapping.id)) { if (!entities.find((l) => l.id === mapping.id)) {
layerMap.destroyMapping(mapping.id); map.destroyMapping(mapping.id);
} }
} }
for (const layer of layers) { for (const layer of entities) {
renderLayer(stage, layerMap, layer, tool, onPosChanged); renderLayer(map, layer, tool, onPosChanged);
} }
}; };

@ -9,14 +9,23 @@ import type {
import { import {
getLayerBboxId, getLayerBboxId,
getObjectGroupId, getObjectGroupId,
IMAGE_PLACEHOLDER_NAME,
LAYER_BBOX_NAME, LAYER_BBOX_NAME,
PREVIEW_GENERATION_BBOX_DUMMY_RECT, PREVIEW_GENERATION_BBOX_DUMMY_RECT,
} from 'features/controlLayers/konva/naming'; } from 'features/controlLayers/konva/naming';
import type { BrushLine, CanvasEntity, EraserLine, ImageObject, RectShape } from 'features/controlLayers/store/types'; import type {
BrushLine,
CanvasEntity,
EraserLine,
ImageObject,
ImageWithDims,
RectShape,
} from 'features/controlLayers/store/types';
import { DEFAULT_RGBA_COLOR } from 'features/controlLayers/store/types'; import { DEFAULT_RGBA_COLOR } from 'features/controlLayers/store/types';
import { t } from 'i18next'; import { t } from 'i18next';
import Konva from 'konva'; import Konva from 'konva';
import { getImageDTO } from 'services/api/endpoints/images'; import { getImageDTO as defaultGetImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
/** /**
@ -120,16 +129,93 @@ export const getRectShape = (mapping: EntityToKonvaMapping, rectShape: RectShape
return entry; return entry;
}; };
export const updateImageSource = async (arg: {
entry: ImageEntry;
image: ImageWithDims;
getImageDTO?: (imageName: string) => Promise<ImageDTO | null>;
onLoading?: () => void;
onLoad?: (konvaImage: Konva.Image) => void;
onError?: () => void;
}) => {
const { entry, image, getImageDTO = defaultGetImageDTO, onLoading, onLoad, onError } = arg;
try {
entry.isLoading = true;
if (!entry.konvaImage) {
entry.konvaPlaceholderGroup.visible(true);
entry.konvaPlaceholderText.text(t('common.loadingImage', 'Loading Image'));
}
onLoading?.();
const imageDTO = await getImageDTO(image.name);
if (!imageDTO) {
entry.isLoading = false;
entry.isError = true;
entry.konvaPlaceholderGroup.visible(true);
entry.konvaPlaceholderText.text(t('common.imageFailedToLoad', 'Image Failed to Load'));
onError?.();
return;
}
const imageEl = new Image();
imageEl.onload = () => {
if (entry.konvaImage) {
entry.konvaImage.setAttrs({
image: imageEl,
});
} else {
entry.konvaImage = new Konva.Image({
id: entry.id,
listening: false,
image: imageEl,
});
entry.konvaImageGroup.add(entry.konvaImage);
}
entry.isLoading = false;
entry.isError = false;
entry.konvaPlaceholderGroup.visible(false);
onLoad?.(entry.konvaImage);
};
imageEl.onerror = () => {
entry.isLoading = false;
entry.isError = true;
entry.konvaPlaceholderGroup.visible(true);
entry.konvaPlaceholderText.text(t('common.imageFailedToLoad', 'Image Failed to Load'));
onError?.();
};
imageEl.id = image.name;
imageEl.src = imageDTO.image_url;
} catch {
entry.isLoading = false;
entry.isError = true;
entry.konvaPlaceholderGroup.visible(true);
entry.konvaPlaceholderText.text(t('common.imageFailedToLoad', 'Image Failed to Load'));
onError?.();
}
};
/** /**
* Creates an image placeholder group for an image object. * Creates an image placeholder group for an image object.
* @param imageObject The image object state * @param image The image object state
* @returns The konva group for the image placeholder, and callbacks to handle loading and error states * @returns The konva group for the image placeholder, and callbacks to handle loading and error states
*/ */
const createImagePlaceholderGroup = ( export const createImageObjectGroup = (arg: {
imageObject: ImageObject mapping: EntityToKonvaMapping;
): { konvaPlaceholderGroup: Konva.Group; onError: () => void; onLoading: () => void; onLoaded: () => void } => { obj: ImageObject;
const { width, height } = imageObject.image; name: string;
const konvaPlaceholderGroup = new Konva.Group({ name: 'image-placeholder', listening: false }); getImageDTO?: (imageName: string) => Promise<ImageDTO | null>;
onLoad?: (konvaImage: Konva.Image) => void;
onLoading?: () => void;
onError?: () => void;
}): ImageEntry => {
const { mapping, obj, name, getImageDTO = defaultGetImageDTO, onLoad, onLoading, onError } = arg;
let entry = mapping.getEntry<ImageEntry>(obj.id);
if (entry) {
return entry;
}
const { id, image } = obj;
const { width, height } = obj;
const konvaImageGroup = new Konva.Group({ id, name, listening: false });
const konvaPlaceholderGroup = new Konva.Group({ name: IMAGE_PLACEHOLDER_NAME, listening: false });
const konvaPlaceholderRect = new Konva.Rect({ const konvaPlaceholderRect = new Konva.Rect({
fill: 'hsl(220 12% 45% / 1)', // 'base.500' fill: 'hsl(220 12% 45% / 1)', // 'base.500'
width, width,
@ -137,7 +223,6 @@ const createImagePlaceholderGroup = (
listening: false, listening: false,
}); });
const konvaPlaceholderText = new Konva.Text({ const konvaPlaceholderText = new Konva.Text({
name: 'image-placeholder-text',
fill: 'hsl(220 12% 10% / 1)', // 'base.900' fill: 'hsl(220 12% 10% / 1)', // 'base.900'
width, width,
height, height,
@ -146,70 +231,25 @@ const createImagePlaceholderGroup = (
fontFamily: '"Inter Variable", sans-serif', fontFamily: '"Inter Variable", sans-serif',
fontSize: width / 16, fontSize: width / 16,
fontStyle: '600', fontStyle: '600',
text: 'Loading Image', text: t('common.loadingImage', 'Loading Image'),
listening: false, listening: false,
}); });
konvaPlaceholderGroup.add(konvaPlaceholderRect); konvaPlaceholderGroup.add(konvaPlaceholderRect);
konvaPlaceholderGroup.add(konvaPlaceholderText); konvaPlaceholderGroup.add(konvaPlaceholderText);
konvaImageGroup.add(konvaPlaceholderGroup);
const onError = () => {
konvaPlaceholderText.text(t('common.imageFailedToLoad', 'Image Failed to Load'));
};
const onLoading = () => {
konvaPlaceholderText.text(t('common.loadingImage', 'Loading Image'));
};
const onLoaded = () => {
konvaPlaceholderGroup.destroy();
};
return { konvaPlaceholderGroup, onError, onLoading, onLoaded };
};
/**
* Creates an image object group. Because images are loaded asynchronously, and we need to handle loading an error state,
* the image is rendered in a group, which includes a placeholder.
* @param imageObject The image object state
* @param layerObjectGroup The konva layer's object group to add the image to
* @param name The konva name for the image
* @returns A promise that resolves to the konva group for the image object
*/
export const createImageObjectGroup = async (
mapping: EntityToKonvaMapping,
imageObject: ImageObject,
name: string
): Promise<ImageEntry> => {
let entry = mapping.getEntry<ImageEntry>(imageObject.id);
if (entry) {
return entry;
}
const konvaImageGroup = new Konva.Group({ id: imageObject.id, name, listening: false });
const placeholder = createImagePlaceholderGroup(imageObject);
konvaImageGroup.add(placeholder.konvaPlaceholderGroup);
mapping.konvaObjectGroup.add(konvaImageGroup); mapping.konvaObjectGroup.add(konvaImageGroup);
entry = mapping.addEntry({ id: imageObject.id, type: 'image', konvaGroup: konvaImageGroup, konvaImage: null }); entry = mapping.addEntry({
getImageDTO(imageObject.image.name).then((imageDTO) => { id,
if (!imageDTO) { type: 'image',
placeholder.onError(); konvaImageGroup,
return; konvaPlaceholderGroup,
} konvaPlaceholderText,
const imageEl = new Image(); konvaImage: null,
imageEl.onload = () => { isLoading: false,
const konvaImage = new Konva.Image({ isError: false,
id: imageObject.id,
name,
listening: false,
image: imageEl,
});
placeholder.onLoaded();
konvaImageGroup.add(konvaImage);
entry.konvaImage = konvaImage;
};
imageEl.onerror = () => {
placeholder.onError();
};
imageEl.id = imageObject.id;
imageEl.src = imageDTO.image_url;
}); });
updateImageSource({ entry, image, getImageDTO, onLoad, onLoading, onError });
return entry; return entry;
}; };

@ -45,22 +45,21 @@ const createCompositingRect = (konvaLayer: Konva.Layer): Konva.Rect => {
/** /**
* Creates a regional guidance layer. * Creates a regional guidance layer.
* @param stage The konva stage * @param stage The konva stage
* @param region The regional guidance layer state * @param entity The regional guidance layer state
* @param onLayerPosChanged Callback for when the layer's position changes * @param onLayerPosChanged Callback for when the layer's position changes
*/ */
const getRegion = ( const getRegion = (
stage: Konva.Stage, map: EntityToKonvaMap,
regionMap: EntityToKonvaMap, entity: RegionEntity,
region: RegionEntity,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
): EntityToKonvaMapping => { ): EntityToKonvaMapping => {
let mapping = regionMap.getMapping(region.id); let mapping = map.getMapping(entity.id);
if (mapping) { if (mapping) {
return mapping; return mapping;
} }
// This layer hasn't been added to the konva state yet // This layer hasn't been added to the konva state yet
const konvaLayer = new Konva.Layer({ const konvaLayer = new Konva.Layer({
id: region.id, id: entity.id,
name: RG_LAYER_NAME, name: RG_LAYER_NAME,
draggable: true, draggable: true,
dragDistance: 0, dragDistance: 0,
@ -70,51 +69,48 @@ const getRegion = (
// the position - we do not need to call this on the `dragmove` event. // the position - we do not need to call this on the `dragmove` event.
if (onPosChanged) { if (onPosChanged) {
konvaLayer.on('dragend', function (e) { konvaLayer.on('dragend', function (e) {
onPosChanged({ id: region.id, x: Math.floor(e.target.x()), y: Math.floor(e.target.y()) }, 'regional_guidance'); onPosChanged({ id: entity.id, x: Math.floor(e.target.x()), y: Math.floor(e.target.y()) }, 'regional_guidance');
}); });
} }
const konvaObjectGroup = createObjectGroup(konvaLayer, RG_LAYER_OBJECT_GROUP_NAME); const konvaObjectGroup = createObjectGroup(konvaLayer, RG_LAYER_OBJECT_GROUP_NAME);
map.stage.add(konvaLayer);
konvaLayer.add(konvaObjectGroup); mapping = map.addMapping(entity.id, konvaLayer, konvaObjectGroup);
stage.add(konvaLayer);
mapping = regionMap.addMapping(region.id, konvaLayer, konvaObjectGroup);
return mapping; return mapping;
}; };
/** /**
* Renders a raster layer. * Renders a raster layer.
* @param stage The konva stage * @param stage The konva stage
* @param region The regional guidance layer state * @param entity The regional guidance layer state
* @param globalMaskLayerOpacity The global mask layer opacity * @param globalMaskLayerOpacity The global mask layer opacity
* @param tool The current tool * @param tool The current tool
* @param onPosChanged Callback for when the layer's position changes * @param onPosChanged Callback for when the layer's position changes
*/ */
export const renderRegion = ( export const renderRegion = (
stage: Konva.Stage, map: EntityToKonvaMap,
regionMap: EntityToKonvaMap, entity: RegionEntity,
region: RegionEntity,
globalMaskLayerOpacity: number, globalMaskLayerOpacity: number,
tool: Tool, tool: Tool,
selectedEntityIdentifier: CanvasEntityIdentifier | null, selectedEntityIdentifier: CanvasEntityIdentifier | null,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
): void => { ): void => {
const mapping = getRegion(stage, regionMap, region, onPosChanged); const mapping = getRegion(map, entity, onPosChanged);
// Update the layer's position and listening state // Update the layer's position and listening state
mapping.konvaLayer.setAttrs({ mapping.konvaLayer.setAttrs({
listening: tool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events listening: tool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events
x: Math.floor(region.x), x: Math.floor(entity.x),
y: Math.floor(region.y), y: Math.floor(entity.y),
}); });
// Convert the color to a string, stripping the alpha - the object group will handle opacity. // Convert the color to a string, stripping the alpha - the object group will handle opacity.
const rgbColor = rgbColorToString(region.fill); const rgbColor = rgbColorToString(entity.fill);
// We use caching to handle "global" layer opacity, but caching is expensive and we should only do it when required. // We use caching to handle "global" layer opacity, but caching is expensive and we should only do it when required.
let groupNeedsCache = false; let groupNeedsCache = false;
const objectIds = region.objects.map(mapId); const objectIds = entity.objects.map(mapId);
// Destroy any objects that are no longer in state // Destroy any objects that are no longer in state
for (const entry of mapping.getEntries()) { for (const entry of mapping.getEntries()) {
if (!objectIds.includes(entry.id)) { if (!objectIds.includes(entry.id)) {
@ -123,7 +119,7 @@ export const renderRegion = (
} }
} }
for (const obj of region.objects) { for (const obj of entity.objects) {
if (obj.type === 'brush_line') { if (obj.type === 'brush_line') {
const entry = getBrushLine(mapping, obj, RG_LAYER_BRUSH_LINE_NAME); const entry = getBrushLine(mapping, obj, RG_LAYER_BRUSH_LINE_NAME);
@ -164,8 +160,8 @@ export const renderRegion = (
} }
// Only update layer visibility if it has changed. // Only update layer visibility if it has changed.
if (mapping.konvaLayer.visible() !== region.isEnabled) { if (mapping.konvaLayer.visible() !== entity.isEnabled) {
mapping.konvaLayer.visible(region.isEnabled); mapping.konvaLayer.visible(entity.isEnabled);
groupNeedsCache = true; groupNeedsCache = true;
} }
@ -177,7 +173,7 @@ export const renderRegion = (
const compositingRect = const compositingRect =
mapping.konvaLayer.findOne<Konva.Rect>(`.${COMPOSITING_RECT_NAME}`) ?? createCompositingRect(mapping.konvaLayer); mapping.konvaLayer.findOne<Konva.Rect>(`.${COMPOSITING_RECT_NAME}`) ?? createCompositingRect(mapping.konvaLayer);
const isSelected = selectedEntityIdentifier?.id === region.id; const isSelected = selectedEntityIdentifier?.id === entity.id;
/** /**
* When the group is selected, we use a rect of the selected preview color, composited over the shapes. This allows * When the group is selected, we use a rect of the selected preview color, composited over the shapes. This allows
@ -200,7 +196,7 @@ export const renderRegion = (
compositingRect.setAttrs({ compositingRect.setAttrs({
// The rect should be the size of the layer - use the fast method if we don't have a pixel-perfect bbox already // The rect should be the size of the layer - use the fast method if we don't have a pixel-perfect bbox already
...(!region.bboxNeedsUpdate && region.bbox ? region.bbox : getLayerBboxFast(mapping.konvaLayer)), ...(!entity.bboxNeedsUpdate && entity.bbox ? entity.bbox : getLayerBboxFast(mapping.konvaLayer)),
fill: rgbColor, fill: rgbColor,
opacity: globalMaskLayerOpacity, opacity: globalMaskLayerOpacity,
// Draw this rect only where there are non-transparent pixels under it (e.g. the mask shapes) // Draw this rect only where there are non-transparent pixels under it (e.g. the mask shapes)
@ -240,21 +236,20 @@ export const renderRegion = (
}; };
export const renderRegions = ( export const renderRegions = (
stage: Konva.Stage, map: EntityToKonvaMap,
regionMap: EntityToKonvaMap, entities: RegionEntity[],
regions: RegionEntity[],
maskOpacity: number, maskOpacity: number,
tool: Tool, tool: Tool,
selectedEntityIdentifier: CanvasEntityIdentifier | null, selectedEntityIdentifier: CanvasEntityIdentifier | null,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
): void => { ): void => {
// Destroy nonexistent layers // Destroy nonexistent layers
for (const mapping of regionMap.getMappings()) { for (const mapping of map.getMappings()) {
if (!regions.find((rg) => rg.id === mapping.id)) { if (!entities.find((rg) => rg.id === mapping.id)) {
regionMap.destroyMapping(mapping.id); map.destroyMapping(mapping.id);
} }
} }
for (const rg of regions) { for (const rg of entities) {
renderRegion(stage, regionMap, rg, maskOpacity, tool, selectedEntityIdentifier, onPosChanged); renderRegion(map, rg, maskOpacity, tool, selectedEntityIdentifier, onPosChanged);
} }
}; };

@ -54,7 +54,6 @@ import type Konva from 'konva';
import type { IRect, Vector2d } from 'konva/lib/types'; import type { IRect, Vector2d } from 'konva/lib/types';
import { debounce } from 'lodash-es'; import { debounce } from 'lodash-es';
import type { RgbaColor } from 'react-colorful'; import type { RgbaColor } from 'react-colorful';
import { getImageDTO } from 'services/api/endpoints/images';
/** /**
* Initializes the canvas renderer. It subscribes to the redux store and listens for changes directly, bypassing the * Initializes the canvas renderer. It subscribes to the redux store and listens for changes directly, bypassing the
@ -283,9 +282,9 @@ export const initializeRenderer = (
// the entire state over when needed. // the entire state over when needed.
const debouncedUpdateBboxes = debounce(updateBboxes, 300); const debouncedUpdateBboxes = debounce(updateBboxes, 300);
const regionMap = new EntityToKonvaMap(); const regionMap = new EntityToKonvaMap(stage);
const layerMap = new EntityToKonvaMap(); const layerMap = new EntityToKonvaMap(stage);
const controlAdapterMap = new EntityToKonvaMap(); const controlAdapterMap = new EntityToKonvaMap(stage);
const renderCanvas = () => { const renderCanvas = () => {
const { canvasV2 } = store.getState(); const { canvasV2 } = store.getState();
@ -304,7 +303,7 @@ export const initializeRenderer = (
canvasV2.tool.selected !== prevCanvasV2.tool.selected canvasV2.tool.selected !== prevCanvasV2.tool.selected
) { ) {
logIfDebugging('Rendering layers'); logIfDebugging('Rendering layers');
renderLayers(stage, layerMap, canvasV2.layers, canvasV2.tool.selected, onPosChanged); renderLayers(layerMap, canvasV2.layers, canvasV2.tool.selected, onPosChanged);
} }
if ( if (
@ -315,7 +314,6 @@ export const initializeRenderer = (
) { ) {
logIfDebugging('Rendering regions'); logIfDebugging('Rendering regions');
renderRegions( renderRegions(
stage,
regionMap, regionMap,
canvasV2.regions, canvasV2.regions,
canvasV2.settings.maskOpacity, canvasV2.settings.maskOpacity,
@ -327,7 +325,7 @@ export const initializeRenderer = (
if (isFirstRender || canvasV2.controlAdapters !== prevCanvasV2.controlAdapters) { if (isFirstRender || canvasV2.controlAdapters !== prevCanvasV2.controlAdapters) {
logIfDebugging('Rendering control adapters'); logIfDebugging('Rendering control adapters');
renderControlAdapters(stage, controlAdapterMap, canvasV2.controlAdapters, getImageDTO); renderControlAdapters(controlAdapterMap, canvasV2.controlAdapters);
} }
if (isFirstRender || canvasV2.document !== prevCanvasV2.document) { if (isFirstRender || canvasV2.document !== prevCanvasV2.document) {
@ -367,7 +365,15 @@ export const initializeRenderer = (
canvasV2.regions !== prevCanvasV2.regions canvasV2.regions !== prevCanvasV2.regions
) { ) {
logIfDebugging('Arranging entities'); logIfDebugging('Arranging entities');
arrangeEntities(stage, canvasV2.layers, canvasV2.controlAdapters, canvasV2.regions); arrangeEntities(
stage,
layerMap,
canvasV2.layers,
controlAdapterMap,
canvasV2.controlAdapters,
regionMap,
canvasV2.regions
);
} }
prevCanvasV2 = canvasV2; prevCanvasV2 = canvasV2;

@ -1,6 +1,5 @@
import { import {
CA_LAYER_NAME, CA_LAYER_NAME,
INITIAL_IMAGE_LAYER_NAME,
INPAINT_MASK_LAYER_NAME, INPAINT_MASK_LAYER_NAME,
RASTER_LAYER_BRUSH_LINE_NAME, RASTER_LAYER_BRUSH_LINE_NAME,
RASTER_LAYER_ERASER_LINE_NAME, RASTER_LAYER_ERASER_LINE_NAME,
@ -88,7 +87,6 @@ export const mapId = (object: { id: string }): string => object.id;
export const selectRenderableLayers = (node: Konva.Node): boolean => export const selectRenderableLayers = (node: Konva.Node): boolean =>
node.name() === RG_LAYER_NAME || node.name() === RG_LAYER_NAME ||
node.name() === CA_LAYER_NAME || node.name() === CA_LAYER_NAME ||
node.name() === INITIAL_IMAGE_LAYER_NAME ||
node.name() === RASTER_LAYER_NAME || node.name() === RASTER_LAYER_NAME ||
node.name() === INPAINT_MASK_LAYER_NAME; node.name() === INPAINT_MASK_LAYER_NAME;

@ -28,6 +28,21 @@ const initialState: CanvasV2State = {
ipAdapters: [], ipAdapters: [],
regions: [], regions: [],
loras: [], loras: [],
inpaintMask: {
bbox: null,
bboxNeedsUpdate: false,
fill: {
type: 'color_fill',
color: DEFAULT_RGBA_COLOR,
},
id: 'inpaint_mask',
imageCache: null,
isEnabled: false,
maskObjects: [],
type: 'inpaint_mask',
x: 0,
y: 0,
},
tool: { tool: {
selected: 'bbox', selected: 'bbox',
selectedBuffer: null, selectedBuffer: null,

@ -18,7 +18,7 @@ import type {
T2IAdapterConfig, T2IAdapterConfig,
T2IAdapterData, T2IAdapterData,
} from './types'; } from './types';
import { buildControlAdapterProcessorV2, imageDTOToImageWithDims } from './types'; import { buildControlAdapterProcessorV2, imageDTOToImageObject } from './types';
export const selectCA = (state: CanvasV2State, id: string) => state.controlAdapters.find((ca) => ca.id === id); export const selectCA = (state: CanvasV2State, id: string) => state.controlAdapters.find((ca) => ca.id === id);
export const selectCAOrThrow = (state: CanvasV2State, id: string) => { export const selectCAOrThrow = (state: CanvasV2State, id: string) => {
@ -128,37 +128,43 @@ export const controlAdaptersReducers = {
} }
moveToStart(state.controlAdapters, ca); moveToStart(state.controlAdapters, ca);
}, },
caImageChanged: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO | null }>) => { caImageChanged: {
const { id, imageDTO } = action.payload; reducer: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO | null; objectId: string }>) => {
const ca = selectCA(state, id); const { id, imageDTO, objectId } = action.payload;
if (!ca) { const ca = selectCA(state, id);
return; if (!ca) {
}
ca.bbox = null;
ca.bboxNeedsUpdate = true;
ca.isEnabled = true;
if (imageDTO) {
const newImage = imageDTOToImageWithDims(imageDTO);
if (isEqual(newImage, ca.image)) {
return; return;
} }
ca.image = newImage; ca.bbox = null;
ca.processedImage = null; ca.bboxNeedsUpdate = true;
} else { ca.isEnabled = true;
ca.image = null; if (imageDTO) {
ca.processedImage = null; const newImageObject = imageDTOToImageObject(id, objectId, imageDTO);
} if (isEqual(newImageObject, ca.imageObject)) {
return;
}
ca.imageObject = newImageObject;
ca.processedImageObject = null;
} else {
ca.imageObject = null;
ca.processedImageObject = null;
}
},
prepare: (payload: { id: string; imageDTO: ImageDTO | null }) => ({ payload: { ...payload, objectId: uuidv4() } }),
}, },
caProcessedImageChanged: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO | null }>) => { caProcessedImageChanged: {
const { id, imageDTO } = action.payload; reducer: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO | null; objectId: string }>) => {
const ca = selectCA(state, id); const { id, imageDTO, objectId } = action.payload;
if (!ca) { const ca = selectCA(state, id);
return; if (!ca) {
} return;
ca.bbox = null; }
ca.bboxNeedsUpdate = true; ca.bbox = null;
ca.isEnabled = true; ca.bboxNeedsUpdate = true;
ca.processedImage = imageDTO ? imageDTOToImageWithDims(imageDTO) : null; ca.isEnabled = true;
ca.processedImageObject = imageDTO ? imageDTOToImageObject(id, objectId, imageDTO) : null;
},
prepare: (payload: { id: string; imageDTO: ImageDTO | null }) => ({ payload: { ...payload, objectId: uuidv4() } }),
}, },
caModelChanged: ( caModelChanged: (
state, state,
@ -182,7 +188,7 @@ export const controlAdaptersReducers = {
if (candidateProcessorConfig?.type !== ca.processorConfig?.type) { if (candidateProcessorConfig?.type !== ca.processorConfig?.type) {
// The processor has changed. For example, the previous model was a Canny model and the new model is a Depth // The processor has changed. For example, the previous model was a Canny model and the new model is a Depth
// model. We need to use the new processor. // model. We need to use the new processor.
ca.processedImage = null; ca.processedImageObject = null;
ca.processorConfig = candidateProcessorConfig; ca.processorConfig = candidateProcessorConfig;
} }
@ -212,7 +218,7 @@ export const controlAdaptersReducers = {
} }
ca.processorConfig = processorConfig; ca.processorConfig = processorConfig;
if (!processorConfig) { if (!processorConfig) {
ca.processedImage = null; ca.processedImageObject = null;
} }
}, },
caFilterChanged: (state, action: PayloadAction<{ id: string; filter: Filter }>) => { caFilterChanged: (state, action: PayloadAction<{ id: string; filter: Filter }>) => {

@ -4,8 +4,14 @@ import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import type { CanvasV2State, CLIPVisionModelV2, IPAdapterConfig, IPAdapterEntity, IPMethodV2 } from './types'; import type {
import { imageDTOToImageWithDims } from './types'; CanvasV2State,
CLIPVisionModelV2,
IPAdapterConfig,
IPAdapterEntity,
IPMethodV2,
} from './types';
import { imageDTOToImageObject } from './types';
export const selectIPA = (state: CanvasV2State, id: string) => state.ipAdapters.find((ipa) => ipa.id === id); export const selectIPA = (state: CanvasV2State, id: string) => state.ipAdapters.find((ipa) => ipa.id === id);
export const selectIPAOrThrow = (state: CanvasV2State, id: string) => { export const selectIPAOrThrow = (state: CanvasV2State, id: string) => {
@ -48,13 +54,16 @@ export const ipAdaptersReducers = {
ipaAllDeleted: (state) => { ipaAllDeleted: (state) => {
state.ipAdapters = []; state.ipAdapters = [];
}, },
ipaImageChanged: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO | null }>) => { ipaImageChanged: {
const { id, imageDTO } = action.payload; reducer: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO | null; objectId: string }>) => {
const ipa = selectIPA(state, id); const { id, imageDTO, objectId } = action.payload;
if (!ipa) { const ipa = selectIPA(state, id);
return; if (!ipa) {
} return;
ipa.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null; }
ipa.imageObject = imageDTO ? imageDTOToImageObject(id, objectId, imageDTO) : null;
},
prepare: (payload: { id: string; imageDTO: ImageDTO | null }) => ({ payload: { ...payload, objectId: uuidv4() } }),
}, },
ipaMethodChanged: (state, action: PayloadAction<{ id: string; method: IPMethodV2 }>) => { ipaMethodChanged: (state, action: PayloadAction<{ id: string; method: IPMethodV2 }>) => {
const { id, method } = action.payload; const { id, method } = action.payload;

@ -1,6 +1,6 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils'; import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils';
import { getBrushLineId, getEraserLineId, getImageObjectId, getRectShapeId } from 'features/controlLayers/konva/naming'; import { getBrushLineId, getEraserLineId, getRectShapeId } from 'features/controlLayers/konva/naming';
import type { IRect } from 'konva/lib/types'; import type { IRect } from 'konva/lib/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
@ -14,7 +14,7 @@ import type {
PointAddedToLineArg, PointAddedToLineArg,
RectShapeAddedArg, RectShapeAddedArg,
} from './types'; } from './types';
import { isLine } from './types'; import { imageDTOToImageObject, isLine } from './types';
export const selectLayer = (state: CanvasV2State, id: string) => state.layers.find((layer) => layer.id === id); export const selectLayer = (state: CanvasV2State, id: string) => state.layers.find((layer) => layer.id === id);
export const selectLayerOrThrow = (state: CanvasV2State, id: string) => { export const selectLayerOrThrow = (state: CanvasV2State, id: string) => {
@ -73,7 +73,9 @@ export const layersReducers = {
layer.bbox = bbox; layer.bbox = bbox;
layer.bboxNeedsUpdate = false; layer.bboxNeedsUpdate = false;
if (bbox === null) { if (bbox === null) {
layer.objects = []; // TODO(psyche): Clear objects when bbox is cleared - right now this doesn't work bc bbox calculation for layers
// doesn't work - always returns null
// layer.objects = [];
} }
}, },
layerReset: (state, action: PayloadAction<{ id: string }>) => { layerReset: (state, action: PayloadAction<{ id: string }>) => {
@ -212,24 +214,15 @@ export const layersReducers = {
prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }), prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }),
}, },
layerImageAdded: { layerImageAdded: {
reducer: (state, action: PayloadAction<ImageObjectAddedArg & { imageId: string }>) => { reducer: (state, action: PayloadAction<ImageObjectAddedArg & { objectId: string }>) => {
const { id, imageId, imageDTO } = action.payload; const { id, objectId, imageDTO } = action.payload;
const layer = selectLayer(state, id); const layer = selectLayer(state, id);
if (!layer) { if (!layer) {
return; return;
} }
const { width, height, image_name: name } = imageDTO; layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO));
layer.objects.push({
type: 'image',
id: getImageObjectId(id, imageId),
x: 0,
y: 0,
width,
height,
image: { width, height, name },
});
layer.bboxNeedsUpdate = true; layer.bboxNeedsUpdate = true;
}, },
prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, imageId: uuidv4() } }), prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }),
}, },
} satisfies SliceCaseReducers<CanvasV2State>; } satisfies SliceCaseReducers<CanvasV2State>;

@ -1,8 +1,12 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils'; import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils';
import { getBrushLineId, getEraserLineId, getRectShapeId } from 'features/controlLayers/konva/naming'; import { getBrushLineId, getEraserLineId, getRectShapeId } from 'features/controlLayers/konva/naming';
import type { CanvasV2State, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types'; import type {
import { imageDTOToImageWithDims } from 'features/controlLayers/store/types'; CanvasV2State,
CLIPVisionModelV2,
IPMethodV2,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas'; import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
import type { IRect } from 'konva/lib/types'; import type { IRect } from 'konva/lib/types';
@ -210,20 +214,25 @@ export const regionsReducers = {
} }
rg.ipAdapters = rg.ipAdapters.filter((ipAdapter) => ipAdapter.id !== ipAdapterId); rg.ipAdapters = rg.ipAdapters.filter((ipAdapter) => ipAdapter.id !== ipAdapterId);
}, },
rgIPAdapterImageChanged: ( rgIPAdapterImageChanged: {
state, reducer: (
action: PayloadAction<{ id: string; ipAdapterId: string; imageDTO: ImageDTO | null }> state,
) => { action: PayloadAction<{ id: string; ipAdapterId: string; imageDTO: ImageDTO | null; objectId: string }>
const { id, ipAdapterId, imageDTO } = action.payload; ) => {
const rg = selectRG(state, id); const { id, ipAdapterId, imageDTO, objectId } = action.payload;
if (!rg) { const rg = selectRG(state, id);
return; if (!rg) {
} return;
const ipa = rg.ipAdapters.find((ipa) => ipa.id === ipAdapterId); }
if (!ipa) { const ipa = rg.ipAdapters.find((ipa) => ipa.id === ipAdapterId);
return; if (!ipa) {
} return;
ipa.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null; }
ipa.imageObject = imageDTO ? imageDTOToImageObject(id, objectId, imageDTO) : null;
},
prepare: (payload: { id: string; ipAdapterId: string; imageDTO: ImageDTO | null }) => ({
payload: { ...payload, objectId: uuidv4() },
}),
}, },
rgIPAdapterWeightChanged: (state, action: PayloadAction<{ id: string; ipAdapterId: string; weight: number }>) => { rgIPAdapterWeightChanged: (state, action: PayloadAction<{ id: string; ipAdapterId: string; weight: number }>) => {
const { id, ipAdapterId, weight } = action.payload; const { id, ipAdapterId, weight } = action.payload;

@ -1,3 +1,4 @@
import { getImageObjectId } from 'features/controlLayers/konva/naming';
import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types'; import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import type { import type {
@ -536,6 +537,9 @@ const zRectShape = z.object({
}); });
export type RectShape = z.infer<typeof zRectShape>; export type RectShape = z.infer<typeof zRectShape>;
const zFilter = z.enum(['LightnessToAlphaFilter']);
export type Filter = z.infer<typeof zFilter>;
const zImageObject = z.object({ const zImageObject = z.object({
id: zId, id: zId,
type: z.literal('image'), type: z.literal('image'),
@ -544,6 +548,7 @@ const zImageObject = z.object({
y: z.number(), y: z.number(),
width: z.number().min(1), width: z.number().min(1),
height: z.number().min(1), height: z.number().min(1),
filters: z.array(zFilter),
}); });
export type ImageObject = z.infer<typeof zImageObject>; export type ImageObject = z.infer<typeof zImageObject>;
@ -569,7 +574,7 @@ export const zIPAdapterEntity = z.object({
isEnabled: z.boolean(), isEnabled: z.boolean(),
weight: z.number().gte(-1).lte(2), weight: z.number().gte(-1).lte(2),
method: zIPMethodV2, method: zIPMethodV2,
image: zImageWithDims.nullable(), imageObject: zImageObject.nullable(),
model: zModelIdentifierField.nullable(), model: zModelIdentifierField.nullable(),
clipVisionModel: zCLIPVisionModelV2, clipVisionModel: zCLIPVisionModelV2,
beginEndStepPct: zBeginEndStepPct, beginEndStepPct: zBeginEndStepPct,
@ -577,7 +582,7 @@ export const zIPAdapterEntity = z.object({
export type IPAdapterEntity = z.infer<typeof zIPAdapterEntity>; export type IPAdapterEntity = z.infer<typeof zIPAdapterEntity>;
export type IPAdapterConfig = Pick< export type IPAdapterConfig = Pick<
IPAdapterEntity, IPAdapterEntity,
'weight' | 'image' | 'beginEndStepPct' | 'model' | 'clipVisionModel' | 'method' 'weight' | 'imageObject' | 'beginEndStepPct' | 'model' | 'clipVisionModel' | 'method'
>; >;
const zMaskObject = z const zMaskObject = z
@ -642,7 +647,7 @@ const zImageFill = z.object({
src: z.string(), src: z.string(),
}); });
const zFill = z.discriminatedUnion('type', [zColorFill, zImageFill]); const zFill = z.discriminatedUnion('type', [zColorFill, zImageFill]);
const zInpaintMaskData = z.object({ const zInpaintMaskEntity = z.object({
id: zId, id: zId,
type: z.literal('inpaint_mask'), type: z.literal('inpaint_mask'),
isEnabled: z.boolean(), isEnabled: z.boolean(),
@ -654,10 +659,7 @@ const zInpaintMaskData = z.object({
fill: zFill, fill: zFill,
imageCache: zImageWithDims.nullable(), imageCache: zImageWithDims.nullable(),
}); });
export type InpaintMaskData = z.infer<typeof zInpaintMaskData>; export type InpaintMaskEntity = z.infer<typeof zInpaintMaskEntity>;
const zFilter = z.enum(['none', 'LightnessToAlphaFilter']);
export type Filter = z.infer<typeof zFilter>;
const zControlAdapterEntityBase = z.object({ const zControlAdapterEntityBase = z.object({
id: zId, id: zId,
@ -670,8 +672,8 @@ const zControlAdapterEntityBase = z.object({
opacity: zOpacity, opacity: zOpacity,
filter: zFilter, filter: zFilter,
weight: z.number().gte(-1).lte(2), weight: z.number().gte(-1).lte(2),
image: zImageWithDims.nullable(), imageObject: zImageObject.nullable(),
processedImage: zImageWithDims.nullable(), processedImageObject: zImageObject.nullable(),
processorConfig: zProcessorConfig.nullable(), processorConfig: zProcessorConfig.nullable(),
processorPendingBatchId: z.string().nullable().default(null), processorPendingBatchId: z.string().nullable().default(null),
beginEndStepPct: zBeginEndStepPct, beginEndStepPct: zBeginEndStepPct,
@ -693,8 +695,8 @@ export type ControlNetConfig = Pick<
ControlNetData, ControlNetData,
| 'adapterType' | 'adapterType'
| 'weight' | 'weight'
| 'image' | 'imageObject'
| 'processedImage' | 'processedImageObject'
| 'processorConfig' | 'processorConfig'
| 'beginEndStepPct' | 'beginEndStepPct'
| 'model' | 'model'
@ -702,7 +704,7 @@ export type ControlNetConfig = Pick<
>; >;
export type T2IAdapterConfig = Pick< export type T2IAdapterConfig = Pick<
T2IAdapterData, T2IAdapterData,
'adapterType' | 'weight' | 'image' | 'processedImage' | 'processorConfig' | 'beginEndStepPct' | 'model' 'adapterType' | 'weight' | 'imageObject' | 'processedImageObject' | 'processorConfig' | 'beginEndStepPct' | 'model'
>; >;
export const initialControlNetV2: ControlNetConfig = { export const initialControlNetV2: ControlNetConfig = {
@ -711,8 +713,8 @@ export const initialControlNetV2: ControlNetConfig = {
weight: 1, weight: 1,
beginEndStepPct: [0, 1], beginEndStepPct: [0, 1],
controlMode: 'balanced', controlMode: 'balanced',
image: null, imageObject: null,
processedImage: null, processedImageObject: null,
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(), processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
}; };
@ -721,13 +723,13 @@ export const initialT2IAdapterV2: T2IAdapterConfig = {
model: null, model: null,
weight: 1, weight: 1,
beginEndStepPct: [0, 1], beginEndStepPct: [0, 1],
image: null, imageObject: null,
processedImage: null, processedImageObject: null,
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(), processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
}; };
export const initialIPAdapterV2: IPAdapterConfig = { export const initialIPAdapterV2: IPAdapterConfig = {
image: null, imageObject: null,
model: null, model: null,
beginEndStepPct: [0, 1], beginEndStepPct: [0, 1],
method: 'full', method: 'full',
@ -752,12 +754,30 @@ export const imageDTOToImageWithDims = ({ image_name, width, height }: ImageDTO)
height, height,
}); });
export const imageDTOToImageObject = (entityId: string, objectId: string, imageDTO: ImageDTO): ImageObject => {
const { width, height, image_name } = imageDTO;
return {
id: getImageObjectId(entityId, objectId),
type: 'image',
x: 0,
y: 0,
width,
height,
filters: [],
image: {
name: image_name,
width,
height,
},
};
};
const zBoundingBoxScaleMethod = z.enum(['none', 'auto', 'manual']); const zBoundingBoxScaleMethod = z.enum(['none', 'auto', 'manual']);
export type BoundingBoxScaleMethod = z.infer<typeof zBoundingBoxScaleMethod>; export type BoundingBoxScaleMethod = z.infer<typeof zBoundingBoxScaleMethod>;
export const isBoundingBoxScaleMethod = (v: unknown): v is BoundingBoxScaleMethod => export const isBoundingBoxScaleMethod = (v: unknown): v is BoundingBoxScaleMethod =>
zBoundingBoxScaleMethod.safeParse(v).success; zBoundingBoxScaleMethod.safeParse(v).success;
export type CanvasEntity = LayerEntity | IPAdapterEntity | ControlAdapterEntity | RegionEntity | InpaintMaskData; export type CanvasEntity = LayerEntity | ControlAdapterEntity | RegionEntity | InpaintMaskEntity | IPAdapterEntity;
export type CanvasEntityIdentifier = Pick<CanvasEntity, 'id' | 'type'>; export type CanvasEntityIdentifier = Pick<CanvasEntity, 'id' | 'type'>;
export type Dimensions = { export type Dimensions = {
@ -775,6 +795,7 @@ export type LoRA = {
export type CanvasV2State = { export type CanvasV2State = {
_version: 3; _version: 3;
selectedEntityIdentifier: CanvasEntityIdentifier | null; selectedEntityIdentifier: CanvasEntityIdentifier | null;
inpaintMask: InpaintMaskEntity;
layers: LayerEntity[]; layers: LayerEntity[];
controlAdapters: ControlAdapterEntity[]; controlAdapters: ControlAdapterEntity[];
ipAdapters: IPAdapterEntity[]; ipAdapters: IPAdapterEntity[];
@ -871,3 +892,14 @@ export type ImageObjectAddedArg = { id: string; imageDTO: ImageDTO };
export const isLine = (obj: RenderableObject): obj is BrushLine | EraserLine => { export const isLine = (obj: RenderableObject): obj is BrushLine | EraserLine => {
return obj.type === 'brush_line' || obj.type === 'eraser_line'; return obj.type === 'brush_line' || obj.type === 'eraser_line';
}; };
/**
* A helper type to remove `[index: string]: any;` from a type.
* This is useful for some Konva types that include `[index: string]: any;` in addition to statically named
* properties, effectively widening the type signature to `Record<string, any>`. For example, `LineConfig`,
* `RectConfig`, `ImageConfig`, etc all include `[index: string]: any;` in their type signature.
* TODO(psyche): Fix this upstream.
*/
export type RemoveIndexString<T> = {
[K in keyof T as string extends K ? never : K]: T[K];
};

@ -25,7 +25,7 @@ export const getImageUsage = (nodes: NodesState, canvasV2: CanvasV2State, image_
(ca) => ca.image?.name === image_name || ca.processedImage?.name === image_name (ca) => ca.image?.name === image_name || ca.processedImage?.name === image_name
); );
const isIPAdapterImage = canvasV2.ipAdapters.some((ipa) => ipa.image?.name === image_name); const isIPAdapterImage = canvasV2.ipAdapters.some((ipa) => ipa.imageObject?.name === image_name);
const imageUsage: ImageUsage = { const imageUsage: ImageUsage = {
isLayerImage, isLayerImage,

@ -692,7 +692,7 @@ const parseIPAdapterToIPAdapterLayer: MetadataParseFunc<IPAdapterEntity> = async
model: zModelIdentifierField.parse(ipAdapterModel), model: zModelIdentifierField.parse(ipAdapterModel),
weight: typeof weight === 'number' ? weight : initialIPAdapterV2.weight, weight: typeof weight === 'number' ? weight : initialIPAdapterV2.weight,
beginEndStepPct, beginEndStepPct,
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null, imageObject: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
clipVisionModel: initialIPAdapterV2.clipVisionModel, // TODO: This needs to be added to the zIPAdapterField... clipVisionModel: initialIPAdapterV2.clipVisionModel, // TODO: This needs to be added to the zIPAdapterField...
method: method ?? initialIPAdapterV2.method, method: method ?? initialIPAdapterV2.method,
}; };

@ -278,10 +278,10 @@ const recallCA: MetadataRecallFunc<ControlAdapterEntity> = async (ca) => {
const recallIPA: MetadataRecallFunc<IPAdapterEntity> = async (ipa) => { const recallIPA: MetadataRecallFunc<IPAdapterEntity> = async (ipa) => {
const { dispatch } = getStore(); const { dispatch } = getStore();
const clone = deepClone(ipa); const clone = deepClone(ipa);
if (clone.image) { if (clone.imageObject) {
const imageDTO = await getImageDTO(clone.image.name); const imageDTO = await getImageDTO(clone.imageObject.name);
if (!imageDTO) { if (!imageDTO) {
clone.image = null; clone.imageObject = null;
} }
} }
if (clone.model) { if (clone.model) {
@ -305,10 +305,10 @@ const recallRG: MetadataRecallFunc<RegionEntity> = async (rg) => {
clone.imageCache = null; clone.imageCache = null;
for (const ipAdapter of clone.ipAdapters) { for (const ipAdapter of clone.ipAdapters) {
if (ipAdapter.image) { if (ipAdapter.imageObject) {
const imageDTO = await getImageDTO(ipAdapter.image.name); const imageDTO = await getImageDTO(ipAdapter.imageObject.name);
if (!imageDTO) { if (!imageDTO) {
ipAdapter.image = null; ipAdapter.imageObject = null;
} }
} }
if (ipAdapter.model) { if (ipAdapter.model) {

@ -34,7 +34,7 @@ export const addIPAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise
}; };
const addIPAdapter = (ipa: IPAdapterEntity, g: Graph, denoise: Invocation<'denoise_latents'>) => { const addIPAdapter = (ipa: IPAdapterEntity, g: Graph, denoise: Invocation<'denoise_latents'>) => {
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipa; const { id, weight, model, clipVisionModel, method, beginEndStepPct, imageObject: image } = ipa;
assert(image, 'IP Adapter image is required'); assert(image, 'IP Adapter image is required');
assert(model, 'IP Adapter model is required'); assert(model, 'IP Adapter model is required');
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise); const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
@ -59,6 +59,6 @@ export const isValidIPAdapter = (ipa: IPAdapterEntity, base: BaseModelType): boo
// Must be have a model that matches the current base and must have a control image // Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ipa.model); const hasModel = Boolean(ipa.model);
const modelMatchesBase = ipa.model?.base === base; const modelMatchesBase = ipa.model?.base === base;
const hasImage = Boolean(ipa.image); const hasImage = Boolean(ipa.imageObject);
return hasModel && modelMatchesBase && hasImage; return hasModel && modelMatchesBase && hasImage;
}; };

@ -190,7 +190,7 @@ export const addRegions = async (
for (const ipa of validRGIPAdapters) { for (const ipa of validRGIPAdapters) {
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise); const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipa; const { id, weight, model, clipVisionModel, method, beginEndStepPct, imageObject: image } = ipa;
assert(model, 'IP Adapter model is required'); assert(model, 'IP Adapter model is required');
assert(image, 'IP Adapter image is required'); assert(image, 'IP Adapter image is required');