feat(ui): use canvas as source for control images (wip)

This commit is contained in:
psychedelicious 2024-07-05 18:41:01 +10:00
parent 51008da2dd
commit d988e18731
5 changed files with 100 additions and 22 deletions

View File

@ -2,6 +2,7 @@ import type { Store } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import {
getControlAdapterImage,
getGenerationMode,
getImageSourceImage,
getInpaintMaskImage,
@ -369,6 +370,10 @@ export class CanvasManager {
return getGenerationMode({ manager: this });
}
getControlAdapterImage(arg: Omit<Parameters<typeof getControlAdapterImage>[0], 'manager'>) {
return getControlAdapterImage({ ...arg, manager: this });
}
getRegionMaskImage(arg: Omit<Parameters<typeof getRegionMaskImage>[0], 'manager'>) {
return getRegionMaskImage({ ...arg, manager: this });
}

View File

@ -319,6 +319,24 @@ export function getRegionMaskLayerClone(arg: { manager: CanvasManager; id: strin
return layerClone;
}
export function getControlAdapterLayerClone(arg: { manager: CanvasManager; id: string }): Konva.Layer {
const { id, manager } = arg;
const controlAdapter = manager.controlAdapters.get(id);
assert(controlAdapter, `Canvas region with id ${id} not found`);
const controlAdapterClone = controlAdapter.layer.clone();
const objectGroupClone = controlAdapter.group.clone();
controlAdapterClone.destroyChildren();
controlAdapterClone.add(objectGroupClone);
objectGroupClone.opacity(1);
objectGroupClone.cache();
return controlAdapterClone;
}
export function getCompositeLayerStageClone(arg: { manager: CanvasManager }): Konva.Stage {
const { manager } = arg;
@ -406,6 +424,37 @@ export async function getRegionMaskImage(arg: {
return imageDTO;
}
export async function getControlAdapterImage(arg: {
manager: CanvasManager;
id: string;
bbox?: Rect;
preview?: boolean;
}): Promise<ImageDTO> {
const { manager, id, bbox, preview = false } = arg;
const ca = manager.stateApi.getControlAdaptersState().entities.find((entity) => entity.id === id);
assert(ca, `Control adapter entity state with id ${id} not found`);
// if (region.imageCache) {
// const imageDTO = await this.util.getImageDTO(region.imageCache.name);
// if (imageDTO) {
// return imageDTO;
// }
// }
const layerClone = getControlAdapterLayerClone({ id, manager });
const blob = await konvaNodeToBlob(layerClone, bbox);
if (preview) {
previewBlob(blob, `region ${ca.id} mask`);
}
layerClone.destroy();
const imageDTO = await manager.util.uploadImage(blob, `${ca.id}_control_image.png`, 'control', true);
// manager.stateApi.onRegionMaskImageCached(ca.id, imageDTO);
return imageDTO;
}
export async function getInpaintMaskImage(arg: {
manager: CanvasManager;
bbox?: Rect;

View File

@ -1,8 +1,10 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type {
ControlAdapterEntity,
ControlNetData,
ImageWithDims,
ProcessorConfig,
Rect,
T2IAdapterData,
} from 'features/controlLayers/store/types';
import type { ImageField } from 'features/nodes/types/common';
@ -11,18 +13,20 @@ import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { BaseModelType, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
export const addControlAdapters = (
export const addControlAdapters = async (
manager: CanvasManager,
controlAdapters: ControlAdapterEntity[],
g: Graph,
bbox: Rect,
denoise: Invocation<'denoise_latents'>,
base: BaseModelType
): ControlAdapterEntity[] => {
): Promise<ControlAdapterEntity[]> => {
const validControlAdapters = controlAdapters.filter((ca) => isValidControlAdapter(ca, base));
for (const ca of validControlAdapters) {
if (ca.adapterType === 'controlnet') {
addControlNetToGraph(ca, g, denoise);
await addControlNetToGraph(manager, ca, g, bbox, denoise);
} else {
addT2IAdapterToGraph(ca, g, denoise);
await addT2IAdapterToGraph(manager, ca, g, bbox, denoise);
}
}
return validControlAdapters;
@ -45,14 +49,17 @@ const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
}
};
const addControlNetToGraph = (ca: ControlNetData, g: Graph, denoise: Invocation<'denoise_latents'>) => {
const { id, beginEndStepPct, controlMode, imageObject, model, processedImageObject, processorConfig, weight } = ca;
const addControlNetToGraph = async (
manager: CanvasManager,
ca: ControlNetData,
g: Graph,
bbox: Rect,
denoise: Invocation<'denoise_latents'>
) => {
const { id, beginEndStepPct, controlMode, model, weight } = ca;
assert(model, 'ControlNet model is required');
const controlImage = buildControlImage(
imageObject?.image ?? null,
processedImageObject?.image ?? null,
processorConfig
);
const { image_name } = await manager.getControlAdapterImage({ id: ca.id, bbox, preview: true });
const controlNetCollect = addControlNetCollectorSafe(g, denoise);
const controlNet = g.addNode({
@ -64,7 +71,7 @@ const addControlNetToGraph = (ca: ControlNetData, g: Graph, denoise: Invocation<
resize_mode: 'just_resize',
control_model: model,
control_weight: weight,
image: controlImage,
image: { image_name },
});
g.addEdge(controlNet, 'control', controlNetCollect, 'item');
};
@ -87,14 +94,17 @@ const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
}
};
const addT2IAdapterToGraph = (ca: T2IAdapterData, g: Graph, denoise: Invocation<'denoise_latents'>) => {
const { id, beginEndStepPct, imageObject, model, processedImageObject, processorConfig, weight } = ca;
const addT2IAdapterToGraph = async (
manager: CanvasManager,
ca: T2IAdapterData,
g: Graph,
bbox: Rect,
denoise: Invocation<'denoise_latents'>
) => {
const { id, beginEndStepPct, model, weight } = ca;
assert(model, 'T2I Adapter model is required');
const controlImage = buildControlImage(
imageObject?.image ?? null,
processedImageObject?.image ?? null,
processorConfig
);
const { image_name } = await manager.getControlAdapterImage({ id: ca.id, bbox, preview: true });
const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise);
const t2iAdapter = g.addNode({
@ -105,7 +115,7 @@ const addT2IAdapterToGraph = (ca: T2IAdapterData, g: Graph, denoise: Invocation<
resize_mode: 'just_resize',
t2i_adapter_model: model,
weight: weight,
image: controlImage,
image: { image_name },
});
g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item');

View File

@ -210,7 +210,14 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P
);
}
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base);
const _addedCAs = await addControlAdapters(
manager,
state.canvasV2.controlAdapters.entities,
g,
state.canvasV2.bbox,
denoise,
modelConfig.base
);
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
const _addedRegions = await addRegions(
manager,

View File

@ -214,7 +214,14 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager):
);
}
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base);
const _addedCAs = await addControlAdapters(
manager,
state.canvasV2.controlAdapters.entities,
g,
state.canvasV2.bbox,
denoise,
modelConfig.base
);
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
const _addedRegions = await addRegions(
manager,