mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(ui): organise graph builder files
This commit is contained in:
parent
cadea55521
commit
4c3c2297b9
@ -50,7 +50,7 @@ import { assert } from 'tsafe';
|
|||||||
* @param vaeSource The VAE source (either seamless, vae_loader, main_model_loader, or sdxl_model_loader)
|
* @param vaeSource The VAE source (either seamless, vae_loader, main_model_loader, or sdxl_model_loader)
|
||||||
* @returns A promise that resolves to the layers that were added to the graph
|
* @returns A promise that resolves to the layers that were added to the graph
|
||||||
*/
|
*/
|
||||||
export const addGenerationTabControlLayers = async (
|
export const addControlLayers = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
g: Graph,
|
g: Graph,
|
||||||
base: BaseModelType,
|
base: BaseModelType,
|
@ -65,7 +65,7 @@ function calculateHrfRes(
|
|||||||
* @param vaeSource The VAE source node (may be a model loader, VAE loader, or seamless node)
|
* @param vaeSource The VAE source node (may be a model loader, VAE loader, or seamless node)
|
||||||
* @returns The HRF image output node.
|
* @returns The HRF image output node.
|
||||||
*/
|
*/
|
||||||
export const addGenerationTabHRF = (
|
export const addHRF = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
g: Graph,
|
g: Graph,
|
||||||
denoise: Invocation<'denoise_latents'>,
|
denoise: Invocation<'denoise_latents'>,
|
@ -5,7 +5,7 @@ import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
|||||||
import { filter, size } from 'lodash-es';
|
import { filter, size } from 'lodash-es';
|
||||||
import type { Invocation, S } from 'services/api/types';
|
import type { Invocation, S } from 'services/api/types';
|
||||||
|
|
||||||
export const addGenerationTabLoRAs = (
|
export const addLoRAs = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
g: Graph,
|
g: Graph,
|
||||||
denoise: Invocation<'denoise_latents'>,
|
denoise: Invocation<'denoise_latents'>,
|
@ -8,7 +8,7 @@ import type { Invocation } from 'services/api/types';
|
|||||||
* @param imageOutput The current image output node
|
* @param imageOutput The current image output node
|
||||||
* @returns The nsfw checker node
|
* @returns The nsfw checker node
|
||||||
*/
|
*/
|
||||||
export const addGenerationTabNSFWChecker = (
|
export const addNSFWChecker = (
|
||||||
g: Graph,
|
g: Graph,
|
||||||
imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'>
|
imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'>
|
||||||
): Invocation<'img_nsfw'> => {
|
): Invocation<'img_nsfw'> => {
|
@ -5,7 +5,7 @@ import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
|||||||
import { filter, size } from 'lodash-es';
|
import { filter, size } from 'lodash-es';
|
||||||
import type { Invocation, S } from 'services/api/types';
|
import type { Invocation, S } from 'services/api/types';
|
||||||
|
|
||||||
export const addGenerationTabSDXLLoRAs = (
|
export const addSDXLLoRas = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
g: Graph,
|
g: Graph,
|
||||||
denoise: Invocation<'denoise_latents'>,
|
denoise: Invocation<'denoise_latents'>,
|
@ -13,7 +13,7 @@ import type { Invocation } from 'services/api/types';
|
|||||||
import { isRefinerMainModelModelConfig } from 'services/api/types';
|
import { isRefinerMainModelModelConfig } from 'services/api/types';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
export const addGenerationTabSDXLRefiner = async (
|
export const addSDXLRefiner = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
g: Graph,
|
g: Graph,
|
||||||
denoise: Invocation<'denoise_latents'>,
|
denoise: Invocation<'denoise_latents'>,
|
@ -14,7 +14,7 @@ import type { Invocation } from 'services/api/types';
|
|||||||
* @param vaeLoader The VAE loader node in the graph, if it exists
|
* @param vaeLoader The VAE loader node in the graph, if it exists
|
||||||
* @returns The seamless node, if it was added to the graph
|
* @returns The seamless node, if it was added to the graph
|
||||||
*/
|
*/
|
||||||
export const addGenerationTabSeamless = (
|
export const addSeamless = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
g: Graph,
|
g: Graph,
|
||||||
denoise: Invocation<'denoise_latents'>,
|
denoise: Invocation<'denoise_latents'>,
|
@ -8,7 +8,7 @@ import type { Invocation } from 'services/api/types';
|
|||||||
* @param imageOutput The image output node
|
* @param imageOutput The image output node
|
||||||
* @returns The watermark node
|
* @returns The watermark node
|
||||||
*/
|
*/
|
||||||
export const addGenerationTabWatermarker = (
|
export const addWatermarker = (
|
||||||
g: Graph,
|
g: Graph,
|
||||||
imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'>
|
imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'>
|
||||||
): Invocation<'img_watermark'> => {
|
): Invocation<'img_watermark'> => {
|
@ -14,12 +14,12 @@ import {
|
|||||||
POSITIVE_CONDITIONING_COLLECT,
|
POSITIVE_CONDITIONING_COLLECT,
|
||||||
VAE_LOADER,
|
VAE_LOADER,
|
||||||
} from 'features/nodes/util/graph/constants';
|
} from 'features/nodes/util/graph/constants';
|
||||||
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/generation/addGenerationTabControlLayers';
|
import { addControlLayers } from 'features/nodes/util/graph/generation/addControlLayers';
|
||||||
import { addGenerationTabHRF } from 'features/nodes/util/graph/generation/addGenerationTabHRF';
|
import { addHRF } from 'features/nodes/util/graph/generation/addHRF';
|
||||||
import { addGenerationTabLoRAs } from 'features/nodes/util/graph/generation/addGenerationTabLoRAs';
|
import { addLoRAs } from 'features/nodes/util/graph/generation/addLoRAs';
|
||||||
import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/generation/addGenerationTabNSFWChecker';
|
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
|
||||||
import { addGenerationTabSeamless } from 'features/nodes/util/graph/generation/addGenerationTabSeamless';
|
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
|
||||||
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/generation/addGenerationTabWatermarker';
|
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||||
import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
|
import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
|
||||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||||
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
@ -143,15 +143,15 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
|
|||||||
vae: vae ?? undefined,
|
vae: vae ?? undefined,
|
||||||
});
|
});
|
||||||
|
|
||||||
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
|
const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader);
|
||||||
|
|
||||||
addGenerationTabLoRAs(state, g, denoise, modelLoader, seamless, clipSkip, posCond, negCond);
|
addLoRAs(state, g, denoise, modelLoader, seamless, clipSkip, posCond, negCond);
|
||||||
|
|
||||||
// We might get the VAE from the main model, custom VAE, or seamless node.
|
// We might get the VAE from the main model, custom VAE, or seamless node.
|
||||||
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
|
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
|
||||||
g.addEdge(vaeSource, 'vae', l2i, 'vae');
|
g.addEdge(vaeSource, 'vae', l2i, 'vae');
|
||||||
|
|
||||||
const addedLayers = await addGenerationTabControlLayers(
|
const addedLayers = await addControlLayers(
|
||||||
state,
|
state,
|
||||||
g,
|
g,
|
||||||
modelConfig.base,
|
modelConfig.base,
|
||||||
@ -166,15 +166,15 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
|
|||||||
|
|
||||||
const isHRFAllowed = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
|
const isHRFAllowed = !addedLayers.some((l) => isInitialImageLayer(l) || isRegionalGuidanceLayer(l));
|
||||||
if (isHRFAllowed && state.hrf.hrfEnabled) {
|
if (isHRFAllowed && state.hrf.hrfEnabled) {
|
||||||
imageOutput = addGenerationTabHRF(state, g, denoise, noise, l2i, vaeSource);
|
imageOutput = addHRF(state, g, denoise, noise, l2i, vaeSource);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
imageOutput = addGenerationTabNSFWChecker(g, imageOutput);
|
imageOutput = addNSFWChecker(g, imageOutput);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (state.system.shouldUseWatermarker) {
|
if (state.system.shouldUseWatermarker) {
|
||||||
imageOutput = addGenerationTabWatermarker(g, imageOutput);
|
imageOutput = addWatermarker(g, imageOutput);
|
||||||
}
|
}
|
||||||
|
|
||||||
g.setMetadataReceivingNode(imageOutput);
|
g.setMetadataReceivingNode(imageOutput);
|
||||||
|
@ -12,12 +12,12 @@ import {
|
|||||||
SDXL_MODEL_LOADER,
|
SDXL_MODEL_LOADER,
|
||||||
VAE_LOADER,
|
VAE_LOADER,
|
||||||
} from 'features/nodes/util/graph/constants';
|
} from 'features/nodes/util/graph/constants';
|
||||||
import { addGenerationTabControlLayers } from 'features/nodes/util/graph/generation/addGenerationTabControlLayers';
|
import { addControlLayers } from 'features/nodes/util/graph/generation/addControlLayers';
|
||||||
import { addGenerationTabNSFWChecker } from 'features/nodes/util/graph/generation/addGenerationTabNSFWChecker';
|
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
|
||||||
import { addGenerationTabSDXLLoRAs } from 'features/nodes/util/graph/generation/addGenerationTabSDXLLoRAs';
|
import { addSDXLLoRas } from 'features/nodes/util/graph/generation/addSDXLLoRAs';
|
||||||
import { addGenerationTabSDXLRefiner } from 'features/nodes/util/graph/generation/addGenerationTabSDXLRefiner';
|
import { addSDXLRefiner } from 'features/nodes/util/graph/generation/addSDXLRefiner';
|
||||||
import { addGenerationTabSeamless } from 'features/nodes/util/graph/generation/addGenerationTabSeamless';
|
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
|
||||||
import { addGenerationTabWatermarker } from 'features/nodes/util/graph/generation/addGenerationTabWatermarker';
|
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||||
import { getBoardField, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getBoardField, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
import type { Invocation, NonNullableGraph } from 'services/api/types';
|
import type { Invocation, NonNullableGraph } from 'services/api/types';
|
||||||
@ -135,9 +135,9 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
vae: vae ?? undefined,
|
vae: vae ?? undefined,
|
||||||
});
|
});
|
||||||
|
|
||||||
const seamless = addGenerationTabSeamless(state, g, denoise, modelLoader, vaeLoader);
|
const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader);
|
||||||
|
|
||||||
addGenerationTabSDXLLoRAs(state, g, denoise, modelLoader, seamless, posCond, negCond);
|
addSDXLLoRas(state, g, denoise, modelLoader, seamless, posCond, negCond);
|
||||||
|
|
||||||
// We might get the VAE from the main model, custom VAE, or seamless node.
|
// We might get the VAE from the main model, custom VAE, or seamless node.
|
||||||
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
|
const vaeSource = seamless ?? vaeLoader ?? modelLoader;
|
||||||
@ -145,10 +145,10 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
|
|
||||||
// Add Refiner if enabled
|
// Add Refiner if enabled
|
||||||
if (refinerModel) {
|
if (refinerModel) {
|
||||||
await addGenerationTabSDXLRefiner(state, g, denoise, modelLoader, seamless, posCond, negCond, l2i);
|
await addSDXLRefiner(state, g, denoise, modelLoader, seamless, posCond, negCond, l2i);
|
||||||
}
|
}
|
||||||
|
|
||||||
await addGenerationTabControlLayers(
|
await addControlLayers(
|
||||||
state,
|
state,
|
||||||
g,
|
g,
|
||||||
modelConfig.base,
|
modelConfig.base,
|
||||||
@ -162,11 +162,11 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
);
|
);
|
||||||
|
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
imageOutput = addGenerationTabNSFWChecker(g, imageOutput);
|
imageOutput = addNSFWChecker(g, imageOutput);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (state.system.shouldUseWatermarker) {
|
if (state.system.shouldUseWatermarker) {
|
||||||
imageOutput = addGenerationTabWatermarker(g, imageOutput);
|
imageOutput = addWatermarker(g, imageOutput);
|
||||||
}
|
}
|
||||||
|
|
||||||
g.setMetadataReceivingNode(imageOutput);
|
g.setMetadataReceivingNode(imageOutput);
|
||||||
|
Loading…
Reference in New Issue
Block a user