feat(ui): handle control adapter processed images

- Add helper functions to build metadata for control adapters, including the processed images
- Update parses to parse the new metadata
This commit is contained in:
psychedelicious 2024-03-14 16:46:18 +11:00 committed by Kent Keirsey
parent c24f2046e7
commit 21621eebf0
4 changed files with 187 additions and 71 deletions

View File

@ -225,7 +225,14 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
const control_model = await getProperty(metadataItem, 'control_model');
const key = await getModelKey(control_model, 'controlnet');
const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
const image = zControlField.shape.image.nullish().catch(null).parse(await getProperty(metadataItem, 'image'));
const image = zControlField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'image'));
const processedImage = zControlField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'processed_image'));
const control_weight = zControlField.shape.control_weight
.nullish()
.catch(null)
@ -259,7 +266,7 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
controlMode: control_mode ?? initialControlNet.controlMode,
resizeMode: resize_mode ?? initialControlNet.resizeMode,
controlImage: image?.image_name ?? null,
processedControlImage: image?.image_name ?? null,
processedControlImage: processedImage?.image_name ?? null,
processorType,
processorNode,
shouldAutoConfig: true,
@ -283,8 +290,18 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
const image = zT2IAdapterField.shape.image.nullish().catch(null).parse(await getProperty(metadataItem, 'image'));
const weight = zT2IAdapterField.shape.weight.nullish().catch(null).parse(await getProperty(metadataItem, 'weight'));
const image = zT2IAdapterField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'image'));
const processedImage = zT2IAdapterField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'processed_image'));
const weight = zT2IAdapterField.shape.weight
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'weight'));
const begin_step_percent = zT2IAdapterField.shape.begin_step_percent
.nullish()
.catch(null)
@ -309,7 +326,7 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
resizeMode: resize_mode ?? initialT2IAdapter.resizeMode,
controlImage: image?.image_name ?? null,
processedControlImage: image?.image_name ?? null,
processedControlImage: processedImage?.image_name ?? null,
processorType,
processorNode,
shouldAutoConfig: true,
@ -333,8 +350,14 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
const key = await getModelKey(ip_adapter_model, 'ip_adapter');
const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
const image = zIPAdapterField.shape.image.nullish().catch(null).parse(await getProperty(metadataItem, 'image'));
const weight = zIPAdapterField.shape.weight.nullish().catch(null).parse(await getProperty(metadataItem, 'weight'));
const image = zIPAdapterField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'image'));
const weight = zIPAdapterField.shape.weight
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'weight'));
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
.nullish()
.catch(null)

View File

@ -1,11 +1,15 @@
import type { RootState } from 'app/store/store';
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterProcessorType, ControlNetConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import type {
CollectInvocation,
ControlNetInvocation,
CoreMetadataInvocation,
NonNullableGraph,
S,
} from 'services/api/types';
import { assert } from 'tsafe';
import { CONTROL_NET_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
@ -70,34 +74,12 @@ export const addControlNetToLinearGraph = async (
resize_mode: resizeMode,
control_model: model,
control_weight: weight,
image: buildControlImage(controlImage, processedControlImage, processorType),
};
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
controlNetNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
controlNetNode.image = {
image_name: controlImage,
};
} else {
// Skip CAs without an unprocessed image - should never happen, we already filtered the list of valid CAs
return;
}
graph.nodes[controlNetNode.id] = controlNetNode;
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
controlNetMetadata.push({
control_model: model,
control_weight: weight,
control_mode: controlMode,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
image: controlNetNode.image,
});
controlNetMetadata.push(buildControlNetMetadata(controlNet));
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
@ -110,3 +92,62 @@ export const addControlNetToLinearGraph = async (
upsertMetadata(graph, { controlnets: controlNetMetadata });
}
};
const buildControlImage = (
controlImage: string | null,
processedControlImage: string | null,
processorType: ControlAdapterProcessorType
): ImageField => {
let image: ImageField | null = null;
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
image = {
image_name: controlImage,
};
}
assert(image, 'ControlNet image is required');
return image;
};
const buildControlNetMetadata = (controlNet: ControlNetConfig): S['ControlNetMetadataField'] => {
const {
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
controlMode,
resizeMode,
model,
processorType,
weight,
} = controlNet;
assert(model, 'ControlNet model is required');
const processed_image =
processedControlImage && processorType !== 'none'
? {
image_name: processedControlImage,
}
: null;
assert(controlImage, 'ControlNet image is required');
return {
control_model: model,
control_weight: weight,
control_mode: controlMode,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
image: {
image_name: controlImage,
},
processed_image,
};
};

