mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): use canvas as source for control images (wip)
This commit is contained in:
parent
51008da2dd
commit
d988e18731
@ -2,6 +2,7 @@ import type { Store } from '@reduxjs/toolkit';
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
|
getControlAdapterImage,
|
||||||
getGenerationMode,
|
getGenerationMode,
|
||||||
getImageSourceImage,
|
getImageSourceImage,
|
||||||
getInpaintMaskImage,
|
getInpaintMaskImage,
|
||||||
@ -369,6 +370,10 @@ export class CanvasManager {
|
|||||||
return getGenerationMode({ manager: this });
|
return getGenerationMode({ manager: this });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getControlAdapterImage(arg: Omit<Parameters<typeof getControlAdapterImage>[0], 'manager'>) {
|
||||||
|
return getControlAdapterImage({ ...arg, manager: this });
|
||||||
|
}
|
||||||
|
|
||||||
getRegionMaskImage(arg: Omit<Parameters<typeof getRegionMaskImage>[0], 'manager'>) {
|
getRegionMaskImage(arg: Omit<Parameters<typeof getRegionMaskImage>[0], 'manager'>) {
|
||||||
return getRegionMaskImage({ ...arg, manager: this });
|
return getRegionMaskImage({ ...arg, manager: this });
|
||||||
}
|
}
|
||||||
|
@ -319,6 +319,24 @@ export function getRegionMaskLayerClone(arg: { manager: CanvasManager; id: strin
|
|||||||
return layerClone;
|
return layerClone;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getControlAdapterLayerClone(arg: { manager: CanvasManager; id: string }): Konva.Layer {
|
||||||
|
const { id, manager } = arg;
|
||||||
|
|
||||||
|
const controlAdapter = manager.controlAdapters.get(id);
|
||||||
|
assert(controlAdapter, `Canvas region with id ${id} not found`);
|
||||||
|
|
||||||
|
const controlAdapterClone = controlAdapter.layer.clone();
|
||||||
|
const objectGroupClone = controlAdapter.group.clone();
|
||||||
|
|
||||||
|
controlAdapterClone.destroyChildren();
|
||||||
|
controlAdapterClone.add(objectGroupClone);
|
||||||
|
|
||||||
|
objectGroupClone.opacity(1);
|
||||||
|
objectGroupClone.cache();
|
||||||
|
|
||||||
|
return controlAdapterClone;
|
||||||
|
}
|
||||||
|
|
||||||
export function getCompositeLayerStageClone(arg: { manager: CanvasManager }): Konva.Stage {
|
export function getCompositeLayerStageClone(arg: { manager: CanvasManager }): Konva.Stage {
|
||||||
const { manager } = arg;
|
const { manager } = arg;
|
||||||
|
|
||||||
@ -406,6 +424,37 @@ export async function getRegionMaskImage(arg: {
|
|||||||
return imageDTO;
|
return imageDTO;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function getControlAdapterImage(arg: {
|
||||||
|
manager: CanvasManager;
|
||||||
|
id: string;
|
||||||
|
bbox?: Rect;
|
||||||
|
preview?: boolean;
|
||||||
|
}): Promise<ImageDTO> {
|
||||||
|
const { manager, id, bbox, preview = false } = arg;
|
||||||
|
const ca = manager.stateApi.getControlAdaptersState().entities.find((entity) => entity.id === id);
|
||||||
|
assert(ca, `Control adapter entity state with id ${id} not found`);
|
||||||
|
|
||||||
|
// if (region.imageCache) {
|
||||||
|
// const imageDTO = await this.util.getImageDTO(region.imageCache.name);
|
||||||
|
// if (imageDTO) {
|
||||||
|
// return imageDTO;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
const layerClone = getControlAdapterLayerClone({ id, manager });
|
||||||
|
const blob = await konvaNodeToBlob(layerClone, bbox);
|
||||||
|
|
||||||
|
if (preview) {
|
||||||
|
previewBlob(blob, `region ${ca.id} mask`);
|
||||||
|
}
|
||||||
|
|
||||||
|
layerClone.destroy();
|
||||||
|
|
||||||
|
const imageDTO = await manager.util.uploadImage(blob, `${ca.id}_control_image.png`, 'control', true);
|
||||||
|
// manager.stateApi.onRegionMaskImageCached(ca.id, imageDTO);
|
||||||
|
return imageDTO;
|
||||||
|
}
|
||||||
|
|
||||||
export async function getInpaintMaskImage(arg: {
|
export async function getInpaintMaskImage(arg: {
|
||||||
manager: CanvasManager;
|
manager: CanvasManager;
|
||||||
bbox?: Rect;
|
bbox?: Rect;
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
|
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||||
import type {
|
import type {
|
||||||
ControlAdapterEntity,
|
ControlAdapterEntity,
|
||||||
ControlNetData,
|
ControlNetData,
|
||||||
ImageWithDims,
|
ImageWithDims,
|
||||||
ProcessorConfig,
|
ProcessorConfig,
|
||||||
|
Rect,
|
||||||
T2IAdapterData,
|
T2IAdapterData,
|
||||||
} from 'features/controlLayers/store/types';
|
} from 'features/controlLayers/store/types';
|
||||||
import type { ImageField } from 'features/nodes/types/common';
|
import type { ImageField } from 'features/nodes/types/common';
|
||||||
@ -11,18 +13,20 @@ import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
|||||||
import type { BaseModelType, Invocation } from 'services/api/types';
|
import type { BaseModelType, Invocation } from 'services/api/types';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
export const addControlAdapters = (
|
export const addControlAdapters = async (
|
||||||
|
manager: CanvasManager,
|
||||||
controlAdapters: ControlAdapterEntity[],
|
controlAdapters: ControlAdapterEntity[],
|
||||||
g: Graph,
|
g: Graph,
|
||||||
|
bbox: Rect,
|
||||||
denoise: Invocation<'denoise_latents'>,
|
denoise: Invocation<'denoise_latents'>,
|
||||||
base: BaseModelType
|
base: BaseModelType
|
||||||
): ControlAdapterEntity[] => {
|
): Promise<ControlAdapterEntity[]> => {
|
||||||
const validControlAdapters = controlAdapters.filter((ca) => isValidControlAdapter(ca, base));
|
const validControlAdapters = controlAdapters.filter((ca) => isValidControlAdapter(ca, base));
|
||||||
for (const ca of validControlAdapters) {
|
for (const ca of validControlAdapters) {
|
||||||
if (ca.adapterType === 'controlnet') {
|
if (ca.adapterType === 'controlnet') {
|
||||||
addControlNetToGraph(ca, g, denoise);
|
await addControlNetToGraph(manager, ca, g, bbox, denoise);
|
||||||
} else {
|
} else {
|
||||||
addT2IAdapterToGraph(ca, g, denoise);
|
await addT2IAdapterToGraph(manager, ca, g, bbox, denoise);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return validControlAdapters;
|
return validControlAdapters;
|
||||||
@ -45,14 +49,17 @@ const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const addControlNetToGraph = (ca: ControlNetData, g: Graph, denoise: Invocation<'denoise_latents'>) => {
|
const addControlNetToGraph = async (
|
||||||
const { id, beginEndStepPct, controlMode, imageObject, model, processedImageObject, processorConfig, weight } = ca;
|
manager: CanvasManager,
|
||||||
|
ca: ControlNetData,
|
||||||
|
g: Graph,
|
||||||
|
bbox: Rect,
|
||||||
|
denoise: Invocation<'denoise_latents'>
|
||||||
|
) => {
|
||||||
|
const { id, beginEndStepPct, controlMode, model, weight } = ca;
|
||||||
assert(model, 'ControlNet model is required');
|
assert(model, 'ControlNet model is required');
|
||||||
const controlImage = buildControlImage(
|
const { image_name } = await manager.getControlAdapterImage({ id: ca.id, bbox, preview: true });
|
||||||
imageObject?.image ?? null,
|
|
||||||
processedImageObject?.image ?? null,
|
|
||||||
processorConfig
|
|
||||||
);
|
|
||||||
const controlNetCollect = addControlNetCollectorSafe(g, denoise);
|
const controlNetCollect = addControlNetCollectorSafe(g, denoise);
|
||||||
|
|
||||||
const controlNet = g.addNode({
|
const controlNet = g.addNode({
|
||||||
@ -64,7 +71,7 @@ const addControlNetToGraph = (ca: ControlNetData, g: Graph, denoise: Invocation<
|
|||||||
resize_mode: 'just_resize',
|
resize_mode: 'just_resize',
|
||||||
control_model: model,
|
control_model: model,
|
||||||
control_weight: weight,
|
control_weight: weight,
|
||||||
image: controlImage,
|
image: { image_name },
|
||||||
});
|
});
|
||||||
g.addEdge(controlNet, 'control', controlNetCollect, 'item');
|
g.addEdge(controlNet, 'control', controlNetCollect, 'item');
|
||||||
};
|
};
|
||||||
@ -87,14 +94,17 @@ const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const addT2IAdapterToGraph = (ca: T2IAdapterData, g: Graph, denoise: Invocation<'denoise_latents'>) => {
|
const addT2IAdapterToGraph = async (
|
||||||
const { id, beginEndStepPct, imageObject, model, processedImageObject, processorConfig, weight } = ca;
|
manager: CanvasManager,
|
||||||
|
ca: T2IAdapterData,
|
||||||
|
g: Graph,
|
||||||
|
bbox: Rect,
|
||||||
|
denoise: Invocation<'denoise_latents'>
|
||||||
|
) => {
|
||||||
|
const { id, beginEndStepPct, model, weight } = ca;
|
||||||
assert(model, 'T2I Adapter model is required');
|
assert(model, 'T2I Adapter model is required');
|
||||||
const controlImage = buildControlImage(
|
const { image_name } = await manager.getControlAdapterImage({ id: ca.id, bbox, preview: true });
|
||||||
imageObject?.image ?? null,
|
|
||||||
processedImageObject?.image ?? null,
|
|
||||||
processorConfig
|
|
||||||
);
|
|
||||||
const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise);
|
const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise);
|
||||||
|
|
||||||
const t2iAdapter = g.addNode({
|
const t2iAdapter = g.addNode({
|
||||||
@ -105,7 +115,7 @@ const addT2IAdapterToGraph = (ca: T2IAdapterData, g: Graph, denoise: Invocation<
|
|||||||
resize_mode: 'just_resize',
|
resize_mode: 'just_resize',
|
||||||
t2i_adapter_model: model,
|
t2i_adapter_model: model,
|
||||||
weight: weight,
|
weight: weight,
|
||||||
image: controlImage,
|
image: { image_name },
|
||||||
});
|
});
|
||||||
|
|
||||||
g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item');
|
g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item');
|
||||||
|
@ -210,7 +210,14 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base);
|
const _addedCAs = await addControlAdapters(
|
||||||
|
manager,
|
||||||
|
state.canvasV2.controlAdapters.entities,
|
||||||
|
g,
|
||||||
|
state.canvasV2.bbox,
|
||||||
|
denoise,
|
||||||
|
modelConfig.base
|
||||||
|
);
|
||||||
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
|
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
|
||||||
const _addedRegions = await addRegions(
|
const _addedRegions = await addRegions(
|
||||||
manager,
|
manager,
|
||||||
|
@ -214,7 +214,14 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager):
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const _addedCAs = addControlAdapters(state.canvasV2.controlAdapters.entities, g, denoise, modelConfig.base);
|
const _addedCAs = await addControlAdapters(
|
||||||
|
manager,
|
||||||
|
state.canvasV2.controlAdapters.entities,
|
||||||
|
g,
|
||||||
|
state.canvasV2.bbox,
|
||||||
|
denoise,
|
||||||
|
modelConfig.base
|
||||||
|
);
|
||||||
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
|
const _addedIPAs = addIPAdapters(state.canvasV2.ipAdapters.entities, g, denoise, modelConfig.base);
|
||||||
const _addedRegions = await addRegions(
|
const _addedRegions = await addRegions(
|
||||||
manager,
|
manager,
|
||||||
|
Loading…
Reference in New Issue
Block a user