Merge branch 'main' into pr/4352

This commit is contained in:
blessedcoolant 2023-08-29 12:40:01 +12:00
commit 6fdeeb8ce8
34 changed files with 794 additions and 112 deletions

View File

@ -34,6 +34,7 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.seamless import set_seamless
from ...backend.model_management.models import BaseModelType from ...backend.model_management.models import BaseModelType
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
@ -456,7 +457,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet( with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader() unet_info.context.model, _lora_loader()
), unet_info as unet: ), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet:
latents = latents.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None: if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype) noise = noise.to(device=unet.device, dtype=unet.dtype)
@ -549,7 +550,7 @@ class LatentsToImageInvocation(BaseInvocation):
context=context, context=context,
) )
with vae_info as vae: with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
latents = latents.to(vae.device) latents = latents.to(vae.device)
if self.fp32: if self.fp32:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)

View File

@ -8,8 +8,8 @@ from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
FieldDescriptions, FieldDescriptions,
InputField,
Input, Input,
InputField,
InvocationContext, InvocationContext,
OutputField, OutputField,
UIType, UIType,
@ -33,6 +33,7 @@ class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel") unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading") loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
class ClipField(BaseModel): class ClipField(BaseModel):
@ -45,6 +46,7 @@ class ClipField(BaseModel):
class VaeField(BaseModel): class VaeField(BaseModel):
# TODO: better naming? # TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel") vae: ModelInfo = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
class ModelLoaderOutput(BaseInvocationOutput): class ModelLoaderOutput(BaseInvocationOutput):
@ -388,3 +390,50 @@ class VaeLoaderInvocation(BaseInvocation):
) )
) )
) )
class SeamlessModeOutput(BaseInvocationOutput):
"""Modified Seamless Model output"""
type: Literal["seamless_output"] = "seamless_output"
# Outputs
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
@title("Seamless")
@tags("seamless", "model")
class SeamlessModeInvocation(BaseInvocation):
"""Applies the seamless transformation to the Model UNet and VAE."""
type: Literal["seamless"] = "seamless"
# Inputs
unet: Optional[UNetField] = InputField(
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
)
vae: Optional[VaeField] = InputField(
default=None, description=FieldDescriptions.vae_model, input=Input.Connection, title="VAE"
)
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
unet = copy.deepcopy(self.unet)
vae = copy.deepcopy(self.vae)
seamless_axes_list = []
if self.seamless_x:
seamless_axes_list.append("x")
if self.seamless_y:
seamless_axes_list.append("y")
if unet is not None:
unet.seamless_axes = seamless_axes_list
if vae is not None:
vae.seamless_axes = seamless_axes_list
return SeamlessModeOutput(unet=unet, vae=vae)

View File

@ -20,7 +20,8 @@ def _conv_forward_asymmetric(self, input, weight, bias):
def configure_model_padding(model, seamless, seamless_axes): def configure_model_padding(model, seamless, seamless_axes):
""" """
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options. Modifies the 2D convolution layers to use a circular padding mode based on
the `seamless` and `seamless_axes` options.
""" """
# TODO: get an explicit interface for this in diffusers: https://github.com/huggingface/diffusers/issues/556 # TODO: get an explicit interface for this in diffusers: https://github.com/huggingface/diffusers/issues/556
for m in model.modules(): for m in model.modules():

View File

