tidy(ui): organise CL graph builder

This commit is contained in:
psychedelicious 2024-05-14 14:26:40 +10:00
parent b239891986
commit 4a1c3786a1

View File

@ -36,6 +36,20 @@ import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types'; import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
/**
* Adds the control layers to the graph
* @param state The app root state
* @param g The graph to add the layers to
* @param base The base model type
* @param denoise The main denoise node
* @param posCond The positive conditioning node
* @param negCond The negative conditioning node
* @param posCondCollect The positive conditioning collector
* @param negCondCollect The negative conditioning collector
* @param noise The noise node
* @param vaeSource The VAE source (either seamless, vae_loader, main_model_loader, or sdxl_model_loader)
* @returns A promise that resolves to the layers that were added to the graph
*/
export const addGenerationTabControlLayers = async ( export const addGenerationTabControlLayers = async (
state: RootState, state: RootState,
g: Graph, g: Graph,
@ -244,45 +258,7 @@ export const addGenerationTabControlLayers = async (
return validLayers; return validLayers;
}; };
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => { //#region Control Adapters
if (layer.uploadedMaskImage) {
const imageDTO = await getImageDTO(layer.uploadedMaskImage.name);
if (imageDTO) {
return imageDTO;
}
}
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
return imageDTO;
};
const buildControlImage = (
image: ImageWithDims | null,
processedImage: ImageWithDims | null,
processorConfig: ProcessorConfig | null
): ImageField => {
if (processedImage && processorConfig) {
// We've processed the image in the app - use it for the control image.
return {
image_name: processedImage.name,
};
} else if (image) {
// No processor selected, and we have an image - the user provided a processed image, use it for the control image.
return {
image_name: image.name,
};
}
assert(false, 'Attempted to add unprocessed control image');
};
const addGlobalControlAdapterToGraph = ( const addGlobalControlAdapterToGraph = (
controlAdapterConfig: ControlNetConfigV2 | T2IAdapterConfigV2, controlAdapterConfig: ControlNetConfigV2 | T2IAdapterConfigV2,
g: Graph, g: Graph,
@ -379,6 +355,7 @@ const addGlobalT2IAdapterToGraph = (
g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item'); g.addEdge(t2iAdapter, 't2i_adapter', t2iAdapterCollect, 'item');
}; };
//#region IP Adapter
const addIPAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => { const addIPAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
try { try {
// You see, we've already got one! // You see, we've already got one!
@ -420,7 +397,9 @@ const addGlobalIPAdapterToGraph = (
}); });
g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item'); g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item');
}; };
//#endregion
//#region Initial Image
const addInitialImageLayerToGraph = ( const addInitialImageLayerToGraph = (
state: RootState, state: RootState,
g: Graph, g: Graph,
@ -488,7 +467,9 @@ const addInitialImageLayerToGraph = (
g.upsertMetadata({ generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' }); g.upsertMetadata({ generation_mode: isSDXL ? 'sdxl_img2img' : 'img2img' });
}; };
//#endregion
//#region Layer validators
const isValidControlAdapter = (ca: ControlNetConfigV2 | T2IAdapterConfigV2, base: BaseModelType): boolean => { const isValidControlAdapter = (ca: ControlNetConfigV2 | T2IAdapterConfigV2, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image // Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ca.model); const hasModel = Boolean(ca.model);
@ -534,3 +515,45 @@ const isValidLayer = (layer: Layer, base: BaseModelType) => {
} }
return false; return false;
}; };
//#endregion
//#region Helpers
const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
if (layer.uploadedMaskImage) {
const imageDTO = await getImageDTO(layer.uploadedMaskImage.name);
if (imageDTO) {
return imageDTO;
}
}
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
const imageDTO = await req.unwrap();
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
return imageDTO;
};
const buildControlImage = (
image: ImageWithDims | null,
processedImage: ImageWithDims | null,
processorConfig: ProcessorConfig | null
): ImageField => {
if (processedImage && processorConfig) {
// We've processed the image in the app - use it for the control image.
return {
image_name: processedImage.name,
};
} else if (image) {
// No processor selected, and we have an image - the user provided a processed image, use it for the control image.
return {
image_name: image.name,
};
}
assert(false, 'Attempted to add unprocessed control image');
};
//#endregion