View File

@ -1,11 +1,15 @@
import type { RootState } from 'app/store/store';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import type {
CollectInvocation,
CoreMetadataInvocation,
IPAdapterInvocation,
NonNullableGraph,
S,
} from 'services/api/types';
import { assert } from 'tsafe';
import { IP_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
@ -44,7 +48,10 @@ export const addIPAdapterToLinearGraph = async (
if (!ipAdapter.model) {
return;
}
const { id, weight, model, beginStepPct, endStepPct } = ipAdapter;
const { id, weight, model, beginStepPct, endStepPct, controlImage } = ipAdapter;
assert(controlImage, 'IP Adapter image is required');
const ipAdapterNode: IPAdapterInvocation = {
id: `ip_adapter_${id}`,
type: 'ip_adapter',
@ -53,25 +60,14 @@ export const addIPAdapterToLinearGraph = async (
ip_adapter_model: model,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image: {
image_name: controlImage,
},
};
if (ipAdapter.controlImage) {
ipAdapterNode.image = {
image_name: ipAdapter.controlImage,
};
} else {
return;
}
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
ipAdapterMetdata.push({
weight: weight,
ip_adapter_model: model,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image: ipAdapterNode.image,
});
ipAdapterMetdata.push(buildIPAdapterMetadata(ipAdapter));
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
@ -85,3 +81,27 @@ export const addIPAdapterToLinearGraph = async (
upsertMetadata(graph, { ipAdapters: ipAdapterMetdata });
}
};
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
const { controlImage, beginStepPct, endStepPct, model, weight } = ipAdapter;
assert(model, 'IP Adapter model is required');
let image: ImageField | null = null;
if (controlImage) {
image = {
image_name: controlImage,
};
}
assert(image, 'IP Adapter image is required');
return {
ip_adapter_model: model,
weight,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image,
};
};

View File

@ -1,11 +1,15 @@
import type { RootState } from 'app/store/store';
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlAdapterProcessorType, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import type {
CollectInvocation,
CoreMetadataInvocation,
NonNullableGraph,
S,
T2IAdapterInvocation,
} from 'services/api/types';
import { assert } from 'tsafe';
import { T2I_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
@ -68,33 +72,12 @@ export const addT2IAdaptersToLinearGraph = async (
resize_mode: resizeMode,
t2i_adapter_model: model,
weight: weight,
image: buildControlImage(controlImage, processedControlImage, processorType),
};
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
t2iAdapterNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
t2iAdapterNode.image = {
image_name: controlImage,
};
} else {
// Skip CAs without an unprocessed image - should never happen, we already filtered the list of valid CAs
return;
}
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
t2iAdapterMetadata.push({
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
t2i_adapter_model: t2iAdapter.model,
weight: weight,
image: t2iAdapterNode.image,
});
t2iAdapterMetadata.push(buildT2IAdapterMetadata(t2iAdapter));
graph.edges.push({
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
@ -108,3 +91,52 @@ export const addT2IAdaptersToLinearGraph = async (
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetadata });
}
};
const buildControlImage = (
controlImage: string | null,
processedControlImage: string | null,
processorType: ControlAdapterProcessorType
): ImageField => {
let image: ImageField | null = null;
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
image = {
image_name: controlImage,
};
}
assert(image, 'T2I Adapter image is required');
return image;
};
const buildT2IAdapterMetadata = (t2iAdapter: T2IAdapterConfig): S['T2IAdapterMetadataField'] => {
const { controlImage, processedControlImage, beginStepPct, endStepPct, resizeMode, model, processorType, weight } =
t2iAdapter;
assert(model, 'T2I Adapter model is required');
const processed_image =
processedControlImage && processorType !== 'none'
? {
image_name: processedControlImage,
}
: null;
assert(controlImage, 'T2I Adapter image is required');
return {
t2i_adapter_model: model,
weight,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
image: {
image_name: controlImage,
},
processed_image,
};
};