refactor(ui): create classes to abstract mgmt of konva nodes

This commit is contained in:
psychedelicious 2024-06-19 20:30:49 +10:00
parent 995c26751e
commit d965df8ca9
28 changed files with 482 additions and 433 deletions

@ -67,11 +67,11 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
// We should only process if the processor settings or image have changed
const originalCA = selectCA(originalState.canvasV2, id);
const originalImage = originalCA?.image;
const originalImage = originalCA?.imageObject;
const originalConfig = originalCA?.processorConfig;
const image = ca.image;
const processedImage = ca.processedImage;
const image = ca.imageObject;
const processedImage = ca.processedImageObject;
const config = ca.processorConfig;
if (isEqual(config, originalConfig) && isEqual(image, originalImage) && processedImage) {
@ -95,7 +95,7 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
}
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never);
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image.image, config as never);
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {

@ -49,7 +49,7 @@ const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, ima
};
const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.canvasV2.ipAdapters.forEach(({ id, image }) => {
state.canvasV2.ipAdapters.forEach(({ id, imageObject: image }) => {
if (image?.name === imageDTO.image_name) {
dispatch(ipaImageChanged({ id, imageDTO: null }));
}

@ -178,7 +178,7 @@ const createSelector = (templates: Templates) =>
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!ipa.image) {
if (!ipa.imageObject) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
@ -214,7 +214,7 @@ const createSelector = (templates: Templates) =>
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!ipAdapter.image) {
if (!ipAdapter.imageObject) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
});

@ -47,10 +47,10 @@ export const CAImagePreview = memo(
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
controlAdapter.image?.name ?? skipToken
controlAdapter.imageObject?.image.name ?? skipToken
);
const { currentData: processedControlImage, isError: isErrorProcessedControlImage } = useGetImageDTOQuery(
controlAdapter.processedImage?.name ?? skipToken
controlAdapter.processedImageObject?.image.name ?? skipToken
);
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();

@ -95,7 +95,7 @@ export const IPASettings = memo(({ id }: Props) => {
</Flex>
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
<IPAImagePreview
image={ipAdapter.image}
image={ipAdapter.imageObject?.image ?? null}
onChangeImage={onChangeImage}
ipAdapterId={ipAdapter.id}
droppableData={droppableData}

@ -123,7 +123,7 @@ export const RGIPAdapterSettings = memo(({ id, ipAdapterId, ipAdapterNumber }: P
</Flex>
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
<IPAImagePreview
image={ipAdapter.image}
image={ipAdapter.imageObject?.image ?? null}
onChangeImage={onChangeImage}
ipAdapterId={ipAdapter.id}
droppableData={droppableData}

@ -24,21 +24,27 @@ export type RectShapeEntry = {
export type ImageEntry = {
id: string;
type: ImageObject['type'];
konvaImageGroup: Konva.Group;
konvaPlaceholderGroup: Konva.Group;
konvaPlaceholderText: Konva.Text;
konvaImage: Konva.Image | null; // The image is loaded asynchronously, so it may not be available immediately
konvaGroup: Konva.Group;
isLoading: boolean;
isError: boolean;
};
type Entry = BrushLineEntry | EraserLineEntry | RectShapeEntry | ImageEntry;
export class EntityToKonvaMap {
stage: Konva.Stage;
mappings: Record<string, EntityToKonvaMapping>;
constructor() {
constructor(stage: Konva.Stage) {
this.stage = stage;
this.mappings = {};
}
addMapping(id: string, konvaLayer: Konva.Layer, konvaObjectGroup: Konva.Group): EntityToKonvaMapping {
const mapping = new EntityToKonvaMapping(id, konvaLayer, konvaObjectGroup);
const mapping = new EntityToKonvaMapping(id, konvaLayer, konvaObjectGroup, this);
this.mappings[id] = mapping;
return mapping;
}
@ -66,12 +72,14 @@ export class EntityToKonvaMapping {
konvaLayer: Konva.Layer;
konvaObjectGroup: Konva.Group;
konvaNodeEntries: Record<string, Entry>;
map: EntityToKonvaMap;
constructor(id: string, konvaLayer: Konva.Layer, konvaObjectGroup: Konva.Group) {
constructor(id: string, konvaLayer: Konva.Layer, konvaObjectGroup: Konva.Group, map: EntityToKonvaMap) {
this.id = id;
this.konvaLayer = konvaLayer;
this.konvaObjectGroup = konvaObjectGroup;
this.konvaNodeEntries = {};
this.map = map;
}
addEntry<T extends Entry>(entry: T): T {
@ -83,8 +91,8 @@ export class EntityToKonvaMapping {
return this.konvaNodeEntries[id] as T | undefined;
}
getEntries(): Entry[] {
return Object.values(this.konvaNodeEntries);
getEntries<T extends Entry>(): T[] {
return Object.values(this.konvaNodeEntries) as T[];
}
destroyEntry(id: string): void {
@ -97,7 +105,7 @@ export class EntityToKonvaMapping {
} else if (entry.type === 'rect_shape') {
entry.konvaRect.destroy();
} else if (entry.type === 'image') {
entry.konvaGroup.destroy();
entry.konvaImageGroup.destroy();
}
delete this.konvaNodeEntries[id];
}

@ -4,42 +4,40 @@
// IDs for singleton Konva layers and objects
export const PREVIEW_LAYER_ID = 'preview_layer';
export const PREVIEW_TOOL_GROUP_ID = 'preview_layer.tool_group';
export const PREVIEW_BRUSH_GROUP_ID = 'preview_layer.brush_group';
export const PREVIEW_BRUSH_FILL_ID = 'preview_layer.brush_fill';
export const PREVIEW_BRUSH_BORDER_INNER_ID = 'preview_layer.brush_border_inner';
export const PREVIEW_BRUSH_BORDER_OUTER_ID = 'preview_layer.brush_border_outer';
export const PREVIEW_RECT_ID = 'preview_layer.rect';
export const PREVIEW_GENERATION_BBOX_GROUP = 'preview_layer.gen_bbox_group';
export const PREVIEW_GENERATION_BBOX_TRANSFORMER = 'preview_layer.gen_bbox_transformer';
export const PREVIEW_GENERATION_BBOX_DUMMY_RECT = 'preview_layer.gen_bbox_dummy_rect';
export const PREVIEW_DOCUMENT_SIZE_GROUP = 'preview_layer.doc_size_group';
export const PREVIEW_DOCUMENT_SIZE_STAGE_RECT = 'preview_layer.doc_size_stage_rect';
export const PREVIEW_DOCUMENT_SIZE_DOCUMENT_RECT = 'preview_layer.doc_size_doc_rect';
export const PREVIEW_TOOL_GROUP_ID = `${PREVIEW_LAYER_ID}.tool_group`;
export const PREVIEW_BRUSH_GROUP_ID = `${PREVIEW_LAYER_ID}.brush_group`;
export const PREVIEW_BRUSH_FILL_ID = `${PREVIEW_LAYER_ID}.brush_fill`;
export const PREVIEW_BRUSH_BORDER_INNER_ID = `${PREVIEW_LAYER_ID}.brush_border_inner`;
export const PREVIEW_BRUSH_BORDER_OUTER_ID = `${PREVIEW_LAYER_ID}.brush_border_outer`;
export const PREVIEW_RECT_ID = `${PREVIEW_LAYER_ID}.rect`;
export const PREVIEW_GENERATION_BBOX_GROUP = `${PREVIEW_LAYER_ID}.gen_bbox_group`;
export const PREVIEW_GENERATION_BBOX_TRANSFORMER = `${PREVIEW_LAYER_ID}.gen_bbox_transformer`;
export const PREVIEW_GENERATION_BBOX_DUMMY_RECT = `${PREVIEW_LAYER_ID}.gen_bbox_dummy_rect`;
export const PREVIEW_DOCUMENT_SIZE_GROUP = `${PREVIEW_LAYER_ID}.doc_size_group`;
export const PREVIEW_DOCUMENT_SIZE_STAGE_RECT = `${PREVIEW_LAYER_ID}.doc_size_stage_rect`;
export const PREVIEW_DOCUMENT_SIZE_DOCUMENT_RECT = `${PREVIEW_LAYER_ID}.doc_size_doc_rect`;
// Names for Konva layers and objects (comparable to CSS classes)
export const LAYER_BBOX_NAME = 'layer.bbox';
export const COMPOSITING_RECT_NAME = 'compositing-rect';
export const LAYER_BBOX_NAME = 'layer_bbox';
export const COMPOSITING_RECT_NAME = 'compositing_rect';
export const IMAGE_PLACEHOLDER_NAME = 'image_placeholder';
export const CA_LAYER_NAME = 'control_adapter_layer';
export const CA_LAYER_IMAGE_NAME = 'control_adapter_layer.image';
export const INITIAL_IMAGE_LAYER_ID = 'singleton_initial_image_layer';
export const INITIAL_IMAGE_LAYER_NAME = 'initial_image_layer';
export const INITIAL_IMAGE_LAYER_IMAGE_NAME = 'initial_image_layer.image';
export const CA_LAYER_NAME = 'control_adapter';
export const CA_LAYER_OBJECT_GROUP_NAME = `${CA_LAYER_NAME}.object_group`;
export const CA_LAYER_IMAGE_NAME = `${CA_LAYER_NAME}.image`;
export const RG_LAYER_NAME = 'regional_guidance_layer';
export const RG_LAYER_OBJECT_GROUP_NAME = 'regional_guidance_layer.object_group';
export const RG_LAYER_BRUSH_LINE_NAME = 'regional_guidance_layer.brush_line';
export const RG_LAYER_ERASER_LINE_NAME = 'regional_guidance_layer.eraser_line';
export const RG_LAYER_RECT_SHAPE_NAME = 'regional_guidance_layer.rect_shape';
export const RG_LAYER_OBJECT_GROUP_NAME = `${RG_LAYER_NAME}.object_group`;
export const RG_LAYER_BRUSH_LINE_NAME = `${RG_LAYER_NAME}.brush_line`;
export const RG_LAYER_ERASER_LINE_NAME = `${RG_LAYER_NAME}.eraser_line`;
export const RG_LAYER_RECT_SHAPE_NAME = `${RG_LAYER_NAME}.rect_shape`;
export const RASTER_LAYER_NAME = 'raster_layer';
export const RASTER_LAYER_OBJECT_GROUP_NAME = 'raster_layer.object_group';
export const RASTER_LAYER_BRUSH_LINE_NAME = 'raster_layer.brush_line';
export const RASTER_LAYER_ERASER_LINE_NAME = 'raster_layer.eraser_line';
export const RASTER_LAYER_RECT_SHAPE_NAME = 'raster_layer.rect_shape';
export const RASTER_LAYER_IMAGE_NAME = 'raster_layer.image';
export const RASTER_LAYER_OBJECT_GROUP_NAME = `${RASTER_LAYER_NAME}.object_group`;
export const RASTER_LAYER_BRUSH_LINE_NAME = `${RASTER_LAYER_NAME}.brush_line`;
export const RASTER_LAYER_ERASER_LINE_NAME = `${RASTER_LAYER_NAME}.eraser_line`;
export const RASTER_LAYER_RECT_SHAPE_NAME = `${RASTER_LAYER_NAME}.rect_shape`;
export const RASTER_LAYER_IMAGE_NAME = `${RASTER_LAYER_NAME}.image`;
export const INPAINT_MASK_LAYER_NAME = 'inpaint_mask_layer';
@ -51,9 +49,8 @@ export const getLayerId = (entityId: string) => `${RASTER_LAYER_NAME}_${entityId
export const getBrushLineId = (entityId: string, lineId: string) => `${entityId}.brush_line_${lineId}`;
export const getEraserLineId = (entityId: string, lineId: string) => `${entityId}.eraser_line_${lineId}`;
export const getRectShapeId = (entityId: string, rectId: string) => `${entityId}.rect_${rectId}`;
export const getImageObjectId = (entityId: string, imageName: string) => `${entityId}.image_${imageName}`;
export const getImageObjectId = (entityId: string, imageId: string) => `${entityId}.image_${imageId}`;
export const getObjectGroupId = (entityId: string, groupId: string) => `${entityId}.objectGroup_${groupId}`;
export const getLayerBboxId = (entityId: string) => `${entityId}.bbox`;
export const getCAId = (entityId: string) => `control_adapter_${entityId}`;
export const getCAImageId = (entityId: string, imageName: string) => `${entityId}.image_${imageName}`;
export const getCAId = (entityId: string) => `${CA_LAYER_NAME}_${entityId}`;
export const getIPAId = (entityId: string) => `ip_adapter_${entityId}`;

@ -1,23 +1,27 @@
import type { EntityToKonvaMap } from 'features/controlLayers/konva/entityToKonvaMap';
import { BACKGROUND_LAYER_ID, PREVIEW_LAYER_ID } from 'features/controlLayers/konva/naming';
import type { ControlAdapterEntity, LayerEntity, RegionEntity } from 'features/controlLayers/store/types';
import type Konva from 'konva';
export const arrangeEntities = (
stage: Konva.Stage,
layerMap: EntityToKonvaMap,
layers: LayerEntity[],
controlAdapterMap: EntityToKonvaMap,
controlAdapters: ControlAdapterEntity[],
regionMap: EntityToKonvaMap,
regions: RegionEntity[]
): void => {
let zIndex = 0;
stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`)?.zIndex(++zIndex);
for (const layer of layers) {
stage.findOne<Konva.Layer>(`#${layer.id}`)?.zIndex(++zIndex);
layerMap.getMapping(layer.id)?.konvaLayer.zIndex(++zIndex);
}
for (const ca of controlAdapters) {
stage.findOne<Konva.Layer>(`#${ca.id}`)?.zIndex(++zIndex);
controlAdapterMap.getMapping(ca.id)?.konvaLayer.zIndex(++zIndex);
}
for (const rg of regions) {
stage.findOne<Konva.Layer>(`#${rg.id}`)?.zIndex(++zIndex);
regionMap.getMapping(rg.id)?.konvaLayer.zIndex(++zIndex);
}
stage.findOne<Konva.Layer>(`#${PREVIEW_LAYER_ID}`)?.zIndex(++zIndex);
};

@ -1,9 +1,15 @@
import type { EntityToKonvaMap } from 'features/controlLayers/konva/entityToKonvaMap';
import type { EntityToKonvaMap, EntityToKonvaMapping, ImageEntry } from 'features/controlLayers/konva/entityToKonvaMap';
import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters';
import { CA_LAYER_IMAGE_NAME, CA_LAYER_NAME, getCAImageId } from 'features/controlLayers/konva/naming';
import { CA_LAYER_IMAGE_NAME, CA_LAYER_NAME, CA_LAYER_OBJECT_GROUP_NAME } from 'features/controlLayers/konva/naming';
import {
createImageObjectGroup,
createObjectGroup,
updateImageSource,
} from 'features/controlLayers/konva/renderers/objects';
import type { ControlAdapterEntity } from 'features/controlLayers/store/types';
import Konva from 'konva';
import type { ImageDTO } from 'services/api/types';
import { isEqual } from 'lodash-es';
import { assert } from 'tsafe';
/**
* Logic for creating and rendering control adapter (control net & t2i adapter) layers. These layers have image objects
@ -13,164 +19,98 @@ import type { ImageDTO } from 'services/api/types';
/**
* Creates a control adapter layer.
* @param stage The konva stage
* @param ca The control adapter layer state
* @param entity The control adapter layer state
*/
const createCALayer = (stage: Konva.Stage, ca: ControlAdapterEntity): Konva.Layer => {
const getControlAdapter = (map: EntityToKonvaMap, entity: ControlAdapterEntity): EntityToKonvaMapping => {
let mapping = map.getMapping(entity.id);
if (mapping) {
return mapping;
}
const konvaLayer = new Konva.Layer({
id: ca.id,
id: entity.id,
name: CA_LAYER_NAME,
imageSmoothingEnabled: false,
listening: false,
});
stage.add(konvaLayer);
return konvaLayer;
};
/**
* Creates a control adapter layer image.
* @param konvaLayer The konva layer
* @param imageEl The image element
*/
const createCALayerImage = (konvaLayer: Konva.Layer, imageEl: HTMLImageElement): Konva.Image => {
const konvaImage = new Konva.Image({
name: CA_LAYER_IMAGE_NAME,
image: imageEl,
listening: false,
});
konvaLayer.add(konvaImage);
return konvaImage;
};
/**
* Updates the image source for a control adapter layer. This includes loading the image from the server and updating
* the konva image.
* @param stage The konva stage
* @param konvaLayer The konva layer
* @param ca The control adapter layer state
* @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
*/
const updateControlAdapterImageSource = async (
stage: Konva.Stage,
konvaLayer: Konva.Layer,
ca: ControlAdapterEntity,
getImageDTO: (imageName: string) => Promise<ImageDTO | null>
): Promise<void> => {
const image = ca.processedImage ?? ca.image;
if (image) {
const imageName = image.name;
const imageDTO = await getImageDTO(imageName);
if (!imageDTO) {
return;
}
const imageEl = new Image();
const imageId = getCAImageId(ca.id, imageName);
imageEl.onload = () => {
// Find the existing image or create a new one - must find using the name, bc the id may have just changed
const konvaImage =
konvaLayer.findOne<Konva.Image>(`.${CA_LAYER_IMAGE_NAME}`) ?? createCALayerImage(konvaLayer, imageEl);
// Update the image's attributes
konvaImage.setAttrs({
id: imageId,
image: imageEl,
});
updateControlAdapterImageAttrs(stage, konvaImage, ca);
// Must cache after this to apply the filters
konvaImage.cache();
imageEl.id = imageId;
};
imageEl.src = imageDTO.image_url;
} else {
konvaLayer.findOne(`.${CA_LAYER_IMAGE_NAME}`)?.destroy();
}
};
/**
* Updates the image attributes for a control adapter layer's image (width, height, visibility, opacity, filters).
* @param stage The konva stage
* @param konvaImage The konva image
* @param ca The control adapter layer state
*/
const updateControlAdapterImageAttrs = (stage: Konva.Stage, konvaImage: Konva.Image, ca: ControlAdapterEntity): void => {
let needsCache = false;
// TODO(psyche): `node.filters()` returns null if no filters; report upstream
const filters = konvaImage.filters() ?? [];
const filter = filters[0] ?? null;
const filterNeedsUpdate = (filter === null && ca.filter !== 'none') || (filter && filter.name !== ca.filter);
if (
konvaImage.x() !== ca.x ||
konvaImage.y() !== ca.y ||
konvaImage.visible() !== ca.isEnabled ||
filterNeedsUpdate
) {
konvaImage.setAttrs({
opacity: ca.opacity,
scaleX: 1,
scaleY: 1,
visible: ca.isEnabled,
filters: ca.filter === 'LightnessToAlphaFilter' ? [LightnessToAlphaFilter] : [],
});
needsCache = true;
}
if (konvaImage.opacity() !== ca.opacity) {
konvaImage.opacity(ca.opacity);
}
if (needsCache) {
konvaImage.cache();
}
const konvaObjectGroup = createObjectGroup(konvaLayer, CA_LAYER_OBJECT_GROUP_NAME);
map.stage.add(konvaLayer);
mapping = map.addMapping(entity.id, konvaLayer, konvaObjectGroup);
return mapping;
};
/**
* Renders a control adapter layer. If the layer doesn't already exist, it is created. Otherwise, the layer is updated
* with the current image source and attributes.
* @param stage The konva stage
* @param ca The control adapter layer state
* @param entity The control adapter layer state
* @param getImageDTO A function to retrieve an image DTO from the server, used to update the image source
*/
export const renderControlAdapter = (
stage: Konva.Stage,
controlAdapterMap: EntityToKonvaMap,
ca: ControlAdapterEntity,
getImageDTO: (imageName: string) => Promise<ImageDTO | null>
): void => {
const konvaLayer = stage.findOne<Konva.Layer>(`#${ca.id}`) ?? createCALayer(stage, ca);
const konvaImage = konvaLayer.findOne<Konva.Image>(`.${CA_LAYER_IMAGE_NAME}`);
const canvasImageSource = konvaImage?.image();
export const renderControlAdapter = async (map: EntityToKonvaMap, entity: ControlAdapterEntity): Promise<void> => {
const mapping = getControlAdapter(map, entity);
const imageObject = entity.processedImageObject ?? entity.imageObject;
let imageSourceNeedsUpdate = false;
if (canvasImageSource instanceof HTMLImageElement) {
const image = ca.processedImage ?? ca.image;
if (image && canvasImageSource.id !== getCAImageId(ca.id, image.name)) {
imageSourceNeedsUpdate = true;
} else if (!image) {
imageSourceNeedsUpdate = true;
}
} else if (!canvasImageSource) {
imageSourceNeedsUpdate = true;
if (!imageObject) {
// The user has deleted/reset the image
mapping.getEntries().forEach((entry) => {
mapping.destroyEntry(entry.id);
});
return;
}
if (imageSourceNeedsUpdate) {
updateControlAdapterImageSource(stage, konvaLayer, ca, getImageDTO);
} else if (konvaImage) {
updateControlAdapterImageAttrs(stage, konvaImage, ca);
let entry = mapping.getEntries<ImageEntry>()[0];
const opacity = entity.opacity;
const visible = entity.isEnabled;
const filters = entity.filter === 'LightnessToAlphaFilter' ? [LightnessToAlphaFilter] : [];
if (!entry) {
entry = await createImageObjectGroup({
mapping,
obj: imageObject,
name: CA_LAYER_IMAGE_NAME,
onLoad: (konvaImage) => {
konvaImage.filters(filters);
konvaImage.cache();
konvaImage.opacity(opacity);
konvaImage.visible(visible);
},
});
} else {
if (entry.isLoading || entry.isError) {
return;
}
assert(entry.konvaImage, `Image entry ${entry.id} must have a konva image if it is not loading or in error state`);
const imageSource = entry.konvaImage.image();
assert(imageSource instanceof HTMLImageElement, `Image source must be an HTMLImageElement`);
if (imageSource.id !== imageObject.image.name) {
updateImageSource({
entry,
image: imageObject.image,
onLoad: (konvaImage) => {
konvaImage.filters(filters);
konvaImage.cache();
konvaImage.opacity(opacity);
konvaImage.visible(visible);
},
});
} else {
if (!isEqual(entry.konvaImage.filters(), filters)) {
entry.konvaImage.filters(filters);
entry.konvaImage.cache();
}
entry.konvaImage.opacity(opacity);
entry.konvaImage.visible(visible);
}
}
};
export const renderControlAdapters = (
stage: Konva.Stage,
controlAdapterMap: EntityToKonvaMap,
controlAdapters: ControlAdapterEntity[],
getImageDTO: (imageName: string) => Promise<ImageDTO | null>
): void => {
export const renderControlAdapters = (map: EntityToKonvaMap, entities: ControlAdapterEntity[]): void => {
// Destroy nonexistent layers
for (const mapping of controlAdapterMap.getMappings()) {
if (!controlAdapters.find((ca) => ca.id === mapping.id)) {
controlAdapterMap.destroyMapping(mapping.id);
for (const mapping of map.getMappings()) {
if (!entities.find((ca) => ca.id === mapping.id)) {
map.destroyMapping(mapping.id);
}
}
for (const ca of controlAdapters) {
renderControlAdapter(stage, controlAdapterMap, ca, getImageDTO);
for (const ca of entities) {
renderControlAdapter(map, ca);
}
};

@ -25,22 +25,21 @@ import Konva from 'konva';
/**
* Creates a raster layer.
* @param stage The konva stage
* @param layerState The raster layer state
* @param entity The raster layer state
* @param onPosChanged Callback for when the layer's position changes
*/
const getLayer = (
stage: Konva.Stage,
layerMap: EntityToKonvaMap,
layerState: LayerEntity,
map: EntityToKonvaMap,
entity: LayerEntity,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
): EntityToKonvaMapping => {
let mapping = layerMap.getMapping(layerState.id);
let mapping = map.getMapping(entity.id);
if (mapping) {
return mapping;
}
// This layer hasn't been added to the konva state yet
const konvaLayer = new Konva.Layer({
id: layerState.id,
id: entity.id,
name: RASTER_LAYER_NAME,
draggable: true,
dragDistance: 0,
@ -50,41 +49,39 @@ const getLayer = (
// the position - we do not need to call this on the `dragmove` event.
if (onPosChanged) {
konvaLayer.on('dragend', function (e) {
onPosChanged({ id: layerState.id, x: Math.floor(e.target.x()), y: Math.floor(e.target.y()) }, 'layer');
onPosChanged({ id: entity.id, x: Math.floor(e.target.x()), y: Math.floor(e.target.y()) }, 'layer');
});
}
const konvaObjectGroup = createObjectGroup(konvaLayer, RASTER_LAYER_OBJECT_GROUP_NAME);
konvaLayer.add(konvaObjectGroup);
stage.add(konvaLayer);
mapping = layerMap.addMapping(layerState.id, konvaLayer, konvaObjectGroup);
map.stage.add(konvaLayer);
mapping = map.addMapping(entity.id, konvaLayer, konvaObjectGroup);
return mapping;
};
/**
* Renders a regional guidance layer.
* @param stage The konva stage
* @param layerState The regional guidance layer state
* @param entity The regional guidance layer state
* @param tool The current tool
* @param onPosChanged Callback for when the layer's position changes
*/
export const renderLayer = async (
stage: Konva.Stage,
layerMap: EntityToKonvaMap,
layerState: LayerEntity,
map: EntityToKonvaMap,
entity: LayerEntity,
tool: Tool,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
) => {
const mapping = getLayer(stage, layerMap, layerState, onPosChanged);
const mapping = getLayer(map, entity, onPosChanged);
// Update the layer's position and listening state
mapping.konvaLayer.setAttrs({
listening: tool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events
x: Math.floor(layerState.x),
y: Math.floor(layerState.y),
x: Math.floor(entity.x),
y: Math.floor(entity.y),
});
const objectIds = layerState.objects.map(mapId);
const objectIds = entity.objects.map(mapId);
// Destroy any objects that are no longer in state
for (const entry of mapping.getEntries()) {
if (!objectIds.includes(entry.id)) {
@ -92,7 +89,7 @@ export const renderLayer = async (
}
}
for (const obj of layerState.objects) {
for (const obj of entity.objects) {
if (obj.type === 'brush_line') {
const entry = getBrushLine(mapping, obj, RASTER_LAYER_BRUSH_LINE_NAME);
// Only update the points if they have changed.
@ -108,13 +105,13 @@ export const renderLayer = async (
} else if (obj.type === 'rect_shape') {
getRectShape(mapping, obj, RASTER_LAYER_RECT_SHAPE_NAME);
} else if (obj.type === 'image') {
createImageObjectGroup(mapping, obj, RASTER_LAYER_IMAGE_NAME);
createImageObjectGroup({ mapping, obj, name: RASTER_LAYER_IMAGE_NAME });
}
}
// Only update layer visibility if it has changed.
if (mapping.konvaLayer.visible() !== layerState.isEnabled) {
mapping.konvaLayer.visible(layerState.isEnabled);
if (mapping.konvaLayer.visible() !== entity.isEnabled) {
mapping.konvaLayer.visible(entity.isEnabled);
}
// const bboxRect = konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(layerState, konvaLayer);
@ -135,23 +132,22 @@ export const renderLayer = async (
// bboxRect.visible(false);
// }
mapping.konvaObjectGroup.opacity(layerState.opacity);
mapping.konvaObjectGroup.opacity(entity.opacity);
};
export const renderLayers = (
stage: Konva.Stage,
layerMap: EntityToKonvaMap,
layers: LayerEntity[],
map: EntityToKonvaMap,
entities: LayerEntity[],
tool: Tool,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
): void => {
// Destroy nonexistent layers
for (const mapping of layerMap.getMappings()) {
if (!layers.find((l) => l.id === mapping.id)) {
layerMap.destroyMapping(mapping.id);
for (const mapping of map.getMappings()) {
if (!entities.find((l) => l.id === mapping.id)) {
map.destroyMapping(mapping.id);
}
}
for (const layer of layers) {
renderLayer(stage, layerMap, layer, tool, onPosChanged);
for (const layer of entities) {
renderLayer(map, layer, tool, onPosChanged);
}
};

@ -9,14 +9,23 @@ import type {
import {
getLayerBboxId,
getObjectGroupId,
IMAGE_PLACEHOLDER_NAME,
LAYER_BBOX_NAME,
PREVIEW_GENERATION_BBOX_DUMMY_RECT,
} from 'features/controlLayers/konva/naming';
import type { BrushLine, CanvasEntity, EraserLine, ImageObject, RectShape } from 'features/controlLayers/store/types';
import type {
BrushLine,
CanvasEntity,
EraserLine,
ImageObject,
ImageWithDims,
RectShape,
} from 'features/controlLayers/store/types';
import { DEFAULT_RGBA_COLOR } from 'features/controlLayers/store/types';
import { t } from 'i18next';
import Konva from 'konva';
import { getImageDTO } from 'services/api/endpoints/images';
import { getImageDTO as defaultGetImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
/**
@ -120,16 +129,93 @@ export const getRectShape = (mapping: EntityToKonvaMapping, rectShape: RectShape
return entry;
};
export const updateImageSource = async (arg: {
entry: ImageEntry;
image: ImageWithDims;
getImageDTO?: (imageName: string) => Promise<ImageDTO | null>;
onLoading?: () => void;
onLoad?: (konvaImage: Konva.Image) => void;
onError?: () => void;
}) => {
const { entry, image, getImageDTO = defaultGetImageDTO, onLoading, onLoad, onError } = arg;
try {
entry.isLoading = true;
if (!entry.konvaImage) {
entry.konvaPlaceholderGroup.visible(true);
entry.konvaPlaceholderText.text(t('common.loadingImage', 'Loading Image'));
}
onLoading?.();
const imageDTO = await getImageDTO(image.name);
if (!imageDTO) {
entry.isLoading = false;
entry.isError = true;
entry.konvaPlaceholderGroup.visible(true);
entry.konvaPlaceholderText.text(t('common.imageFailedToLoad', 'Image Failed to Load'));
onError?.();
return;
}
const imageEl = new Image();
imageEl.onload = () => {
if (entry.konvaImage) {
entry.konvaImage.setAttrs({
image: imageEl,
});
} else {
entry.konvaImage = new Konva.Image({
id: entry.id,
listening: false,
image: imageEl,
});
entry.konvaImageGroup.add(entry.konvaImage);
}
entry.isLoading = false;
entry.isError = false;
entry.konvaPlaceholderGroup.visible(false);
onLoad?.(entry.konvaImage);
};
imageEl.onerror = () => {
entry.isLoading = false;
entry.isError = true;
entry.konvaPlaceholderGroup.visible(true);
entry.konvaPlaceholderText.text(t('common.imageFailedToLoad', 'Image Failed to Load'));
onError?.();
};
imageEl.id = image.name;
imageEl.src = imageDTO.image_url;
} catch {
entry.isLoading = false;
entry.isError = true;
entry.konvaPlaceholderGroup.visible(true);
entry.konvaPlaceholderText.text(t('common.imageFailedToLoad', 'Image Failed to Load'));
onError?.();
}
};
/**
* Creates an image placeholder group for an image object.
* @param imageObject The image object state
* @param image The image object state
* @returns The konva group for the image placeholder, and callbacks to handle loading and error states
*/
const createImagePlaceholderGroup = (
imageObject: ImageObject
): { konvaPlaceholderGroup: Konva.Group; onError: () => void; onLoading: () => void; onLoaded: () => void } => {
const { width, height } = imageObject.image;
const konvaPlaceholderGroup = new Konva.Group({ name: 'image-placeholder', listening: false });
export const createImageObjectGroup = (arg: {
mapping: EntityToKonvaMapping;
obj: ImageObject;
name: string;
getImageDTO?: (imageName: string) => Promise<ImageDTO | null>;
onLoad?: (konvaImage: Konva.Image) => void;
onLoading?: () => void;
onError?: () => void;
}): ImageEntry => {
const { mapping, obj, name, getImageDTO = defaultGetImageDTO, onLoad, onLoading, onError } = arg;
let entry = mapping.getEntry<ImageEntry>(obj.id);
if (entry) {
return entry;
}
const { id, image } = obj;
const { width, height } = obj;
const konvaImageGroup = new Konva.Group({ id, name, listening: false });
const konvaPlaceholderGroup = new Konva.Group({ name: IMAGE_PLACEHOLDER_NAME, listening: false });
const konvaPlaceholderRect = new Konva.Rect({
fill: 'hsl(220 12% 45% / 1)', // 'base.500'
width,
@ -137,7 +223,6 @@ const createImagePlaceholderGroup = (
listening: false,
});
const konvaPlaceholderText = new Konva.Text({
name: 'image-placeholder-text',
fill: 'hsl(220 12% 10% / 1)', // 'base.900'
width,
height,
@ -146,70 +231,25 @@ const createImagePlaceholderGroup = (
fontFamily: '"Inter Variable", sans-serif',
fontSize: width / 16,
fontStyle: '600',
text: 'Loading Image',
text: t('common.loadingImage', 'Loading Image'),
listening: false,
});
konvaPlaceholderGroup.add(konvaPlaceholderRect);
konvaPlaceholderGroup.add(konvaPlaceholderText);
const onError = () => {
konvaPlaceholderText.text(t('common.imageFailedToLoad', 'Image Failed to Load'));
};
const onLoading = () => {
konvaPlaceholderText.text(t('common.loadingImage', 'Loading Image'));
};
const onLoaded = () => {
konvaPlaceholderGroup.destroy();
};
return { konvaPlaceholderGroup, onError, onLoading, onLoaded };
};
/**
* Creates an image object group. Because images are loaded asynchronously, and we need to handle loading an error state,
* the image is rendered in a group, which includes a placeholder.
* @param imageObject The image object state
* @param layerObjectGroup The konva layer's object group to add the image to
* @param name The konva name for the image
* @returns A promise that resolves to the konva group for the image object
*/
export const createImageObjectGroup = async (
mapping: EntityToKonvaMapping,
imageObject: ImageObject,
name: string
): Promise<ImageEntry> => {
let entry = mapping.getEntry<ImageEntry>(imageObject.id);
if (entry) {
return entry;
}
const konvaImageGroup = new Konva.Group({ id: imageObject.id, name, listening: false });
const placeholder = createImagePlaceholderGroup(imageObject);
konvaImageGroup.add(placeholder.konvaPlaceholderGroup);
konvaImageGroup.add(konvaPlaceholderGroup);
mapping.konvaObjectGroup.add(konvaImageGroup);
entry = mapping.addEntry({ id: imageObject.id, type: 'image', konvaGroup: konvaImageGroup, konvaImage: null });
getImageDTO(imageObject.image.name).then((imageDTO) => {
if (!imageDTO) {
placeholder.onError();
return;
}
const imageEl = new Image();
imageEl.onload = () => {
const konvaImage = new Konva.Image({
id: imageObject.id,
name,
listening: false,
image: imageEl,
});
placeholder.onLoaded();
konvaImageGroup.add(konvaImage);
entry.konvaImage = konvaImage;
};
imageEl.onerror = () => {
placeholder.onError();
};
imageEl.id = imageObject.id;
imageEl.src = imageDTO.image_url;
entry = mapping.addEntry({
id,
type: 'image',
konvaImageGroup,
konvaPlaceholderGroup,
konvaPlaceholderText,
konvaImage: null,
isLoading: false,
isError: false,
});
updateImageSource({ entry, image, getImageDTO, onLoad, onLoading, onError });
return entry;
};

@ -45,22 +45,21 @@ const createCompositingRect = (konvaLayer: Konva.Layer): Konva.Rect => {
/**
* Creates a regional guidance layer.
* @param stage The konva stage
* @param region The regional guidance layer state
* @param entity The regional guidance layer state
* @param onLayerPosChanged Callback for when the layer's position changes
*/
const getRegion = (
stage: Konva.Stage,
regionMap: EntityToKonvaMap,
region: RegionEntity,
map: EntityToKonvaMap,
entity: RegionEntity,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
): EntityToKonvaMapping => {
let mapping = regionMap.getMapping(region.id);
let mapping = map.getMapping(entity.id);
if (mapping) {
return mapping;
}
// This layer hasn't been added to the konva state yet
const konvaLayer = new Konva.Layer({
id: region.id,
id: entity.id,
name: RG_LAYER_NAME,
draggable: true,
dragDistance: 0,
@ -70,51 +69,48 @@ const getRegion = (
// the position - we do not need to call this on the `dragmove` event.
if (onPosChanged) {
konvaLayer.on('dragend', function (e) {
onPosChanged({ id: region.id, x: Math.floor(e.target.x()), y: Math.floor(e.target.y()) }, 'regional_guidance');
onPosChanged({ id: entity.id, x: Math.floor(e.target.x()), y: Math.floor(e.target.y()) }, 'regional_guidance');
});
}
const konvaObjectGroup = createObjectGroup(konvaLayer, RG_LAYER_OBJECT_GROUP_NAME);
konvaLayer.add(konvaObjectGroup);
stage.add(konvaLayer);
mapping = regionMap.addMapping(region.id, konvaLayer, konvaObjectGroup);
map.stage.add(konvaLayer);
mapping = map.addMapping(entity.id, konvaLayer, konvaObjectGroup);
return mapping;
};
/**
* Renders a raster layer.
* @param stage The konva stage
* @param region The regional guidance layer state
* @param entity The regional guidance layer state
* @param globalMaskLayerOpacity The global mask layer opacity
* @param tool The current tool
* @param onPosChanged Callback for when the layer's position changes
*/
export const renderRegion = (
stage: Konva.Stage,
regionMap: EntityToKonvaMap,
region: RegionEntity,
map: EntityToKonvaMap,
entity: RegionEntity,
globalMaskLayerOpacity: number,
tool: Tool,
selectedEntityIdentifier: CanvasEntityIdentifier | null,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
): void => {
const mapping = getRegion(stage, regionMap, region, onPosChanged);
const mapping = getRegion(map, entity, onPosChanged);
// Update the layer's position and listening state
mapping.konvaLayer.setAttrs({
listening: tool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events
x: Math.floor(region.x),
y: Math.floor(region.y),
x: Math.floor(entity.x),
y: Math.floor(entity.y),
});
// Convert the color to a string, stripping the alpha - the object group will handle opacity.
const rgbColor = rgbColorToString(region.fill);
const rgbColor = rgbColorToString(entity.fill);
// We use caching to handle "global" layer opacity, but caching is expensive and we should only do it when required.
let groupNeedsCache = false;
const objectIds = region.objects.map(mapId);
const objectIds = entity.objects.map(mapId);
// Destroy any objects that are no longer in state
for (const entry of mapping.getEntries()) {
if (!objectIds.includes(entry.id)) {
@ -123,7 +119,7 @@ export const renderRegion = (
}
}
for (const obj of region.objects) {
for (const obj of entity.objects) {
if (obj.type === 'brush_line') {
const entry = getBrushLine(mapping, obj, RG_LAYER_BRUSH_LINE_NAME);
@ -164,8 +160,8 @@ export const renderRegion = (
}
// Only update layer visibility if it has changed.
if (mapping.konvaLayer.visible() !== region.isEnabled) {
mapping.konvaLayer.visible(region.isEnabled);
if (mapping.konvaLayer.visible() !== entity.isEnabled) {
mapping.konvaLayer.visible(entity.isEnabled);
groupNeedsCache = true;
}
@ -177,7 +173,7 @@ export const renderRegion = (
const compositingRect =
mapping.konvaLayer.findOne<Konva.Rect>(`.${COMPOSITING_RECT_NAME}`) ?? createCompositingRect(mapping.konvaLayer);
const isSelected = selectedEntityIdentifier?.id === region.id;
const isSelected = selectedEntityIdentifier?.id === entity.id;
/**
* When the group is selected, we use a rect of the selected preview color, composited over the shapes. This allows
@ -200,7 +196,7 @@ export const renderRegion = (
compositingRect.setAttrs({
// The rect should be the size of the layer - use the fast method if we don't have a pixel-perfect bbox already
...(!region.bboxNeedsUpdate && region.bbox ? region.bbox : getLayerBboxFast(mapping.konvaLayer)),
...(!entity.bboxNeedsUpdate && entity.bbox ? entity.bbox : getLayerBboxFast(mapping.konvaLayer)),
fill: rgbColor,
opacity: globalMaskLayerOpacity,
// Draw this rect only where there are non-transparent pixels under it (e.g. the mask shapes)
@ -240,21 +236,20 @@ export const renderRegion = (
};
export const renderRegions = (
stage: Konva.Stage,
regionMap: EntityToKonvaMap,
regions: RegionEntity[],
map: EntityToKonvaMap,
entities: RegionEntity[],
maskOpacity: number,
tool: Tool,
selectedEntityIdentifier: CanvasEntityIdentifier | null,
onPosChanged?: (arg: PosChangedArg, entityType: CanvasEntity['type']) => void
): void => {
// Destroy nonexistent layers
for (const mapping of regionMap.getMappings()) {
if (!regions.find((rg) => rg.id === mapping.id)) {
regionMap.destroyMapping(mapping.id);
for (const mapping of map.getMappings()) {
if (!entities.find((rg) => rg.id === mapping.id)) {
map.destroyMapping(mapping.id);
}
}
for (const rg of regions) {
renderRegion(stage, regionMap, rg, maskOpacity, tool, selectedEntityIdentifier, onPosChanged);
for (const rg of entities) {
renderRegion(map, rg, maskOpacity, tool, selectedEntityIdentifier, onPosChanged);
}
};

@ -54,7 +54,6 @@ import type Konva from 'konva';
import type { IRect, Vector2d } from 'konva/lib/types';
import { debounce } from 'lodash-es';
import type { RgbaColor } from 'react-colorful';
import { getImageDTO } from 'services/api/endpoints/images';
/**
* Initializes the canvas renderer. It subscribes to the redux store and listens for changes directly, bypassing the
@ -283,9 +282,9 @@ export const initializeRenderer = (
// the entire state over when needed.
const debouncedUpdateBboxes = debounce(updateBboxes, 300);
const regionMap = new EntityToKonvaMap();
const layerMap = new EntityToKonvaMap();
const controlAdapterMap = new EntityToKonvaMap();
const regionMap = new EntityToKonvaMap(stage);
const layerMap = new EntityToKonvaMap(stage);
const controlAdapterMap = new EntityToKonvaMap(stage);
const renderCanvas = () => {
const { canvasV2 } = store.getState();
@ -304,7 +303,7 @@ export const initializeRenderer = (
canvasV2.tool.selected !== prevCanvasV2.tool.selected
) {
logIfDebugging('Rendering layers');
renderLayers(stage, layerMap, canvasV2.layers, canvasV2.tool.selected, onPosChanged);
renderLayers(layerMap, canvasV2.layers, canvasV2.tool.selected, onPosChanged);
}
if (
@ -315,7 +314,6 @@ export const initializeRenderer = (
) {
logIfDebugging('Rendering regions');
renderRegions(
stage,
regionMap,
canvasV2.regions,
canvasV2.settings.maskOpacity,
@ -327,7 +325,7 @@ export const initializeRenderer = (
if (isFirstRender || canvasV2.controlAdapters !== prevCanvasV2.controlAdapters) {
logIfDebugging('Rendering control adapters');
renderControlAdapters(stage, controlAdapterMap, canvasV2.controlAdapters, getImageDTO);
renderControlAdapters(controlAdapterMap, canvasV2.controlAdapters);
}
if (isFirstRender || canvasV2.document !== prevCanvasV2.document) {
@ -367,7 +365,15 @@ export const initializeRenderer = (
canvasV2.regions !== prevCanvasV2.regions
) {
logIfDebugging('Arranging entities');
arrangeEntities(stage, canvasV2.layers, canvasV2.controlAdapters, canvasV2.regions);
arrangeEntities(
stage,
layerMap,
canvasV2.layers,
controlAdapterMap,
canvasV2.controlAdapters,
regionMap,
canvasV2.regions
);
}
prevCanvasV2 = canvasV2;

@ -1,6 +1,5 @@
import {
CA_LAYER_NAME,
INITIAL_IMAGE_LAYER_NAME,
INPAINT_MASK_LAYER_NAME,
RASTER_LAYER_BRUSH_LINE_NAME,
RASTER_LAYER_ERASER_LINE_NAME,
@ -88,7 +87,6 @@ export const mapId = (object: { id: string }): string => object.id;
export const selectRenderableLayers = (node: Konva.Node): boolean =>
node.name() === RG_LAYER_NAME ||
node.name() === CA_LAYER_NAME ||
node.name() === INITIAL_IMAGE_LAYER_NAME ||
node.name() === RASTER_LAYER_NAME ||
node.name() === INPAINT_MASK_LAYER_NAME;

@ -28,6 +28,21 @@ const initialState: CanvasV2State = {
ipAdapters: [],
regions: [],
loras: [],
inpaintMask: {
bbox: null,
bboxNeedsUpdate: false,
fill: {
type: 'color_fill',
color: DEFAULT_RGBA_COLOR,
},
id: 'inpaint_mask',
imageCache: null,
isEnabled: false,
maskObjects: [],
type: 'inpaint_mask',
x: 0,
y: 0,
},
tool: {
selected: 'bbox',
selectedBuffer: null,

@ -18,7 +18,7 @@ import type {
T2IAdapterConfig,
T2IAdapterData,
} from './types';
import { buildControlAdapterProcessorV2, imageDTOToImageWithDims } from './types';
import { buildControlAdapterProcessorV2, imageDTOToImageObject } from './types';
export const selectCA = (state: CanvasV2State, id: string) => state.controlAdapters.find((ca) => ca.id === id);
export const selectCAOrThrow = (state: CanvasV2State, id: string) => {
@ -128,37 +128,43 @@ export const controlAdaptersReducers = {
}
moveToStart(state.controlAdapters, ca);
},
caImageChanged: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO | null }>) => {
const { id, imageDTO } = action.payload;
const ca = selectCA(state, id);
if (!ca) {
return;
}
ca.bbox = null;
ca.bboxNeedsUpdate = true;
ca.isEnabled = true;
if (imageDTO) {
const newImage = imageDTOToImageWithDims(imageDTO);
if (isEqual(newImage, ca.image)) {
caImageChanged: {
reducer: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO | null; objectId: string }>) => {
const { id, imageDTO, objectId } = action.payload;
const ca = selectCA(state, id);
if (!ca) {
return;
}
ca.image = newImage;
ca.processedImage = null;
} else {
ca.image = null;
ca.processedImage = null;
}
ca.bbox = null;
ca.bboxNeedsUpdate = true;
ca.isEnabled = true;
if (imageDTO) {
const newImageObject = imageDTOToImageObject(id, objectId, imageDTO);
if (isEqual(newImageObject, ca.imageObject)) {
return;
}
ca.imageObject = newImageObject;
ca.processedImageObject = null;
} else {
ca.imageObject = null;
ca.processedImageObject = null;
}
},
prepare: (payload: { id: string; imageDTO: ImageDTO | null }) => ({ payload: { ...payload, objectId: uuidv4() } }),
},
caProcessedImageChanged: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO | null }>) => {
const { id, imageDTO } = action.payload;
const ca = selectCA(state, id);
if (!ca) {
return;
}
ca.bbox = null;
ca.bboxNeedsUpdate = true;
ca.isEnabled = true;
ca.processedImage = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
caProcessedImageChanged: {
reducer: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO | null; objectId: string }>) => {
const { id, imageDTO, objectId } = action.payload;
const ca = selectCA(state, id);
if (!ca) {
return;
}
ca.bbox = null;
ca.bboxNeedsUpdate = true;
ca.isEnabled = true;
ca.processedImageObject = imageDTO ? imageDTOToImageObject(id, objectId, imageDTO) : null;
},
prepare: (payload: { id: string; imageDTO: ImageDTO | null }) => ({ payload: { ...payload, objectId: uuidv4() } }),
},
caModelChanged: (
state,
@ -182,7 +188,7 @@ export const controlAdaptersReducers = {
if (candidateProcessorConfig?.type !== ca.processorConfig?.type) {
// The processor has changed. For example, the previous model was a Canny model and the new model is a Depth
// model. We need to use the new processor.
ca.processedImage = null;
ca.processedImageObject = null;
ca.processorConfig = candidateProcessorConfig;
}
@ -212,7 +218,7 @@ export const controlAdaptersReducers = {
}
ca.processorConfig = processorConfig;
if (!processorConfig) {
ca.processedImage = null;
ca.processedImageObject = null;
}
},
caFilterChanged: (state, action: PayloadAction<{ id: string; filter: Filter }>) => {

@ -4,8 +4,14 @@ import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
import type { CanvasV2State, CLIPVisionModelV2, IPAdapterConfig, IPAdapterEntity, IPMethodV2 } from './types';
import { imageDTOToImageWithDims } from './types';
import type {
CanvasV2State,
CLIPVisionModelV2,
IPAdapterConfig,
IPAdapterEntity,
IPMethodV2,
} from './types';
import { imageDTOToImageObject } from './types';
export const selectIPA = (state: CanvasV2State, id: string) => state.ipAdapters.find((ipa) => ipa.id === id);
export const selectIPAOrThrow = (state: CanvasV2State, id: string) => {
@ -48,13 +54,16 @@ export const ipAdaptersReducers = {
ipaAllDeleted: (state) => {
state.ipAdapters = [];
},
ipaImageChanged: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO | null }>) => {
const { id, imageDTO } = action.payload;
const ipa = selectIPA(state, id);
if (!ipa) {
return;
}
ipa.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
ipaImageChanged: {
reducer: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO | null; objectId: string }>) => {
const { id, imageDTO, objectId } = action.payload;
const ipa = selectIPA(state, id);
if (!ipa) {
return;
}
ipa.imageObject = imageDTO ? imageDTOToImageObject(id, objectId, imageDTO) : null;
},
prepare: (payload: { id: string; imageDTO: ImageDTO | null }) => ({ payload: { ...payload, objectId: uuidv4() } }),
},
ipaMethodChanged: (state, action: PayloadAction<{ id: string; method: IPMethodV2 }>) => {
const { id, method } = action.payload;

@ -1,6 +1,6 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils';
import { getBrushLineId, getEraserLineId, getImageObjectId, getRectShapeId } from 'features/controlLayers/konva/naming';
import { getBrushLineId, getEraserLineId, getRectShapeId } from 'features/controlLayers/konva/naming';
import type { IRect } from 'konva/lib/types';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
@ -14,7 +14,7 @@ import type {
PointAddedToLineArg,
RectShapeAddedArg,
} from './types';
import { isLine } from './types';
import { imageDTOToImageObject, isLine } from './types';
export const selectLayer = (state: CanvasV2State, id: string) => state.layers.find((layer) => layer.id === id);
export const selectLayerOrThrow = (state: CanvasV2State, id: string) => {
@ -73,7 +73,9 @@ export const layersReducers = {
layer.bbox = bbox;
layer.bboxNeedsUpdate = false;
if (bbox === null) {
layer.objects = [];
// TODO(psyche): Clear objects when bbox is cleared - right now this doesn't work bc bbox calculation for layers
// doesn't work - always returns null
// layer.objects = [];
}
},
layerReset: (state, action: PayloadAction<{ id: string }>) => {
@ -212,24 +214,15 @@ export const layersReducers = {
prepare: (payload: RectShapeAddedArg) => ({ payload: { ...payload, rectId: uuidv4() } }),
},
layerImageAdded: {
reducer: (state, action: PayloadAction<ImageObjectAddedArg & { imageId: string }>) => {
const { id, imageId, imageDTO } = action.payload;
reducer: (state, action: PayloadAction<ImageObjectAddedArg & { objectId: string }>) => {
const { id, objectId, imageDTO } = action.payload;
const layer = selectLayer(state, id);
if (!layer) {
return;
}
const { width, height, image_name: name } = imageDTO;
layer.objects.push({
type: 'image',
id: getImageObjectId(id, imageId),
x: 0,
y: 0,
width,
height,
image: { width, height, name },
});
layer.objects.push(imageDTOToImageObject(id, objectId, imageDTO));
layer.bboxNeedsUpdate = true;
},
prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, imageId: uuidv4() } }),
prepare: (payload: ImageObjectAddedArg) => ({ payload: { ...payload, objectId: uuidv4() } }),
},
} satisfies SliceCaseReducers<CanvasV2State>;

@ -1,8 +1,12 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils';
import { getBrushLineId, getEraserLineId, getRectShapeId } from 'features/controlLayers/konva/naming';
import type { CanvasV2State, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/types';
import type {
CanvasV2State,
CLIPVisionModelV2,
IPMethodV2,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
import type { IRect } from 'konva/lib/types';
@ -210,20 +214,25 @@ export const regionsReducers = {
}
rg.ipAdapters = rg.ipAdapters.filter((ipAdapter) => ipAdapter.id !== ipAdapterId);
},
rgIPAdapterImageChanged: (
state,
action: PayloadAction<{ id: string; ipAdapterId: string; imageDTO: ImageDTO | null }>
) => {
const { id, ipAdapterId, imageDTO } = action.payload;
const rg = selectRG(state, id);
if (!rg) {
return;
}
const ipa = rg.ipAdapters.find((ipa) => ipa.id === ipAdapterId);
if (!ipa) {
return;
}
ipa.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
rgIPAdapterImageChanged: {
reducer: (
state,
action: PayloadAction<{ id: string; ipAdapterId: string; imageDTO: ImageDTO | null; objectId: string }>
) => {
const { id, ipAdapterId, imageDTO, objectId } = action.payload;
const rg = selectRG(state, id);
if (!rg) {
return;
}
const ipa = rg.ipAdapters.find((ipa) => ipa.id === ipAdapterId);
if (!ipa) {
return;
}
ipa.imageObject = imageDTO ? imageDTOToImageObject(id, objectId, imageDTO) : null;
},
prepare: (payload: { id: string; ipAdapterId: string; imageDTO: ImageDTO | null }) => ({
payload: { ...payload, objectId: uuidv4() },
}),
},
rgIPAdapterWeightChanged: (state, action: PayloadAction<{ id: string; ipAdapterId: string; weight: number }>) => {
const { id, ipAdapterId, weight } = action.payload;

@ -1,3 +1,4 @@
import { getImageObjectId } from 'features/controlLayers/konva/naming';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import type {
@ -536,6 +537,9 @@ const zRectShape = z.object({
});
export type RectShape = z.infer<typeof zRectShape>;
const zFilter = z.enum(['LightnessToAlphaFilter']);
export type Filter = z.infer<typeof zFilter>;
const zImageObject = z.object({
id: zId,
type: z.literal('image'),
@ -544,6 +548,7 @@ const zImageObject = z.object({
y: z.number(),
width: z.number().min(1),
height: z.number().min(1),
filters: z.array(zFilter),
});
export type ImageObject = z.infer<typeof zImageObject>;
@ -569,7 +574,7 @@ export const zIPAdapterEntity = z.object({
isEnabled: z.boolean(),
weight: z.number().gte(-1).lte(2),
method: zIPMethodV2,
image: zImageWithDims.nullable(),
imageObject: zImageObject.nullable(),
model: zModelIdentifierField.nullable(),
clipVisionModel: zCLIPVisionModelV2,
beginEndStepPct: zBeginEndStepPct,
@ -577,7 +582,7 @@ export const zIPAdapterEntity = z.object({
export type IPAdapterEntity = z.infer<typeof zIPAdapterEntity>;
export type IPAdapterConfig = Pick<
IPAdapterEntity,
'weight' | 'image' | 'beginEndStepPct' | 'model' | 'clipVisionModel' | 'method'
'weight' | 'imageObject' | 'beginEndStepPct' | 'model' | 'clipVisionModel' | 'method'
>;
const zMaskObject = z
@ -642,7 +647,7 @@ const zImageFill = z.object({
src: z.string(),
});
const zFill = z.discriminatedUnion('type', [zColorFill, zImageFill]);
const zInpaintMaskData = z.object({
const zInpaintMaskEntity = z.object({
id: zId,
type: z.literal('inpaint_mask'),
isEnabled: z.boolean(),
@ -654,10 +659,7 @@ const zInpaintMaskData = z.object({
fill: zFill,
imageCache: zImageWithDims.nullable(),
});
export type InpaintMaskData = z.infer<typeof zInpaintMaskData>;
const zFilter = z.enum(['none', 'LightnessToAlphaFilter']);
export type Filter = z.infer<typeof zFilter>;
export type InpaintMaskEntity = z.infer<typeof zInpaintMaskEntity>;
const zControlAdapterEntityBase = z.object({
id: zId,
@ -670,8 +672,8 @@ const zControlAdapterEntityBase = z.object({
opacity: zOpacity,
filter: zFilter,
weight: z.number().gte(-1).lte(2),
image: zImageWithDims.nullable(),
processedImage: zImageWithDims.nullable(),
imageObject: zImageObject.nullable(),
processedImageObject: zImageObject.nullable(),
processorConfig: zProcessorConfig.nullable(),
processorPendingBatchId: z.string().nullable().default(null),
beginEndStepPct: zBeginEndStepPct,
@ -693,8 +695,8 @@ export type ControlNetConfig = Pick<
ControlNetData,
| 'adapterType'
| 'weight'
| 'image'
| 'processedImage'
| 'imageObject'
| 'processedImageObject'
| 'processorConfig'
| 'beginEndStepPct'
| 'model'
@ -702,7 +704,7 @@ export type ControlNetConfig = Pick<
>;
export type T2IAdapterConfig = Pick<
T2IAdapterData,
'adapterType' | 'weight' | 'image' | 'processedImage' | 'processorConfig' | 'beginEndStepPct' | 'model'
'adapterType' | 'weight' | 'imageObject' | 'processedImageObject' | 'processorConfig' | 'beginEndStepPct' | 'model'
>;
export const initialControlNetV2: ControlNetConfig = {
@ -711,8 +713,8 @@ export const initialControlNetV2: ControlNetConfig = {
weight: 1,
beginEndStepPct: [0, 1],
controlMode: 'balanced',
image: null,
processedImage: null,
imageObject: null,
processedImageObject: null,
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
};
@ -721,13 +723,13 @@ export const initialT2IAdapterV2: T2IAdapterConfig = {
model: null,
weight: 1,
beginEndStepPct: [0, 1],
image: null,
processedImage: null,
imageObject: null,
processedImageObject: null,
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
};
export const initialIPAdapterV2: IPAdapterConfig = {
image: null,
imageObject: null,
model: null,
beginEndStepPct: [0, 1],
method: 'full',
@ -752,12 +754,30 @@ export const imageDTOToImageWithDims = ({ image_name, width, height }: ImageDTO)
height,
});
export const imageDTOToImageObject = (entityId: string, objectId: string, imageDTO: ImageDTO): ImageObject => {
const { width, height, image_name } = imageDTO;
return {
id: getImageObjectId(entityId, objectId),
type: 'image',
x: 0,
y: 0,
width,
height,
filters: [],
image: {
name: image_name,
width,
height,
},
};
};
const zBoundingBoxScaleMethod = z.enum(['none', 'auto', 'manual']);
export type BoundingBoxScaleMethod = z.infer<typeof zBoundingBoxScaleMethod>;
export const isBoundingBoxScaleMethod = (v: unknown): v is BoundingBoxScaleMethod =>
zBoundingBoxScaleMethod.safeParse(v).success;
export type CanvasEntity = LayerEntity | IPAdapterEntity | ControlAdapterEntity | RegionEntity | InpaintMaskData;
export type CanvasEntity = LayerEntity | ControlAdapterEntity | RegionEntity | InpaintMaskEntity | IPAdapterEntity;
export type CanvasEntityIdentifier = Pick<CanvasEntity, 'id' | 'type'>;
export type Dimensions = {
@ -775,6 +795,7 @@ export type LoRA = {
export type CanvasV2State = {
_version: 3;
selectedEntityIdentifier: CanvasEntityIdentifier | null;
inpaintMask: InpaintMaskEntity;
layers: LayerEntity[];
controlAdapters: ControlAdapterEntity[];
ipAdapters: IPAdapterEntity[];
@ -871,3 +892,14 @@ export type ImageObjectAddedArg = { id: string; imageDTO: ImageDTO };
export const isLine = (obj: RenderableObject): obj is BrushLine | EraserLine => {
return obj.type === 'brush_line' || obj.type === 'eraser_line';
};
/**
* A helper type to remove `[index: string]: any;` from a type.
* This is useful for some Konva types that include `[index: string]: any;` in addition to statically named
* properties, effectively widening the type signature to `Record<string, any>`. For example, `LineConfig`,
* `RectConfig`, `ImageConfig`, etc all include `[index: string]: any;` in their type signature.
* TODO(psyche): Fix this upstream.
*/
export type RemoveIndexString<T> = {
[K in keyof T as string extends K ? never : K]: T[K];
};

@ -25,7 +25,7 @@ export const getImageUsage = (nodes: NodesState, canvasV2: CanvasV2State, image_
(ca) => ca.image?.name === image_name || ca.processedImage?.name === image_name
);
const isIPAdapterImage = canvasV2.ipAdapters.some((ipa) => ipa.image?.name === image_name);
const isIPAdapterImage = canvasV2.ipAdapters.some((ipa) => ipa.imageObject?.name === image_name);
const imageUsage: ImageUsage = {
isLayerImage,

@ -692,7 +692,7 @@ const parseIPAdapterToIPAdapterLayer: MetadataParseFunc<IPAdapterEntity> = async
model: zModelIdentifierField.parse(ipAdapterModel),
weight: typeof weight === 'number' ? weight : initialIPAdapterV2.weight,
beginEndStepPct,
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
imageObject: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
clipVisionModel: initialIPAdapterV2.clipVisionModel, // TODO: This needs to be added to the zIPAdapterField...
method: method ?? initialIPAdapterV2.method,
};

@ -278,10 +278,10 @@ const recallCA: MetadataRecallFunc<ControlAdapterEntity> = async (ca) => {
const recallIPA: MetadataRecallFunc<IPAdapterEntity> = async (ipa) => {
const { dispatch } = getStore();
const clone = deepClone(ipa);
if (clone.image) {
const imageDTO = await getImageDTO(clone.image.name);
if (clone.imageObject) {
const imageDTO = await getImageDTO(clone.imageObject.name);
if (!imageDTO) {
clone.image = null;
clone.imageObject = null;
}
}
if (clone.model) {
@ -305,10 +305,10 @@ const recallRG: MetadataRecallFunc<RegionEntity> = async (rg) => {
clone.imageCache = null;
for (const ipAdapter of clone.ipAdapters) {
if (ipAdapter.image) {
const imageDTO = await getImageDTO(ipAdapter.image.name);
if (ipAdapter.imageObject) {
const imageDTO = await getImageDTO(ipAdapter.imageObject.name);
if (!imageDTO) {
ipAdapter.image = null;
ipAdapter.imageObject = null;
}
}
if (ipAdapter.model) {

@ -34,7 +34,7 @@ export const addIPAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise
};
const addIPAdapter = (ipa: IPAdapterEntity, g: Graph, denoise: Invocation<'denoise_latents'>) => {
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipa;
const { id, weight, model, clipVisionModel, method, beginEndStepPct, imageObject: image } = ipa;
assert(image, 'IP Adapter image is required');
assert(model, 'IP Adapter model is required');
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
@ -59,6 +59,6 @@ export const isValidIPAdapter = (ipa: IPAdapterEntity, base: BaseModelType): boo
// Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ipa.model);
const modelMatchesBase = ipa.model?.base === base;
const hasImage = Boolean(ipa.image);
const hasImage = Boolean(ipa.imageObject);
return hasModel && modelMatchesBase && hasImage;
};

@ -190,7 +190,7 @@ export const addRegions = async (
for (const ipa of validRGIPAdapters) {
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipa;
const { id, weight, model, clipVisionModel, method, beginEndStepPct, imageObject: image } = ipa;
assert(model, 'IP Adapter model is required');
assert(image, 'IP Adapter image is required');