fix(ui): fix generation graphs

This commit is contained in:
psychedelicious 2024-06-20 11:12:28 +10:00
parent ce8a7bc178
commit c47e02c309
5 changed files with 24 additions and 15 deletions

View File

@ -141,9 +141,9 @@ const createSelector = (templates: Templates) =>
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel')); problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
} }
// Must have a control image OR, if it has a processor, it must have a processed image // Must have a control image OR, if it has a processor, it must have a processed image
if (!ca.image) { if (!ca.imageObject) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected')); problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected'));
} else if (ca.processorConfig && !ca.processedImage) { } else if (ca.processorConfig && !ca.processedImageObject) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed')); problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed'));
} }
// T2I Adapters require images have dimensions that are multiples of 64 (SD1.5) or 32 (SDXL) // T2I Adapters require images have dimensions that are multiples of 64 (SD1.5) or 32 (SDXL)

View File

@ -46,9 +46,13 @@ const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
}; };
const addControlNetToGraph = (ca: ControlNetData, g: Graph, denoise: Invocation<'denoise_latents'>) => { const addControlNetToGraph = (ca: ControlNetData, g: Graph, denoise: Invocation<'denoise_latents'>) => {
const { id, beginEndStepPct, controlMode, image, model, processedImage, processorConfig, weight } = ca; const { id, beginEndStepPct, controlMode, imageObject, model, processedImageObject, processorConfig, weight } = ca;
assert(model, 'ControlNet model is required'); assert(model, 'ControlNet model is required');
const controlImage = buildControlImage(image, processedImage, processorConfig); const controlImage = buildControlImage(
imageObject?.image ?? null,
processedImageObject?.image ?? null,
processorConfig
);
const controlNetCollect = addControlNetCollectorSafe(g, denoise); const controlNetCollect = addControlNetCollectorSafe(g, denoise);
const controlNet = g.addNode({ const controlNet = g.addNode({
@ -84,9 +88,13 @@ const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
}; };
const addT2IAdapterToGraph = (ca: T2IAdapterData, g: Graph, denoise: Invocation<'denoise_latents'>) => { const addT2IAdapterToGraph = (ca: T2IAdapterData, g: Graph, denoise: Invocation<'denoise_latents'>) => {
const { id, beginEndStepPct, image, model, processedImage, processorConfig, weight } = ca; const { id, beginEndStepPct, imageObject, model, processedImageObject, processorConfig, weight } = ca;
assert(model, 'T2I Adapter model is required'); assert(model, 'T2I Adapter model is required');
const controlImage = buildControlImage(image, processedImage, processorConfig); const controlImage = buildControlImage(
imageObject?.image ?? null,
processedImageObject?.image ?? null,
processorConfig
);
const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise); const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise);
const t2iAdapter = g.addNode({ const t2iAdapter = g.addNode({
@ -126,6 +134,6 @@ const isValidControlAdapter = (ca: ControlAdapterEntity, base: BaseModelType): b
// Must be have a model that matches the current base and must have a control image // Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ca.model); const hasModel = Boolean(ca.model);
const modelMatchesBase = ca.model?.base === base; const modelMatchesBase = ca.model?.base === base;
const hasControlImage = Boolean(ca.image || (ca.processedImage && ca.processorConfig)); const hasControlImage = Boolean(ca.imageObject || (ca.processedImageObject && ca.processorConfig));
return hasModel && modelMatchesBase && hasControlImage; return hasModel && modelMatchesBase && hasControlImage;
}; };

View File

@ -34,8 +34,8 @@ export const addIPAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise
}; };
const addIPAdapter = (ipa: IPAdapterEntity, g: Graph, denoise: Invocation<'denoise_latents'>) => { const addIPAdapter = (ipa: IPAdapterEntity, g: Graph, denoise: Invocation<'denoise_latents'>) => {
const { id, weight, model, clipVisionModel, method, beginEndStepPct, imageObject: image } = ipa; const { id, weight, model, clipVisionModel, method, beginEndStepPct, imageObject } = ipa;
assert(image, 'IP Adapter image is required'); assert(imageObject, 'IP Adapter image is required');
assert(model, 'IP Adapter model is required'); assert(model, 'IP Adapter model is required');
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise); const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
@ -49,7 +49,7 @@ const addIPAdapter = (ipa: IPAdapterEntity, g: Graph, denoise: Invocation<'denoi
begin_step_percent: beginEndStepPct[0], begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1], end_step_percent: beginEndStepPct[1],
image: { image: {
image_name: image.name, image_name: imageObject.image.name,
}, },
}); });
g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item'); g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item');

View File

@ -2,6 +2,7 @@ import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { RG_LAYER_NAME } from 'features/controlLayers/konva/naming'; import { RG_LAYER_NAME } from 'features/controlLayers/konva/naming';
import { KonvaNodeManager } from 'features/controlLayers/konva/nodeManager';
import { renderRegions } from 'features/controlLayers/konva/renderers/regions'; import { renderRegions } from 'features/controlLayers/konva/renderers/regions';
import { blobToDataURL } from 'features/controlLayers/konva/util'; import { blobToDataURL } from 'features/controlLayers/konva/util';
import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice'; import { rgMaskImageUploaded } from 'features/controlLayers/store/canvasV2Slice';
@ -190,9 +191,9 @@ export const addRegions = async (
for (const ipa of validRGIPAdapters) { for (const ipa of validRGIPAdapters) {
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise); const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
const { id, weight, model, clipVisionModel, method, beginEndStepPct, imageObject: image } = ipa; const { id, weight, model, clipVisionModel, method, beginEndStepPct, imageObject } = ipa;
assert(model, 'IP Adapter model is required'); assert(model, 'IP Adapter model is required');
assert(image, 'IP Adapter image is required'); assert(imageObject, 'IP Adapter image is required');
const ipAdapter = g.addNode({ const ipAdapter = g.addNode({
id: `ip_adapter_${id}`, id: `ip_adapter_${id}`,
@ -204,7 +205,7 @@ export const addRegions = async (
begin_step_percent: beginEndStepPct[0], begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1], end_step_percent: beginEndStepPct[1],
image: { image: {
image_name: image.name, image_name: imageObject.image.name,
}, },
}); });
@ -260,7 +261,7 @@ export const getRGMaskBlobs = async (
): Promise<Record<string, Blob>> => { ): Promise<Record<string, Blob>> => {
const container = document.createElement('div'); const container = document.createElement('div');
const stage = new Konva.Stage({ container, ...documentSize }); const stage = new Konva.Stage({ container, ...documentSize });
renderRegions(stage, regions, 1, 'brush', null); renderRegions(new KonvaNodeManager(stage), regions, 1, 'brush', null);
const konvaLayers = stage.find<Konva.Layer>(`.${RG_LAYER_NAME}`); const konvaLayers = stage.find<Konva.Layer>(`.${RG_LAYER_NAME}`);
const blobs: Record<string, Blob> = {}; const blobs: Record<string, Blob> = {};

View File

@ -43,7 +43,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
refinerModel, refinerModel,
refinerStart, refinerStart,
} = state.canvasV2.params; } = state.canvasV2.params;
const { width, height } = state.canvasV2.document; const { width, height } = state.canvasV2.bbox;
assert(model, 'No model found in state'); assert(model, 'No model found in state');