fix: SDXL Refiner not working properly with Inpainting

This commit is contained in:
blessedcoolant 2024-04-09 07:19:16 +05:30 committed by psychedelicious
parent 381b41a56e
commit fd1f240853
3 changed files with 5 additions and 68 deletions

View File

@ -336,9 +336,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
).to(device=orig_latents.device, dtype=orig_latents.dtype)
latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
if is_inpainting_model(self.unet):
if masked_latents is None:

View File

@ -1,25 +1,18 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type {
CreateGradientMaskInvocation,
ImageDTO,
NonNullableGraph,
SeamlessModeInvocation,
} from 'services/api/types';
import type { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types';
import {
CANVAS_OUTPUT,
INPAINT_CREATE_MASK,
LATENTS_TO_IMAGE,
MASK_COMBINE,
MASK_RESIZE_UP,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_MODEL_LOADER,
SDXL_REFINER_DENOISE_LATENTS,
SDXL_REFINER_INPAINT_CREATE_MASK,
SDXL_REFINER_MODEL_LOADER,
SDXL_REFINER_NEGATIVE_CONDITIONING,
SDXL_REFINER_POSITIVE_CONDITIONING,
@ -32,9 +25,7 @@ export const addSDXLRefinerToGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string,
modelLoaderNodeId?: string,
canvasInitImage?: ImageDTO,
canvasMaskImage?: ImageDTO
modelLoaderNodeId?: string
): Promise<void> => {
const {
refinerModel,
@ -46,8 +37,6 @@ export const addSDXLRefinerToGraph = async (
refinerStart,
} = state.sdxl;
const { canvasCoherenceEdgeSize, canvasCoherenceMinDenoise, canvasCoherenceMode } = state.generation;
if (!refinerModel) {
return;
}
@ -213,51 +202,9 @@ export const addSDXLRefinerToGraph = async (
);
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH || graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
type: 'create_gradient_mask',
id: SDXL_REFINER_INPAINT_CREATE_MASK,
is_intermediate: true,
edge_radius: canvasCoherenceEdgeSize,
coherence_mode: canvasCoherenceMode,
minimum_denoise: canvasCoherenceMinDenoise,
};
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH) {
if (isUsingScaledDimensions) {
graph.edges.push({
source: {
node_id: MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'mask',
},
});
} else {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
...(graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] as CreateGradientMaskInvocation),
mask: canvasMaskImage,
};
}
}
if (graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
graph.edges.push({
source: {
node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'mask',
},
});
}
graph.edges.push({
source: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
node_id: INPAINT_CREATE_MASK,
field: 'denoise_mask',
},
destination: {

View File

@ -426,14 +426,7 @@ export const buildCanvasSDXLInpaintGraph = async (
// Add Refiner if enabled
if (refinerModel) {
await addSDXLRefinerToGraph(
state,
graph,
SDXL_DENOISE_LATENTS,
modelLoaderNodeId,
canvasInitImage,
canvasMaskImage
);
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}