feat(ui): use canvas as source for control images (wip)

This commit is contained in:
psychedelicious 2024-07-05 18:41:01 +10:00
parent 51008da2dd
commit d988e18731
5 changed files with 100 additions and 22 deletions

View File

@ -2,6 +2,7 @@ import type { Store } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { import {
getControlAdapterImage,
getGenerationMode, getGenerationMode,
getImageSourceImage, getImageSourceImage,
getInpaintMaskImage, getInpaintMaskImage,
@ -369,6 +370,10 @@ export class CanvasManager {
return getGenerationMode({ manager: this }); return getGenerationMode({ manager: this });
} }
getControlAdapterImage(arg: Omit<Parameters<typeof getControlAdapterImage>[0], 'manager'>) {
return getControlAdapterImage({ ...arg, manager: this });
}
getRegionMaskImage(arg: Omit<Parameters<typeof getRegionMaskImage>[0], 'manager'>) { getRegionMaskImage(arg: Omit<Parameters<typeof getRegionMaskImage>[0], 'manager'>) {
return getRegionMaskImage({ ...arg, manager: this }); return getRegionMaskImage({ ...arg, manager: this });
} }

View File

@ -319,6 +319,24 @@ export function getRegionMaskLayerClone(arg: { manager: CanvasManager; id: strin
return layerClone; return layerClone;
} }
export function getControlAdapterLayerClone(arg: { manager: CanvasManager; id: string }): Konva.Layer {
const { id, manager } = arg;
const controlAdapter = manager.controlAdapters.get(id);
assert(controlAdapter, `Canvas region with id ${id} not found`);
const controlAdapterClone = controlAdapter.layer.clone();
const objectGroupClone = controlAdapter.group.clone();
controlAdapterClone.destroyChildren();
controlAdapterClone.add(objectGroupClone);
objectGroupClone.opacity(1);
objectGroupClone.cache();
return controlAdapterClone;
}
export function getCompositeLayerStageClone(arg: { manager: CanvasManager }): Konva.Stage { export function getCompositeLayerStageClone(arg: { manager: CanvasManager }): Konva.Stage {
const { manager } = arg; const { manager } = arg;
@ -406,6 +424,37 @@ export async function getRegionMaskImage(arg: {
return imageDTO; return imageDTO;
} }
export async function getControlAdapterImage(arg: {
manager: CanvasManager;
id: string;
bbox?: Rect;
preview?: boolean;
}): Promise<ImageDTO> {
const { manager, id, bbox, preview = false } = arg;
const ca = manager.stateApi.getControlAdaptersState().entities.find((entity) => entity.id === id);
assert(ca, `Control adapter entity state with id ${id} not found`);
// if (region.imageCache) {
// const imageDTO = await this.util.getImageDTO(region.imageCache.name);
// if (imageDTO) {
// return imageDTO;
// }
// }
const layerClone = getControlAdapterLayerClone({ id, manager });
const blob = await konvaNodeToBlob(layerClone, bbox);
if (preview) {
previewBlob(blob, `region ${ca.id} mask`);
}
layerClone.destroy();
const imageDTO = await manager.util.uploadImage(blob, `${ca.id}_control_image.png`, 'control', true);
// manager.stateApi.onRegionMaskImageCached(ca.id, imageDTO);
return imageDTO;
}
export async function getInpaintMaskImage(arg: { export async function getInpaintMaskImage(arg: {
manager: CanvasManager; manager: CanvasManager;
bbox?: Rect; bbox?: Rect;

View File

@ -1,8 +1,10 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { import type {
ControlAdapterEntity, ControlAdapterEntity,
ControlNetData, ControlNetData,
ImageWithDims, ImageWithDims,
ProcessorConfig, ProcessorConfig,
Rect,
T2IAdapterData, T2IAdapterData,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import type { ImageField } from 'features/nodes/types/common'; import type { ImageField } from 'features/nodes/types/common';
@ -11,18 +13,20 @@ import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { BaseModelType, Invocation } from 'services/api/types'; import type { BaseModelType, Invocation } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
export const addControlAdapters = ( export const addControlAdapters = async (
manager: CanvasManager,
controlAdapters: ControlAdapterEntity[], controlAdapters: ControlAdapterEntity[],
g: Graph, g: Graph,
bbox: Rect,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'>,
base: BaseModelType base: BaseModelType
): ControlAdapterEntity[] => { ): Promise<ControlAdapterEntity[]> => {
const validControlAdapters = controlAdapters.filter((ca) => isValidControlAdapter(ca, base)); const validControlAdapters = controlAdapters.filter((ca) => isValidControlAdapter(ca, base));
for (const ca of validControlAdapters) { for (const ca of validControlAdapters) {
if (ca.adapterType === 'controlnet') { if (ca.adapterType === 'controlnet') {
addControlNetToGraph(ca, g, denoise); await addControlNetToGraph(manager, ca, g, bbox, denoise);
} else { } else {
addT2IAdapterToGraph(ca, g, denoise); await addT2IAdapterToGraph(manager, ca, g, bbox, denoise);
} }
} }
return validControlAdapters; return validControlAdapters;
@ -45,14 +49,17 @@ const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
} }
}; };
const addControlNetToGraph = (ca: ControlNetData, g: Graph, denoise: Invocation<'denoise_latents'>) => { const addControlNetToGraph = async (
const { id, beginEndStepPct, controlMode, imageObject, model, processedImageObject, processorConfig, weight } = ca; manager: CanvasManager,
ca: ControlNetData,
g: Graph,
bbox: Rect,
denoise: Invocation<'denoise_latents'>
) => {
const { id, beginEndStepPct, controlMode, model, weight } = ca;
assert(model, 'ControlNet model is required'); assert(model, 'ControlNet model is required');
const controlImage = buildControlImage( const { image_name } = await manager.getControlAdapterImage({ id: ca.id, bbox, preview: true });
imageObject?.image ?? null,
processedImageObject?.image ?? null,
processorConfig
);
const controlNetCollect = addControlNetCollectorSafe(g, denoise); const controlNetCollect = addControlNetCollectorSafe(g, denoise);
const controlNet = g.addNode({ const controlNet = g.addNode({
@ -64,7 +71,7 @@ const addControlNetToGraph = (ca: ControlNetData, g: Graph, denoise: Invocation<
resize_mode: 'just_resize', resize_mode: 'just_resize',
control_model: model, control_model: model,
control_weight: weight, control_weight: weight,
image: controlImage, image: { image_name },
}); });
g.addEdge(controlNet, 'control', controlNetCollect, 'item'); g.addEdge(controlNet, 'control', controlNetCollect, 'item');
}; };
@ -87,14 +94,17 @@ const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
} }
}; };
const addT2IAdapterToGraph = (ca: T2IAdapterData, g: Graph, denoise: Invocation<'denoise_latents'>) => { const addT2IAdapterToGraph = async (
const { id, beginEndStepPct, imageObject, model, processedImageObject, processorConfig, weight } = ca; manager: CanvasManager,
ca: T2IAdapterData,
g: Graph,
bbox: Rect,
denoise: Invocation<'denoise_latents'>
) => {
const { id, beginEndStepPct, model, weight } = ca;
assert(model, 'T2I Adapter model is required'); assert(model, 'T2I Adapter model is required');
const controlImage = buildControlImage( const { image_name } = await manager.getControlAdapterImage({ id: ca.id, bbox, preview: true });
imageObject?.image ?? null,
processedImageObject?.image ?? null,
processorConfig
);
const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise); const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise);
const t2iAdapter = g.addNode({ const t2iAdapter = g.addNode({
@ -105,7 +115,7 @@ const addT2IAdapterToGraph = (ca: T2IAdapterData, g: Graph, denoise: Invocation<
resize_mode: 'just_resize', resize_mode: 'just_resize',
t2i_adapter_model: model, t2i_adapter_model: model,
weight: weight, weight: weight,
image: controlImage, image: { image_name },
}); });
g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item'); g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item');

View File

@ -210,7 +210,14 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P
); );
} }
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base); const _addedCAs = await addControlAdapters(
manager,
state.canvasV2.controlAdapters.entities,
g,
state.canvasV2.bbox,
denoise,
modelConfig.base
);
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base); const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
const _addedRegions = await addRegions( const _addedRegions = await addRegions(
manager, manager,

View File

@ -214,7 +214,14 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager):
); );
} }
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base); const _addedCAs = await addControlAdapters(
manager,
state.canvasV2.controlAdapters.entities,
g,
state.canvasV2.bbox,
denoise,
modelConfig.base
);
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base); const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
const _addedRegions = await addRegions( const _addedRegions = await addRegions(
manager, manager,