@ -0,0 +1,103 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import List, Union
import torch.nn as nn
from diffusers.models import AutoencoderKL, UNet2DConditionModel
def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
return nn.functional.conv2d(
working,
weight,
bias,
self.stride,
nn.modules.utils._pair(0),
self.dilation,
self.groups,
)
@contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
try:
to_restore = []
for m_name, m in model.named_modules():
if isinstance(model, UNet2DConditionModel):
if ".attentions." in m_name:
continue
if ".resnets." in m_name:
if ".conv2" in m_name:
continue
if ".conv_shortcut" in m_name:
continue
"""
if isinstance(model, UNet2DConditionModel):
if False and ".upsamplers." in m_name:
continue
if False and ".downsamplers." in m_name:
continue
if True and ".resnets." in m_name:
if True and ".conv1" in m_name:
if False and "down_blocks" in m_name:
continue
if False and "mid_block" in m_name:
continue
if False and "up_blocks" in m_name:
continue
if True and ".conv2" in m_name:
continue
if True and ".conv_shortcut" in m_name:
continue
if True and ".attentions." in m_name:
continue
if False and m_name in ["conv_in", "conv_out"]:
continue
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
print(f"applied - {m_name}")
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield
finally:
for module, orig_conv_forward in to_restore:
module._conv_forward = orig_conv_forward
if hasattr(m, "asymmetric_padding_mode"):
del m.asymmetric_padding_mode
if hasattr(m, "asymmetric_padding"):
del m.asymmetric_padding

View File

@ -761,3 +761,18 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
diffusers.ControlNetModel = ControlNetModel diffusers.ControlNetModel = ControlNetModel
diffusers.models.controlnet.ControlNetModel = ControlNetModel diffusers.models.controlnet.ControlNetModel = ControlNetModel
# patch LoRACompatibleConv to use original Conv2D forward function
# this needed to make work seamless patch
# NOTE: with this patch, torch.compile crashes on 2.0 torch(already fixed in nightly)
# https://github.com/huggingface/diffusers/pull/4315
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/lora.py#L96C18-L96C18
def new_LoRACompatibleConv_forward(self, x):
if self.lora_layer is None:
return super(diffusers.models.lora.LoRACompatibleConv, self).forward(x)
else:
return super(diffusers.models.lora.LoRACompatibleConv, self).forward(x) + self.lora_layer(x)
diffusers.models.lora.LoRACompatibleConv.forward = new_LoRACompatibleConv_forward

View File

@ -14,6 +14,7 @@ import i18n from 'i18n';
import { size } from 'lodash-es'; import { size } from 'lodash-es';
import { ReactNode, memo, useCallback, useEffect } from 'react'; import { ReactNode, memo, useCallback, useEffect } from 'react';
import { ErrorBoundary } from 'react-error-boundary'; import { ErrorBoundary } from 'react-error-boundary';
import { usePreselectedImage } from '../../features/parameters/hooks/usePreselectedImage';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback'; import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import GlobalHotkeys from './GlobalHotkeys'; import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster'; import Toaster from './Toaster';
@ -23,13 +24,22 @@ const DEFAULT_CONFIG = {};
interface Props { interface Props {
config?: PartialAppConfig; config?: PartialAppConfig;
headerComponent?: ReactNode; headerComponent?: ReactNode;
selectedImage?: {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
} }
const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => { const App = ({
config = DEFAULT_CONFIG,
headerComponent,
selectedImage,
}: Props) => {
const language = useAppSelector(languageSelector); const language = useAppSelector(languageSelector);
const logger = useLogger('system'); const logger = useLogger('system');
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { handlePreselectedImage } = usePreselectedImage();
const handleReset = useCallback(() => { const handleReset = useCallback(() => {
localStorage.clear(); localStorage.clear();
location.reload(); location.reload();
@ -51,6 +61,10 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
dispatch(appStarted()); dispatch(appStarted());
}, [dispatch]); }, [dispatch]);
useEffect(() => {
handlePreselectedImage(selectedImage);
}, [handlePreselectedImage, selectedImage]);
return ( return (
<ErrorBoundary <ErrorBoundary
onReset={handleReset} onReset={handleReset}

View File

@ -26,6 +26,10 @@ interface Props extends PropsWithChildren {
headerComponent?: ReactNode; headerComponent?: ReactNode;
middleware?: Middleware[]; middleware?: Middleware[];
projectId?: string; projectId?: string;
selectedImage?: {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
} }
const InvokeAIUI = ({ const InvokeAIUI = ({
@ -35,6 +39,7 @@ const InvokeAIUI = ({
headerComponent, headerComponent,
middleware, middleware,
projectId, projectId,
selectedImage,
}: Props) => { }: Props) => {
useEffect(() => { useEffect(() => {
// configure API client token // configure API client token
@ -81,7 +86,11 @@ const InvokeAIUI = ({
<React.Suspense fallback={<Loading />}> <React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider> <ThemeLocaleProvider>
<AppDndContext> <AppDndContext>
<App config={config} headerComponent={headerComponent} /> <App
config={config}
headerComponent={headerComponent}
selectedImage={selectedImage}
/>
</AppDndContext> </AppDndContext>
</ThemeLocaleProvider> </ThemeLocaleProvider>
</React.Suspense> </React.Suspense>

View File

@ -8,7 +8,7 @@ import {
ImageDraggableData, ImageDraggableData,
TypesafeDraggableData, TypesafeDraggableData,
} from 'features/dnd/types'; } from 'features/dnd/types';
import { useMultiselect } from 'features/gallery/hooks/useMultiselect.ts'; import { useMultiselect } from 'features/gallery/hooks/useMultiselect';
import { MouseEvent, memo, useCallback, useMemo, useState } from 'react'; import { MouseEvent, memo, useCallback, useMemo, useState } from 'react';
import { FaTrash } from 'react-icons/fa'; import { FaTrash } from 'react-icons/fa';
import { MdStar, MdStarBorder } from 'react-icons/md'; import { MdStar, MdStarBorder } from 'react-icons/md';

View File

@ -63,7 +63,7 @@ export const addDynamicPromptsToGraph = (
{ {
source: { source: {
node_id: DYNAMIC_PROMPT, node_id: DYNAMIC_PROMPT,
field: 'prompt_collection', field: 'collection',
}, },
destination: { destination: {
node_id: ITERATE, node_id: ITERATE,

View File

@ -11,9 +11,11 @@ import {
METADATA_ACCUMULATOR, METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_CANVAS_INPAINT_GRAPH, SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SEAMLESS,
} from './constants'; } from './constants';
export const addSDXLLoRAsToGraph = ( export const addSDXLLoRAsToGraph = (
@ -36,20 +38,25 @@ export const addSDXLLoRAsToGraph = (
| MetadataAccumulatorInvocation | MetadataAccumulatorInvocation
| undefined; | undefined;
// Handle Seamless Plugs
const unetLoaderId = modelLoaderNodeId;
let clipLoaderId = modelLoaderNodeId;
if ([SEAMLESS, REFINER_SEAMLESS].includes(modelLoaderNodeId)) {
clipLoaderId = SDXL_MODEL_LOADER;
}
if (loraCount > 0) { if (loraCount > 0) {
// Remove modelLoaderNodeId unet/clip/clip2 connections to feed it to LoRAs // Remove modelLoaderNodeId unet/clip/clip2 connections to feed it to LoRAs
graph.edges = graph.edges.filter( graph.edges = graph.edges.filter(
(e) => (e) =>
!( !(
e.source.node_id === modelLoaderNodeId && e.source.node_id === unetLoaderId && ['unet'].includes(e.source.field)
['unet'].includes(e.source.field)
) && ) &&
!( !(
e.source.node_id === modelLoaderNodeId && e.source.node_id === clipLoaderId && ['clip'].includes(e.source.field)
['clip'].includes(e.source.field)
) && ) &&
!( !(
e.source.node_id === modelLoaderNodeId && e.source.node_id === clipLoaderId &&
['clip2'].includes(e.source.field) ['clip2'].includes(e.source.field)
) )
); );
@ -88,7 +95,7 @@ export const addSDXLLoRAsToGraph = (
// first lora = start the lora chain, attach directly to model loader // first lora = start the lora chain, attach directly to model loader
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: modelLoaderNodeId, node_id: unetLoaderId,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -99,7 +106,7 @@ export const addSDXLLoRAsToGraph = (
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: modelLoaderNodeId, node_id: clipLoaderId,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -110,7 +117,7 @@ export const addSDXLLoRAsToGraph = (
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: modelLoaderNodeId, node_id: clipLoaderId,
field: 'clip2', field: 'clip2',
}, },
destination: { destination: {

View File

@ -1,11 +1,15 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { MetadataAccumulatorInvocation } from 'services/api/types'; import {
MetadataAccumulatorInvocation,
SeamlessModeInvocation,
} from 'services/api/types';
import { NonNullableGraph } from '../../types/types'; import { NonNullableGraph } from '../../types/types';
import { import {
CANVAS_OUTPUT, CANVAS_OUTPUT,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MASK_BLUR, MASK_BLUR,
METADATA_ACCUMULATOR, METADATA_ACCUMULATOR,
REFINER_SEAMLESS,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH, SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_CANVAS_INPAINT_GRAPH, SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
@ -21,7 +25,8 @@ import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
export const addSDXLRefinerToGraph = ( export const addSDXLRefinerToGraph = (
state: RootState, state: RootState,
graph: NonNullableGraph, graph: NonNullableGraph,
baseNodeId: string baseNodeId: string,
modelLoaderNodeId?: string
): void => { ): void => {
const { const {
refinerModel, refinerModel,
@ -33,6 +38,8 @@ export const addSDXLRefinerToGraph = (
refinerStart, refinerStart,
} = state.sdxl; } = state.sdxl;
const { seamlessXAxis, seamlessYAxis } = state.generation;
if (!refinerModel) { if (!refinerModel) {
return; return;
} }
@ -53,6 +60,10 @@ export const addSDXLRefinerToGraph = (
metadataAccumulator.refiner_steps = refinerSteps; metadataAccumulator.refiner_steps = refinerSteps;
} }
const modelLoaderId = modelLoaderNodeId
? modelLoaderNodeId
: SDXL_MODEL_LOADER;
// Construct Style Prompt // Construct Style Prompt
const { craftedPositiveStylePrompt, craftedNegativeStylePrompt } = const { craftedPositiveStylePrompt, craftedNegativeStylePrompt } =
craftSDXLStylePrompt(state, true); craftSDXLStylePrompt(state, true);
@ -65,10 +76,7 @@ export const addSDXLRefinerToGraph = (
graph.edges = graph.edges.filter( graph.edges = graph.edges.filter(
(e) => (e) =>
!( !(e.source.node_id === modelLoaderId && ['vae'].includes(e.source.field))
e.source.node_id === SDXL_MODEL_LOADER &&
['vae'].includes(e.source.field)
)
); );
graph.nodes[SDXL_REFINER_MODEL_LOADER] = { graph.nodes[SDXL_REFINER_MODEL_LOADER] = {
@ -98,8 +106,39 @@ export const addSDXLRefinerToGraph = (
denoising_end: 1, denoising_end: 1,
}; };
// Add Seamless To Refiner
if (seamlessXAxis || seamlessYAxis) {
graph.nodes[REFINER_SEAMLESS] = {
id: REFINER_SEAMLESS,
type: 'seamless',
seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis,
} as SeamlessModeInvocation;
graph.edges.push( graph.edges.push(
{ {
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: REFINER_SEAMLESS,
field: 'unet',
},
},
{
source: {
node_id: REFINER_SEAMLESS,
field: 'unet',
},
destination: {
node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'unet',
},
}
);
} else {
graph.edges.push({
source: { source: {
node_id: SDXL_REFINER_MODEL_LOADER, node_id: SDXL_REFINER_MODEL_LOADER,
field: 'unet', field: 'unet',
@ -108,7 +147,10 @@ export const addSDXLRefinerToGraph = (
node_id: SDXL_REFINER_DENOISE_LATENTS, node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'unet', field: 'unet',
}, },
}, });
}
graph.edges.push(
{ {
source: { source: {
node_id: SDXL_REFINER_MODEL_LOADER, node_id: SDXL_REFINER_MODEL_LOADER,

View File

@ -0,0 +1,109 @@
import { RootState } from 'app/store/store';
import { SeamlessModeInvocation } from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import {
CANVAS_COHERENCE_DENOISE_LATENTS,
CANVAS_INPAINT_GRAPH,
CANVAS_OUTPAINT_GRAPH,
DENOISE_LATENTS,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_DENOISE_LATENTS,
SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_TEXT_TO_IMAGE_GRAPH,
SEAMLESS,
} from './constants';
export const addSeamlessToLinearGraph = (
state: RootState,
graph: NonNullableGraph,
modelLoaderNodeId: string
): void => {
// Remove Existing UNet Connections
const { seamlessXAxis, seamlessYAxis } = state.generation;
graph.nodes[SEAMLESS] = {
id: SEAMLESS,
type: 'seamless',
seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis,
} as SeamlessModeInvocation;
let denoisingNodeId = DENOISE_LATENTS;
if (
graph.id === SDXL_TEXT_TO_IMAGE_GRAPH ||
graph.id === SDXL_IMAGE_TO_IMAGE_GRAPH ||
graph.id === SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH ||
graph.id === SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH ||
graph.id === SDXL_CANVAS_INPAINT_GRAPH ||
graph.id === SDXL_CANVAS_OUTPAINT_GRAPH
) {
denoisingNodeId = SDXL_DENOISE_LATENTS;
}
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === modelLoaderNodeId &&
['unet'].includes(e.source.field)
) &&
!(
e.source.node_id === modelLoaderNodeId &&
['vae'].includes(e.source.field)
)
);
graph.edges.push(
{
source: {
node_id: modelLoaderNodeId,
field: 'unet',
},
destination: {
node_id: SEAMLESS,
field: 'unet',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'vae',
},
destination: {
node_id: SEAMLESS,
field: 'vae',
},
},
{
source: {
node_id: SEAMLESS,
field: 'unet',
},
destination: {
node_id: denoisingNodeId,
field: 'unet',
},
}
);
if (
graph.id == CANVAS_INPAINT_GRAPH ||
graph.id === CANVAS_OUTPAINT_GRAPH ||
graph.id === SDXL_CANVAS_INPAINT_GRAPH ||
graph.id === SDXL_CANVAS_OUTPAINT_GRAPH
) {
graph.edges.push({
source: {
node_id: SEAMLESS,
field: 'unet',
},
destination: {
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
field: 'unet',
},
});
}
};

View File

@ -7,6 +7,7 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
@ -22,6 +23,7 @@ import {
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
SEAMLESS,
} from './constants'; } from './constants';
/** /**
@ -44,6 +46,8 @@ export const buildCanvasImageToImageGraph = (
clipSkip, clipSkip,
shouldUseCpuNoise, shouldUseCpuNoise,
shouldUseNoiseSettings, shouldUseNoiseSettings,
seamlessXAxis,
seamlessYAxis,
} = state.generation; } = state.generation;
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
@ -64,6 +68,8 @@ export const buildCanvasImageToImageGraph = (
throw new Error('No model found in state'); throw new Error('No model found in state');
} }
let modelLoaderNodeId = MAIN_MODEL_LOADER;
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise; : initialGenerationState.shouldUseCpuNoise;
@ -81,9 +87,9 @@ export const buildCanvasImageToImageGraph = (
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: CANVAS_IMAGE_TO_IMAGE_GRAPH, id: CANVAS_IMAGE_TO_IMAGE_GRAPH,
nodes: { nodes: {
[MAIN_MODEL_LOADER]: { [modelLoaderNodeId]: {
type: 'main_model_loader', type: 'main_model_loader',
id: MAIN_MODEL_LOADER, id: modelLoaderNodeId,
is_intermediate: true, is_intermediate: true,
model, model,
}, },
@ -142,7 +148,7 @@ export const buildCanvasImageToImageGraph = (
// Connect Model Loader to CLIP Skip and UNet // Connect Model Loader to CLIP Skip and UNet
{ {
source: { source: {
node_id: MAIN_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -152,7 +158,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MAIN_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -340,11 +346,17 @@ export const buildCanvasImageToImageGraph = (
}, },
}); });
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// add LoRA support // add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS); addLoRAsToGraph(state, graph, DENOISE_LATENTS);
// optionally add custom VAE // optionally add custom VAE
addVAEToGraph(state, graph, MAIN_MODEL_LOADER); addVAEToGraph(state, graph, modelLoaderNodeId);
// add dynamic prompts - also sets up core iteration and seed // add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph); addDynamicPromptsToGraph(state, graph);

View File

@ -13,6 +13,7 @@ import {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
@ -38,6 +39,7 @@ import {
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
RANGE_OF_SIZE, RANGE_OF_SIZE,
SEAMLESS,
} from './constants'; } from './constants';
/** /**
@ -68,6 +70,8 @@ export const buildCanvasInpaintGraph = (
canvasCoherenceSteps, canvasCoherenceSteps,
canvasCoherenceStrength, canvasCoherenceStrength,
clipSkip, clipSkip,
seamlessXAxis,
seamlessYAxis,
} = state.generation; } = state.generation;
if (!model) { if (!model) {
@ -85,6 +89,8 @@ export const buildCanvasInpaintGraph = (
shouldAutoSave, shouldAutoSave,
} = state.canvas; } = state.canvas;
let modelLoaderNodeId = MAIN_MODEL_LOADER;
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
: shouldUseCpuNoise; : shouldUseCpuNoise;
@ -92,9 +98,9 @@ export const buildCanvasInpaintGraph = (
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: CANVAS_INPAINT_GRAPH, id: CANVAS_INPAINT_GRAPH,
nodes: { nodes: {
[MAIN_MODEL_LOADER]: { [modelLoaderNodeId]: {
type: 'main_model_loader', type: 'main_model_loader',
id: MAIN_MODEL_LOADER, id: modelLoaderNodeId,
is_intermediate: true, is_intermediate: true,
model, model,
}, },
@ -204,7 +210,7 @@ export const buildCanvasInpaintGraph = (
// Connect Model Loader to CLIP Skip and UNet // Connect Model Loader to CLIP Skip and UNet
{ {
source: { source: {
node_id: MAIN_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -214,7 +220,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MAIN_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -349,7 +355,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MAIN_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -595,11 +601,17 @@ export const buildCanvasInpaintGraph = (
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed; (graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
} }
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// Add VAE // Add VAE
addVAEToGraph(state, graph, MAIN_MODEL_LOADER); addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support // add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS, MAIN_MODEL_LOADER); addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);

View File

@ -14,6 +14,7 @@ import {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
@ -43,6 +44,7 @@ import {
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
RANGE_OF_SIZE, RANGE_OF_SIZE,
SEAMLESS,
} from './constants'; } from './constants';
/** /**
@ -75,6 +77,8 @@ export const buildCanvasOutpaintGraph = (
tileSize, tileSize,
infillMethod, infillMethod,
clipSkip, clipSkip,
seamlessXAxis,
seamlessYAxis,
} = state.generation; } = state.generation;
if (!model) { if (!model) {
@ -92,6 +96,8 @@ export const buildCanvasOutpaintGraph = (
shouldAutoSave, shouldAutoSave,
} = state.canvas; } = state.canvas;
let modelLoaderNodeId = MAIN_MODEL_LOADER;
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
: shouldUseCpuNoise; : shouldUseCpuNoise;
@ -99,9 +105,9 @@ export const buildCanvasOutpaintGraph = (
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: CANVAS_OUTPAINT_GRAPH, id: CANVAS_OUTPAINT_GRAPH,
nodes: { nodes: {
[MAIN_MODEL_LOADER]: { [modelLoaderNodeId]: {
type: 'main_model_loader', type: 'main_model_loader',
id: MAIN_MODEL_LOADER, id: modelLoaderNodeId,
is_intermediate: true, is_intermediate: true,
model, model,
}, },
@ -222,7 +228,7 @@ export const buildCanvasOutpaintGraph = (
// Connect Model Loader To UNet & Clip Skip // Connect Model Loader To UNet & Clip Skip
{ {
source: { source: {
node_id: MAIN_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -232,7 +238,7 @@ export const buildCanvasOutpaintGraph = (
}, },
{ {
source: { source: {
node_id: MAIN_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -389,7 +395,7 @@ export const buildCanvasOutpaintGraph = (
}, },
{ {
source: { source: {
node_id: MAIN_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -732,11 +738,17 @@ export const buildCanvasOutpaintGraph = (
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed; (graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
} }
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// Add VAE // Add VAE
addVAEToGraph(state, graph, MAIN_MODEL_LOADER); addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support // add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS, MAIN_MODEL_LOADER); addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);

View File

@ -8,6 +8,7 @@ import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
@ -19,9 +20,11 @@ import {
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH, SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -44,6 +47,8 @@ export const buildCanvasSDXLImageToImageGraph = (
clipSkip, clipSkip,
shouldUseCpuNoise, shouldUseCpuNoise,
shouldUseNoiseSettings, shouldUseNoiseSettings,
seamlessXAxis,
seamlessYAxis,
} = state.generation; } = state.generation;
const { const {
@ -71,6 +76,9 @@ export const buildCanvasSDXLImageToImageGraph = (
throw new Error('No model found in state'); throw new Error('No model found in state');
} }
// Model Loader ID
let modelLoaderNodeId = SDXL_MODEL_LOADER;
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise; : initialGenerationState.shouldUseCpuNoise;
@ -92,9 +100,9 @@ export const buildCanvasSDXLImageToImageGraph = (
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH, id: SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
nodes: { nodes: {
[SDXL_MODEL_LOADER]: { [modelLoaderNodeId]: {
type: 'sdxl_model_loader', type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER, id: modelLoaderNodeId,
model, model,
}, },
[POSITIVE_CONDITIONING]: { [POSITIVE_CONDITIONING]: {
@ -144,7 +152,7 @@ export const buildCanvasSDXLImageToImageGraph = (
// Connect Model Loader To UNet & CLIP // Connect Model Loader To UNet & CLIP
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -154,7 +162,7 @@ export const buildCanvasSDXLImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -164,7 +172,7 @@ export const buildCanvasSDXLImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip2', field: 'clip2',
}, },
destination: { destination: {
@ -174,7 +182,7 @@ export const buildCanvasSDXLImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -184,7 +192,7 @@ export const buildCanvasSDXLImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip2', field: 'clip2',
}, },
destination: { destination: {
@ -351,16 +359,23 @@ export const buildCanvasSDXLImageToImageGraph = (
}, },
}); });
// add LoRA support // Add Seamless To Graph
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, SDXL_MODEL_LOADER); if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// Add Refiner if enabled // Add Refiner if enabled
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
modelLoaderNodeId = REFINER_SEAMLESS;
} }
// optionally add custom VAE // optionally add custom VAE
addVAEToGraph(state, graph, SDXL_MODEL_LOADER); addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add dynamic prompts - also sets up core iteration and seed // add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph); addDynamicPromptsToGraph(state, graph);

View File

@ -14,6 +14,7 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
@ -35,9 +36,11 @@ import {
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
RANGE_OF_SIZE, RANGE_OF_SIZE,
REFINER_SEAMLESS,
SDXL_CANVAS_INPAINT_GRAPH, SDXL_CANVAS_INPAINT_GRAPH,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -67,6 +70,8 @@ export const buildCanvasSDXLInpaintGraph = (
maskBlurMethod, maskBlurMethod,
canvasCoherenceSteps, canvasCoherenceSteps,
canvasCoherenceStrength, canvasCoherenceStrength,
seamlessXAxis,
seamlessYAxis,
} = state.generation; } = state.generation;
const { const {
@ -91,6 +96,8 @@ export const buildCanvasSDXLInpaintGraph = (
shouldAutoSave, shouldAutoSave,
} = state.canvas; } = state.canvas;
let modelLoaderNodeId = SDXL_MODEL_LOADER;
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
: shouldUseCpuNoise; : shouldUseCpuNoise;
@ -102,9 +109,9 @@ export const buildCanvasSDXLInpaintGraph = (
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: SDXL_CANVAS_INPAINT_GRAPH, id: SDXL_CANVAS_INPAINT_GRAPH,
nodes: { nodes: {
[SDXL_MODEL_LOADER]: { [modelLoaderNodeId]: {
type: 'sdxl_model_loader', type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER, id: modelLoaderNodeId,
model, model,
}, },
[POSITIVE_CONDITIONING]: { [POSITIVE_CONDITIONING]: {
@ -209,7 +216,7 @@ export const buildCanvasSDXLInpaintGraph = (
// Connect Model Loader to UNet and CLIP // Connect Model Loader to UNet and CLIP
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -219,7 +226,7 @@ export const buildCanvasSDXLInpaintGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -229,7 +236,7 @@ export const buildCanvasSDXLInpaintGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip2', field: 'clip2',
}, },
destination: { destination: {
@ -239,7 +246,7 @@ export const buildCanvasSDXLInpaintGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -249,7 +256,7 @@ export const buildCanvasSDXLInpaintGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip2', field: 'clip2',
}, },
destination: { destination: {
@ -363,7 +370,7 @@ export const buildCanvasSDXLInpaintGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -609,16 +616,28 @@ export const buildCanvasSDXLInpaintGraph = (
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed; (graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
} }
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// Add Refiner if enabled // Add Refiner if enabled
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, CANVAS_COHERENCE_DENOISE_LATENTS); addSDXLRefinerToGraph(
state,
graph,
CANVAS_COHERENCE_DENOISE_LATENTS,
modelLoaderNodeId
);
modelLoaderNodeId = REFINER_SEAMLESS;
} }
// optionally add custom VAE // optionally add custom VAE
addVAEToGraph(state, graph, SDXL_MODEL_LOADER); addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support // add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, SDXL_MODEL_LOADER); addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);

View File

@ -15,6 +15,7 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
@ -40,9 +41,11 @@ import {
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
RANGE_OF_SIZE, RANGE_OF_SIZE,
REFINER_SEAMLESS,
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -74,6 +77,8 @@ export const buildCanvasSDXLOutpaintGraph = (
canvasCoherenceStrength, canvasCoherenceStrength,
tileSize, tileSize,
infillMethod, infillMethod,
seamlessXAxis,
seamlessYAxis,
} = state.generation; } = state.generation;
const { const {
@ -98,6 +103,8 @@ export const buildCanvasSDXLOutpaintGraph = (
shouldAutoSave, shouldAutoSave,
} = state.canvas; } = state.canvas;
let modelLoaderNodeId = SDXL_MODEL_LOADER;
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
: shouldUseCpuNoise; : shouldUseCpuNoise;
@ -747,16 +754,28 @@ export const buildCanvasSDXLOutpaintGraph = (
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed; (graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
} }
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// Add Refiner if enabled // Add Refiner if enabled
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, CANVAS_COHERENCE_DENOISE_LATENTS); addSDXLRefinerToGraph(
state,
graph,
CANVAS_COHERENCE_DENOISE_LATENTS,
modelLoaderNodeId
);
modelLoaderNodeId = REFINER_SEAMLESS;
} }
// optionally add custom VAE // optionally add custom VAE
addVAEToGraph(state, graph, SDXL_MODEL_LOADER); addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support // add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, SDXL_MODEL_LOADER); addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);

View File

@ -11,6 +11,7 @@ import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
@ -21,9 +22,11 @@ import {
NOISE, NOISE,
ONNX_MODEL_LOADER, ONNX_MODEL_LOADER,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH, SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -45,6 +48,8 @@ export const buildCanvasSDXLTextToImageGraph = (
clipSkip, clipSkip,
shouldUseCpuNoise, shouldUseCpuNoise,
shouldUseNoiseSettings, shouldUseNoiseSettings,
seamlessXAxis,
seamlessYAxis,
} = state.generation; } = state.generation;
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
@ -74,7 +79,7 @@ export const buildCanvasSDXLTextToImageGraph = (
const isUsingOnnxModel = model.model_type === 'onnx'; const isUsingOnnxModel = model.model_type === 'onnx';
const modelLoaderNodeId = isUsingOnnxModel let modelLoaderNodeId = isUsingOnnxModel
? ONNX_MODEL_LOADER ? ONNX_MODEL_LOADER
: SDXL_MODEL_LOADER; : SDXL_MODEL_LOADER;
@ -334,9 +339,16 @@ export const buildCanvasSDXLTextToImageGraph = (
}, },
}); });
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// Add Refiner if enabled // Add Refiner if enabled
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
modelLoaderNodeId = REFINER_SEAMLESS;
} }
// add LoRA support // add LoRA support

View File

@ -10,6 +10,7 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
@ -24,6 +25,7 @@ import {
NOISE, NOISE,
ONNX_MODEL_LOADER, ONNX_MODEL_LOADER,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
SEAMLESS,
} from './constants'; } from './constants';
/** /**
@ -44,6 +46,8 @@ export const buildCanvasTextToImageGraph = (
clipSkip, clipSkip,
shouldUseCpuNoise, shouldUseCpuNoise,
shouldUseNoiseSettings, shouldUseNoiseSettings,
seamlessXAxis,
seamlessYAxis,
} = state.generation; } = state.generation;
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
@ -70,7 +74,7 @@ export const buildCanvasTextToImageGraph = (
const isUsingOnnxModel = model.model_type === 'onnx'; const isUsingOnnxModel = model.model_type === 'onnx';
const modelLoaderNodeId = isUsingOnnxModel let modelLoaderNodeId = isUsingOnnxModel
? ONNX_MODEL_LOADER ? ONNX_MODEL_LOADER
: MAIN_MODEL_LOADER; : MAIN_MODEL_LOADER;
@ -321,6 +325,12 @@ export const buildCanvasTextToImageGraph = (
}, },
}); });
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// optionally add custom VAE // optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId); addVAEToGraph(state, graph, modelLoaderNodeId);

View File

@ -10,6 +10,7 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
@ -24,6 +25,7 @@ import {
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RESIZE, RESIZE,
SEAMLESS,
} from './constants'; } from './constants';
/** /**
@ -49,6 +51,8 @@ export const buildLinearImageToImageGraph = (
shouldUseCpuNoise, shouldUseCpuNoise,
shouldUseNoiseSettings, shouldUseNoiseSettings,
vaePrecision, vaePrecision,
seamlessXAxis,
seamlessYAxis,
} = state.generation; } = state.generation;
// TODO: add batch functionality // TODO: add batch functionality
@ -80,6 +84,8 @@ export const buildLinearImageToImageGraph = (
throw new Error('No model found in state'); throw new Error('No model found in state');
} }
let modelLoaderNodeId = MAIN_MODEL_LOADER;
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise; : initialGenerationState.shouldUseCpuNoise;
@ -88,9 +94,9 @@ export const buildLinearImageToImageGraph = (
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: IMAGE_TO_IMAGE_GRAPH, id: IMAGE_TO_IMAGE_GRAPH,
nodes: { nodes: {
[MAIN_MODEL_LOADER]: { [modelLoaderNodeId]: {
type: 'main_model_loader', type: 'main_model_loader',
id: MAIN_MODEL_LOADER, id: modelLoaderNodeId,
model, model,
}, },
[CLIP_SKIP]: { [CLIP_SKIP]: {
@ -141,7 +147,7 @@ export const buildLinearImageToImageGraph = (
// Connect Model Loader to UNet and CLIP Skip // Connect Model Loader to UNet and CLIP Skip
{ {
source: { source: {
node_id: MAIN_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -151,7 +157,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MAIN_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -338,11 +344,17 @@ export const buildLinearImageToImageGraph = (
}, },
}); });
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// optionally add custom VAE // optionally add custom VAE
addVAEToGraph(state, graph, MAIN_MODEL_LOADER); addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support // add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS); addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add dynamic prompts - also sets up core iteration and seed // add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph); addDynamicPromptsToGraph(state, graph);

View File

@ -11,6 +11,7 @@ import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
@ -20,10 +21,12 @@ import {
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
RESIZE, RESIZE,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_IMAGE_TO_IMAGE_GRAPH, SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -49,6 +52,8 @@ export const buildLinearSDXLImageToImageGraph = (
shouldUseCpuNoise, shouldUseCpuNoise,
shouldUseNoiseSettings, shouldUseNoiseSettings,
vaePrecision, vaePrecision,
seamlessXAxis,
seamlessYAxis,
} = state.generation; } = state.generation;
const { const {
@ -79,6 +84,9 @@ export const buildLinearSDXLImageToImageGraph = (
throw new Error('No model found in state'); throw new Error('No model found in state');
} }
// Model Loader ID
let modelLoaderNodeId = SDXL_MODEL_LOADER;
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise ? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise; : initialGenerationState.shouldUseCpuNoise;
@ -91,9 +99,9 @@ export const buildLinearSDXLImageToImageGraph = (
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: SDXL_IMAGE_TO_IMAGE_GRAPH, id: SDXL_IMAGE_TO_IMAGE_GRAPH,
nodes: { nodes: {
[SDXL_MODEL_LOADER]: { [modelLoaderNodeId]: {
type: 'sdxl_model_loader', type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER, id: modelLoaderNodeId,
model, model,
}, },
[POSITIVE_CONDITIONING]: { [POSITIVE_CONDITIONING]: {
@ -143,7 +151,7 @@ export const buildLinearSDXLImageToImageGraph = (
// Connect Model Loader to UNet, CLIP & VAE // Connect Model Loader to UNet, CLIP & VAE
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -153,7 +161,7 @@ export const buildLinearSDXLImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -163,7 +171,7 @@ export const buildLinearSDXLImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip2', field: 'clip2',
}, },
destination: { destination: {
@ -173,7 +181,7 @@ export const buildLinearSDXLImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -183,7 +191,7 @@ export const buildLinearSDXLImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip2', field: 'clip2',
}, },
destination: { destination: {
@ -351,15 +359,23 @@ export const buildLinearSDXLImageToImageGraph = (
}, },
}); });
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, SDXL_MODEL_LOADER); // Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// Add Refiner if enabled // Add Refiner if enabled
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
modelLoaderNodeId = REFINER_SEAMLESS;
} }
// optionally add custom VAE // optionally add custom VAE
addVAEToGraph(state, graph, SDXL_MODEL_LOADER); addVAEToGraph(state, graph, modelLoaderNodeId);
// Add LoRA Support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);

View File

@ -7,6 +7,7 @@ import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
@ -15,9 +16,11 @@ import {
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_TEXT_TO_IMAGE_GRAPH, SDXL_TEXT_TO_IMAGE_GRAPH,
SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -38,6 +41,8 @@ export const buildLinearSDXLTextToImageGraph = (
shouldUseCpuNoise, shouldUseCpuNoise,
shouldUseNoiseSettings, shouldUseNoiseSettings,
vaePrecision, vaePrecision,
seamlessXAxis,
seamlessYAxis,
} = state.generation; } = state.generation;
const { const {
@ -61,6 +66,9 @@ export const buildLinearSDXLTextToImageGraph = (
const { craftedPositiveStylePrompt, craftedNegativeStylePrompt } = const { craftedPositiveStylePrompt, craftedNegativeStylePrompt } =
craftSDXLStylePrompt(state, shouldConcatSDXLStylePrompt); craftSDXLStylePrompt(state, shouldConcatSDXLStylePrompt);
// Model Loader ID
let modelLoaderNodeId = SDXL_MODEL_LOADER;
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node * full graph here as a template. Then use the parameters from app state and set friendlier node
@ -74,9 +82,9 @@ export const buildLinearSDXLTextToImageGraph = (
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: SDXL_TEXT_TO_IMAGE_GRAPH, id: SDXL_TEXT_TO_IMAGE_GRAPH,
nodes: { nodes: {
[SDXL_MODEL_LOADER]: { [modelLoaderNodeId]: {
type: 'sdxl_model_loader', type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER, id: modelLoaderNodeId,
model, model,
}, },
[POSITIVE_CONDITIONING]: { [POSITIVE_CONDITIONING]: {
@ -117,7 +125,7 @@ export const buildLinearSDXLTextToImageGraph = (
// Connect Model Loader to UNet, VAE & CLIP // Connect Model Loader to UNet, VAE & CLIP
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -127,7 +135,7 @@ export const buildLinearSDXLTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -137,7 +145,7 @@ export const buildLinearSDXLTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip2', field: 'clip2',
}, },
destination: { destination: {
@ -147,7 +155,7 @@ export const buildLinearSDXLTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -157,7 +165,7 @@ export const buildLinearSDXLTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: SDXL_MODEL_LOADER, node_id: modelLoaderNodeId,
field: 'clip2', field: 'clip2',
}, },
destination: { destination: {
@ -244,16 +252,23 @@ export const buildLinearSDXLTextToImageGraph = (
}, },
}); });
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// Add Refiner if enabled // Add Refiner if enabled
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
modelLoaderNodeId = REFINER_SEAMLESS;
} }
// optionally add custom VAE // optionally add custom VAE
addVAEToGraph(state, graph, SDXL_MODEL_LOADER); addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support // add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, SDXL_MODEL_LOADER); addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);

View File

@ -10,6 +10,7 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
@ -22,6 +23,7 @@ import {
NOISE, NOISE,
ONNX_MODEL_LOADER, ONNX_MODEL_LOADER,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
SEAMLESS,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
} from './constants'; } from './constants';
@ -42,6 +44,8 @@ export const buildLinearTextToImageGraph = (
shouldUseCpuNoise, shouldUseCpuNoise,
shouldUseNoiseSettings, shouldUseNoiseSettings,
vaePrecision, vaePrecision,
seamlessXAxis,
seamlessYAxis,
} = state.generation; } = state.generation;
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
@ -55,7 +59,7 @@ export const buildLinearTextToImageGraph = (
const isUsingOnnxModel = model.model_type === 'onnx'; const isUsingOnnxModel = model.model_type === 'onnx';
const modelLoaderNodeId = isUsingOnnxModel let modelLoaderNodeId = isUsingOnnxModel
? ONNX_MODEL_LOADER ? ONNX_MODEL_LOADER
: MAIN_MODEL_LOADER; : MAIN_MODEL_LOADER;
@ -258,6 +262,12 @@ export const buildLinearTextToImageGraph = (
}, },
}); });
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// optionally add custom VAE // optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId); addVAEToGraph(state, graph, modelLoaderNodeId);

View File

@ -56,6 +56,8 @@ export const SDXL_REFINER_POSITIVE_CONDITIONING =
export const SDXL_REFINER_NEGATIVE_CONDITIONING = export const SDXL_REFINER_NEGATIVE_CONDITIONING =
'sdxl_refiner_negative_conditioning'; 'sdxl_refiner_negative_conditioning';
export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents'; export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
export const SEAMLESS = 'seamless';
export const REFINER_SEAMLESS = 'refiner_seamless';
// friendly graph ids // friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph'; export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';

View File

@ -0,0 +1,81 @@
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { t } from 'i18next';
import { useCallback, useState } from 'react';
import { useAppToaster } from '../../../app/components/Toaster';
import { useAppDispatch } from '../../../app/store/storeHooks';
import {
useGetImageDTOQuery,
useGetImageMetadataQuery,
} from '../../../services/api/endpoints/images';
import { setInitialCanvasImage } from '../../canvas/store/canvasSlice';
import { setActiveTab } from '../../ui/store/uiSlice';
import { initialImageSelected } from '../store/actions';
import { useRecallParameters } from './useRecallParameters';
type SelectedImage = {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
export const usePreselectedImage = () => {
const dispatch = useAppDispatch();
const [imageNameForDto, setImageNameForDto] = useState<string | undefined>();
const [imageNameForMetadata, setImageNameForMetadata] = useState<
string | undefined
>();
const { recallAllParameters } = useRecallParameters();
const toaster = useAppToaster();
const { currentData: selectedImageDto } = useGetImageDTOQuery(
imageNameForDto ?? skipToken
);
const { currentData: selectedImageMetadata } = useGetImageMetadataQuery(
imageNameForMetadata ?? skipToken
);
const handlePreselectedImage = useCallback(
(selectedImage?: SelectedImage) => {
if (!selectedImage) {
return;
}
if (selectedImage.action === 'sendToCanvas') {
setImageNameForDto(selectedImage?.imageName);
if (selectedImageDto) {
dispatch(setInitialCanvasImage(selectedImageDto));
dispatch(setActiveTab('unifiedCanvas'));
toaster({
title: t('toast.sentToUnifiedCanvas'),
status: 'info',
duration: 2500,
isClosable: true,
});
}
}
if (selectedImage.action === 'sendToImg2Img') {
setImageNameForDto(selectedImage?.imageName);
if (selectedImageDto) {
dispatch(initialImageSelected(selectedImageDto));
}
}
if (selectedImage.action === 'useAllParameters') {
setImageNameForMetadata(selectedImage?.imageName);
if (selectedImageMetadata) {
recallAllParameters(selectedImageMetadata.metadata);
}
}
},
[
dispatch,
selectedImageDto,
selectedImageMetadata,
recallAllParameters,
toaster,
]
);
return { handlePreselectedImage };
};

View File

@ -2,6 +2,7 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import { memo } from 'react'; import { memo } from 'react';
import ParamSDXLPromptArea from './ParamSDXLPromptArea'; import ParamSDXLPromptArea from './ParamSDXLPromptArea';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse'; import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
@ -17,6 +18,7 @@ const SDXLImageToImageTabParameters = () => {
<ParamLoraCollapse /> <ParamLoraCollapse />
<ParamDynamicPromptsCollapse /> <ParamDynamicPromptsCollapse />
<ParamNoiseCollapse /> <ParamNoiseCollapse />
<ParamSeamlessCollapse />
</> </>
); );
}; };

View File

@ -2,6 +2,7 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters'; import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
import { memo } from 'react'; import { memo } from 'react';
import ParamSDXLPromptArea from './ParamSDXLPromptArea'; import ParamSDXLPromptArea from './ParamSDXLPromptArea';
@ -17,6 +18,7 @@ const SDXLTextToImageTabParameters = () => {
<ParamLoraCollapse /> <ParamLoraCollapse />
<ParamDynamicPromptsCollapse /> <ParamDynamicPromptsCollapse />
<ParamNoiseCollapse /> <ParamNoiseCollapse />
<ParamSeamlessCollapse />
</> </>
); );
}; };

View File

@ -5,6 +5,7 @@ import ParamMaskAdjustmentCollapse from 'features/parameters/components/Paramete
import ParamCanvasCoherencePassCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamCanvasCoherencePassCollapse'; import ParamCanvasCoherencePassCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamCanvasCoherencePassCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ParamSDXLPromptArea from './ParamSDXLPromptArea'; import ParamSDXLPromptArea from './ParamSDXLPromptArea';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse'; import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
import SDXLUnifiedCanvasTabCoreParameters from './SDXLUnifiedCanvasTabCoreParameters'; import SDXLUnifiedCanvasTabCoreParameters from './SDXLUnifiedCanvasTabCoreParameters';
@ -22,6 +23,7 @@ export default function SDXLUnifiedCanvasTabParameters() {
<ParamMaskAdjustmentCollapse /> <ParamMaskAdjustmentCollapse />
<ParamInfillAndScalingCollapse /> <ParamInfillAndScalingCollapse />
<ParamCanvasCoherencePassCollapse /> <ParamCanvasCoherencePassCollapse />
<ParamSeamlessCollapse />
</> </>
); );
} }

View File

@ -9,7 +9,6 @@ export const initialConfigState: AppConfig = {
disabledFeatures: ['lightbox', 'faceRestore', 'batches'], disabledFeatures: ['lightbox', 'faceRestore', 'batches'],
disabledSDFeatures: [ disabledSDFeatures: [
'variation', 'variation',
'seamless',
'symmetry', 'symmetry',
'hires', 'hires',
'perlinNoise', 'perlinNoise',

View File

@ -6,6 +6,7 @@ import ParamMaskAdjustmentCollapse from 'features/parameters/components/Paramete
import ParamCanvasCoherencePassCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamCanvasCoherencePassCollapse'; import ParamCanvasCoherencePassCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamCanvasCoherencePassCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamPromptArea from 'features/parameters/components/Parameters/Prompt/ParamPromptArea'; import ParamPromptArea from 'features/parameters/components/Parameters/Prompt/ParamPromptArea';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
import { memo } from 'react'; import { memo } from 'react';
import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters'; import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
@ -22,6 +23,7 @@ const UnifiedCanvasParameters = () => {
<ParamMaskAdjustmentCollapse /> <ParamMaskAdjustmentCollapse />
<ParamInfillAndScalingCollapse /> <ParamInfillAndScalingCollapse />
<ParamCanvasCoherencePassCollapse /> <ParamCanvasCoherencePassCollapse />
<ParamSeamlessCollapse />
<ParamAdvancedCollapse /> <ParamAdvancedCollapse />
</> </>
); );

File diff suppressed because one or more lines are too long

View File

@ -130,6 +130,7 @@ export type ESRGANInvocation = s['ESRGANInvocation'];
export type DivideInvocation = s['DivideInvocation']; export type DivideInvocation = s['DivideInvocation'];
export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation']; export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation'];
export type ImageWatermarkInvocation = s['ImageWatermarkInvocation']; export type ImageWatermarkInvocation = s['ImageWatermarkInvocation'];
export type SeamlessModeInvocation = s['SeamlessModeInvocation'];
// ControlNet Nodes // ControlNet Nodes
export type ControlNetInvocation = s['ControlNetInvocation']; export type ControlNetInvocation = s['ControlNetInvocation'];