tidy(ui): organise CL graph builder

This commit is contained in:
psychedelicious 2024-05-14 14:26:40 +10:00
parent b239891986
commit 4a1c3786a1

View File

@ -36,6 +36,20 @@ import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
/**
* Adds the control layers to the graph
* @param state The app root state
* @param g The graph to add the layers to
* @param base The base model type
* @param denoise The main denoise node
* @param posCond The positive conditioning node
* @param negCond The negative conditioning node
* @param posCondCollect The positive conditioning collector
* @param negCondCollect The negative conditioning collector
* @param noise The noise node
* @param vaeSource The VAE source (either seamless, vae_loader, main_model_loader, or sdxl_model_loader)
* @returns A promise that resolves to the layers that were added to the graph
*/
export const addGenerationTabControlLayers = async (
state: RootState,
g: Graph,
@ -244,45 +258,7 @@ export const addGenerationTabControlLayers = async (
return validLayers;
};
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
if (layer.uploadedMaskImage) {
const imageDTO = await getImageDTO(layer.uploadedMaskImage.name);
if (imageDTO) {
return imageDTO;
}
}
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
return imageDTO;
};
const buildControlImage = (
image: ImageWithDims | null,
processedImage: ImageWithDims | null,
processorConfig: ProcessorConfig | null
): ImageField => {
if (processedImage && processorConfig) {
// We've processed the image in the app - use it for the control image.
return {
image_name: processedImage.name,
};
} else if (image) {
// No processor selected, and we have an image - the user provided a processed image, use it for the control image.
return {
image_name: image.name,
};
}
assert(false, 'Attempted to add unprocessed control image');
};
//#region Control Adapters
const addGlobalControlAdapterToGraph = (
controlAdapterConfig: ControlNetConfigV2 | T2IAdapterConfigV2,
g: Graph,
@ -379,6 +355,7 @@ const addGlobalT2IAdapterToGraph = (
g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item');
};
//#region IP Adapter
const addIPAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
try {
// You see, we've already got one!
@ -420,7 +397,9 @@ const addGlobalIPAdapterToGraph = (
});
g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item');
};
//#endregion
//#region Initial Image
const addInitialImageLayerToGraph = (
state: RootState,
g: Graph,
@ -488,7 +467,9 @@ const addInitialImageLayerToGraph = (
g.upsertMetadata({ generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' });
};
//#endregion
//#region Layer validators
const isValidControlAdapter = (ca: ControlNetConfigV2 | T2IAdapterConfigV2, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ca.model);
@ -534,3 +515,45 @@ const isValidLayer = (layer: Layer, base: BaseModelType) => {
}
return false;
};
//#endregion
//#region Helpers
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
if (layer.uploadedMaskImage) {
const imageDTO = await getImageDTO(layer.uploadedMaskImage.name);
if (imageDTO) {
return imageDTO;
}
}
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
return imageDTO;
};
const buildControlImage = (
image: ImageWithDims | null,
processedImage: ImageWithDims | null,
processorConfig: ProcessorConfig | null
): ImageField => {
if (processedImage && processorConfig) {
// We've processed the image in the app - use it for the control image.
return {
image_name: processedImage.name,
};
} else if (image) {
// No processor selected, and we have an image - the user provided a processed image, use it for the control image.
return {
image_name: image.name,
};
}
assert(false, 'Attempted to add unprocessed control image');
};
//#endregion