tidy(ui): organise graph builder files

This commit is contained in:
psychedelicious 2024-05-14 20:18:32 +10:00
parent cadea55521
commit 4c3c2297b9
10 changed files with 32 additions and 32 deletions

View File

@ -50,7 +50,7 @@ import { assert } from 'tsafe';
* @param vaeSource The VAE source (either seamless, vae_loader, main_model_loader, or sdxl_model_loader)
* @returns A promise that resolves to the layers that were added to the graph
*/
export const addGenerationTabControlLayers = async (
export const addControlLayers = async (
state: RootState,
g: Graph,
base: BaseModelType,

View File

@ -65,7 +65,7 @@ function calculateHrfRes(
* @param vaeSource The VAE source node (may be a model loader, VAE loader, or seamless node)
* @returns The HRF image output node.
*/
export const addGenerationTabHRF = (
export const addHRF = (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,

View File

@ -5,7 +5,7 @@ import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { filter, size } from 'lodash-es';
import type { Invocation, S } from 'services/api/types';
export const addGenerationTabLoRAs = (
export const addLoRAs = (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,

View File

@ -8,7 +8,7 @@ import type { Invocation } from 'services/api/types';
* @param imageOutput The current image output node
* @returns The nsfw checker node
*/
export const addGenerationTabNSFWChecker = (
export const addNSFWChecker = (
g: Graph,
imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'>
): Invocation<'img_nsfw'> => {

View File

@ -5,7 +5,7 @@ import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { filter, size } from 'lodash-es';
import type { Invocation, S } from 'services/api/types';
export const addGenerationTabSDXLLoRAs = (
export const addSDXLLoRas = (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,

View File

@ -13,7 +13,7 @@ import type { Invocation } from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
export const addGenerationTabSDXLRefiner = async (
export const addSDXLRefiner = async (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,

View File

@ -14,7 +14,7 @@ import type { Invocation } from 'services/api/types';
* @param vaeLoader The VAE loader node in the graph, if it exists
* @returns The seamless node, if it was added to the graph
*/
export const addGenerationTabSeamless = (
export const addSeamless = (
state: RootState,
g: Graph,
denoise: Invocation<'denoise_latents'>,

View File

@ -8,7 +8,7 @@ import type { Invocation } from 'services/api/types';
* @param imageOutput The image output node
* @returns The watermark node
*/
export const addGenerationTabWatermarker = (
export const addWatermarker = (
g: Graph,
imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'>
): Invocation<'img_watermark'> => {

View File

@ -14,12 +14,12 @@ import {
POSITIVE_CONDITIONING_COLLECT,
VAE_LOADER,
} from 'features/nodes/util/graph/constants';
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/generation/addGenerationTabControlLayers';
import { addGenerationTabHRF } from 'features/nodes/util/graph/generation/addGenerationTabHRF';
import { addGenerationTabLoRAs } from 'features/nodes/util/graph/generation/addGenerationTabLoRAs';
import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/generation/addGenerationTabNSFWChecker';
import { addGenerationTabSeamless } from 'features/nodes/util/graph/generation/addGenerationTabSeamless';
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/generation/addGenerationTabWatermarker';
import { addControlLayers } from 'features/nodes/util/graph/generation/addControlLayers';
import { addHRF } from 'features/nodes/util/graph/generation/addHRF';
import { addLoRAs } from 'features/nodes/util/graph/generation/addLoRAs';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
@ -143,15 +143,15 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
vae: vae ?? undefined,
});
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader);
addGenerationTabLoRAs(state, g, denoise, modelLoader, seamless, clipSkip, posCond, negCond);
addLoRAs(state, g, denoise, modelLoader, seamless, clipSkip, posCond, negCond);
// We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae');
const addedLayers = await addGenerationTabControlLayers(
const addedLayers = await addControlLayers(
state,
g,
modelConfig.base,
@ -166,15 +166,15 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
const isHRFAllowed = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
if (isHRFAllowed && state.hrf.hrfEnabled) {
imageOutput = addGenerationTabHRF(state, g, denoise, noise, l2i, vaeSource);
imageOutput = addHRF(state, g, denoise, noise, l2i, vaeSource);
}
if (state.system.shouldUseNSFWChecker) {
imageOutput = addGenerationTabNSFWChecker(g, imageOutput);
imageOutput = addNSFWChecker(g, imageOutput);
}
if (state.system.shouldUseWatermarker) {
imageOutput = addGenerationTabWatermarker(g, imageOutput);
imageOutput = addWatermarker(g, imageOutput);
}
g.setMetadataReceivingNode(imageOutput);

View File

@ -12,12 +12,12 @@ import {
SDXL_MODEL_LOADER,
VAE_LOADER,
} from 'features/nodes/util/graph/constants';
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/generation/addGenerationTabControlLayers';
import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/generation/addGenerationTabNSFWChecker';
import { addGenerationTabSDXLLoRAs } from 'features/nodes/util/graph/generation/addGenerationTabSDXLLoRAs';
import { addGenerationTabSDXLRefiner } from 'features/nodes/util/graph/generation/addGenerationTabSDXLRefiner';
import { addGenerationTabSeamless } from 'features/nodes/util/graph/generation/addGenerationTabSeamless';
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/generation/addGenerationTabWatermarker';
import { addControlLayers } from 'features/nodes/util/graph/generation/addControlLayers';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
import { addSDXLLoRas } from 'features/nodes/util/graph/generation/addSDXLLoRAs';
import { addSDXLRefiner } from 'features/nodes/util/graph/generation/addSDXLRefiner';
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Invocation, NonNullableGraph } from 'services/api/types';
@ -135,9 +135,9 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
vae: vae ?? undefined,
});
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader);
addGenerationTabSDXLLoRAs(state, g, denoise, modelLoader, seamless, posCond, negCond);
addSDXLLoRas(state, g, denoise, modelLoader, seamless, posCond, negCond);
// We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
@ -145,10 +145,10 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
// Add Refiner if enabled
if (refinerModel) {
await addGenerationTabSDXLRefiner(state, g, denoise, modelLoader, seamless, posCond, negCond, l2i);
await addSDXLRefiner(state, g, denoise, modelLoader, seamless, posCond, negCond, l2i);
}
await addGenerationTabControlLayers(
await addControlLayers(
state,
g,
modelConfig.base,
@ -162,11 +162,11 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
);
if (state.system.shouldUseNSFWChecker) {
imageOutput = addGenerationTabNSFWChecker(g, imageOutput);
imageOutput = addNSFWChecker(g, imageOutput);
}
if (state.system.shouldUseWatermarker) {
imageOutput = addGenerationTabWatermarker(g, imageOutput);
imageOutput = addWatermarker(g, imageOutput);
}
g.setMetadataReceivingNode(imageOutput);