tidy(ui): organise graph builder files

This commit is contained in:
psychedelicious 2024-05-14 20:18:32 +10:00
parent cadea55521
commit 4c3c2297b9
10 changed files with 32 additions and 32 deletions

View File

@ -50,7 +50,7 @@ import { assert } from 'tsafe';
* @param vaeSource The VAE source (either seamless, vae_loader, main_model_loader, or sdxl_model_loader) * @param vaeSource The VAE source (either seamless, vae_loader, main_model_loader, or sdxl_model_loader)
* @returns A promise that resolves to the layers that were added to the graph * @returns A promise that resolves to the layers that were added to the graph
*/ */
export const addGenerationTabControlLayers = async ( export const addControlLayers = async (
state: RootState, state: RootState,
g: Graph, g: Graph,
base: BaseModelType, base: BaseModelType,

View File

@ -65,7 +65,7 @@ function calculateHrfRes(
* @param vaeSource The VAE source node (may be a model loader, VAE loader, or seamless node) * @param vaeSource The VAE source node (may be a model loader, VAE loader, or seamless node)
* @returns The HRF image output node. * @returns The HRF image output node.
*/ */
export const addGenerationTabHRF = ( export const addHRF = (
state: RootState, state: RootState,
g: Graph, g: Graph,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'>,

View File

@ -5,7 +5,7 @@ import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { filter, size } from 'lodash-es'; import { filter, size } from 'lodash-es';
import type { Invocation, S } from 'services/api/types'; import type { Invocation, S } from 'services/api/types';
export const addGenerationTabLoRAs = ( export const addLoRAs = (
state: RootState, state: RootState,
g: Graph, g: Graph,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'>,

View File

@ -8,7 +8,7 @@ import type { Invocation } from 'services/api/types';
* @param imageOutput The current image output node * @param imageOutput The current image output node
* @returns The nsfw checker node * @returns The nsfw checker node
*/ */
export const addGenerationTabNSFWChecker = ( export const addNSFWChecker = (
g: Graph, g: Graph,
imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'>
): Invocation<'img_nsfw'> => { ): Invocation<'img_nsfw'> => {

View File

@ -5,7 +5,7 @@ import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { filter, size } from 'lodash-es'; import { filter, size } from 'lodash-es';
import type { Invocation, S } from 'services/api/types'; import type { Invocation, S } from 'services/api/types';
export const addGenerationTabSDXLLoRAs = ( export const addSDXLLoRas = (
state: RootState, state: RootState,
g: Graph, g: Graph,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'>,

View File

@ -13,7 +13,7 @@ import type { Invocation } from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types'; import { isRefinerMainModelModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
export const addGenerationTabSDXLRefiner = async ( export const addSDXLRefiner = async (
state: RootState, state: RootState,
g: Graph, g: Graph,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'>,

View File

@ -14,7 +14,7 @@ import type { Invocation } from 'services/api/types';
* @param vaeLoader The VAE loader node in the graph, if it exists * @param vaeLoader The VAE loader node in the graph, if it exists
* @returns The seamless node, if it was added to the graph * @returns The seamless node, if it was added to the graph
*/ */
export const addGenerationTabSeamless = ( export const addSeamless = (
state: RootState, state: RootState,
g: Graph, g: Graph,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'>,

View File

@ -8,7 +8,7 @@ import type { Invocation } from 'services/api/types';
* @param imageOutput The image output node * @param imageOutput The image output node
* @returns The watermark node * @returns The watermark node
*/ */
export const addGenerationTabWatermarker = ( export const addWatermarker = (
g: Graph, g: Graph,
imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'>
): Invocation<'img_watermark'> => { ): Invocation<'img_watermark'> => {

View File

@ -14,12 +14,12 @@ import {
POSITIVE_CONDITIONING_COLLECT, POSITIVE_CONDITIONING_COLLECT,
VAE_LOADER, VAE_LOADER,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/generation/addGenerationTabControlLayers'; import { addControlLayers } from 'features/nodes/util/graph/generation/addControlLayers';
import { addGenerationTabHRF } from 'features/nodes/util/graph/generation/addGenerationTabHRF'; import { addHRF } from 'features/nodes/util/graph/generation/addHRF';
import { addGenerationTabLoRAs } from 'features/nodes/util/graph/generation/addGenerationTabLoRAs'; import { addLoRAs } from 'features/nodes/util/graph/generation/addLoRAs';
import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/generation/addGenerationTabNSFWChecker'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
import { addGenerationTabSeamless } from 'features/nodes/util/graph/generation/addGenerationTabSeamless'; import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/generation/addGenerationTabWatermarker'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import type { GraphType } from 'features/nodes/util/graph/generation/Graph'; import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
@ -143,15 +143,15 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
vae: vae ?? undefined, vae: vae ?? undefined,
}); });
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader); const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader);
addGenerationTabLoRAs(state, g, denoise, modelLoader, seamless, clipSkip, posCond, negCond); addLoRAs(state, g, denoise, modelLoader, seamless, clipSkip, posCond, negCond);
// We might get the VAE from the main model, custom VAE, or seamless node. // We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader; const vaeSource = seamless ?? vaeLoader ?? modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae'); g.addEdge(vaeSource, 'vae', l2i, 'vae');
const addedLayers = await addGenerationTabControlLayers( const addedLayers = await addControlLayers(
state, state,
g, g,
modelConfig.base, modelConfig.base,
@ -166,15 +166,15 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
const isHRFAllowed = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l)); const isHRFAllowed = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
if (isHRFAllowed && state.hrf.hrfEnabled) { if (isHRFAllowed && state.hrf.hrfEnabled) {
imageOutput = addGenerationTabHRF(state, g, denoise, noise, l2i, vaeSource); imageOutput = addHRF(state, g, denoise, noise, l2i, vaeSource);
} }
if (state.system.shouldUseNSFWChecker) { if (state.system.shouldUseNSFWChecker) {
imageOutput = addGenerationTabNSFWChecker(g, imageOutput); imageOutput = addNSFWChecker(g, imageOutput);
} }
if (state.system.shouldUseWatermarker) { if (state.system.shouldUseWatermarker) {
imageOutput = addGenerationTabWatermarker(g, imageOutput); imageOutput = addWatermarker(g, imageOutput);
} }
g.setMetadataReceivingNode(imageOutput); g.setMetadataReceivingNode(imageOutput);

View File

@ -12,12 +12,12 @@ import {
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
VAE_LOADER, VAE_LOADER,
} from 'features/nodes/util/graph/constants'; } from 'features/nodes/util/graph/constants';
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/generation/addGenerationTabControlLayers'; import { addControlLayers } from 'features/nodes/util/graph/generation/addControlLayers';
import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/generation/addGenerationTabNSFWChecker'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
import { addGenerationTabSDXLLoRAs } from 'features/nodes/util/graph/generation/addGenerationTabSDXLLoRAs'; import { addSDXLLoRas } from 'features/nodes/util/graph/generation/addSDXLLoRAs';
import { addGenerationTabSDXLRefiner } from 'features/nodes/util/graph/generation/addGenerationTabSDXLRefiner'; import { addSDXLRefiner } from 'features/nodes/util/graph/generation/addSDXLRefiner';
import { addGenerationTabSeamless } from 'features/nodes/util/graph/generation/addGenerationTabSeamless'; import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/generation/addGenerationTabWatermarker'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils'; import { getBoardField, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Invocation, NonNullableGraph } from 'services/api/types'; import type { Invocation, NonNullableGraph } from 'services/api/types';
@ -135,9 +135,9 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
vae: vae ?? undefined, vae: vae ?? undefined,
}); });
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader); const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader);
addGenerationTabSDXLLoRAs(state, g, denoise, modelLoader, seamless, posCond, negCond); addSDXLLoRas(state, g, denoise, modelLoader, seamless, posCond, negCond);
// We might get the VAE from the main model, custom VAE, or seamless node. // We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader; const vaeSource = seamless ?? vaeLoader ?? modelLoader;
@ -145,10 +145,10 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
// Add Refiner if enabled // Add Refiner if enabled
if (refinerModel) { if (refinerModel) {
await addGenerationTabSDXLRefiner(state, g, denoise, modelLoader, seamless, posCond, negCond, l2i); await addSDXLRefiner(state, g, denoise, modelLoader, seamless, posCond, negCond, l2i);
} }
await addGenerationTabControlLayers( await addControlLayers(
state, state,
g, g,
modelConfig.base, modelConfig.base,
@ -162,11 +162,11 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
); );
if (state.system.shouldUseNSFWChecker) { if (state.system.shouldUseNSFWChecker) {
imageOutput = addGenerationTabNSFWChecker(g, imageOutput); imageOutput = addNSFWChecker(g, imageOutput);
} }
if (state.system.shouldUseWatermarker) { if (state.system.shouldUseWatermarker) {
imageOutput = addGenerationTabWatermarker(g, imageOutput); imageOutput = addWatermarker(g, imageOutput);
} }
g.setMetadataReceivingNode(imageOutput); g.setMetadataReceivingNode(imageOutput);