fix: SDXL Refiner not working properly with Inpainting

This commit is contained in:
blessedcoolant 2024-04-09 07:19:16 +05:30
parent 40a964792d
commit 5a1f4cb1ce
3 changed files with 5 additions and 68 deletions

View File

@ -336,9 +336,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
).to(device=orig_latents.device, dtype=orig_latents.dtype) ).to(device=orig_latents.device, dtype=orig_latents.dtype)
latents = self.scheduler.add_noise(latents, noise, batched_t) latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
if is_inpainting_model(self.unet): if is_inpainting_model(self.unet):
if masked_latents is None: if masked_latents is None:

View File

@ -1,25 +1,18 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { import type { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types';
CreateGradientMaskInvocation,
ImageDTO,
NonNullableGraph,
SeamlessModeInvocation,
} from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types'; import { isRefinerMainModelModelConfig } from 'services/api/types';
import { import {
CANVAS_OUTPUT, CANVAS_OUTPUT,
INPAINT_CREATE_MASK,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MASK_COMBINE,
MASK_RESIZE_UP,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH, SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_CANVAS_INPAINT_GRAPH, SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH, SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_DENOISE_LATENTS, SDXL_REFINER_DENOISE_LATENTS,
SDXL_REFINER_INPAINT_CREATE_MASK,
SDXL_REFINER_MODEL_LOADER, SDXL_REFINER_MODEL_LOADER,
SDXL_REFINER_NEGATIVE_CONDITIONING, SDXL_REFINER_NEGATIVE_CONDITIONING,
SDXL_REFINER_POSITIVE_CONDITIONING, SDXL_REFINER_POSITIVE_CONDITIONING,
@ -32,9 +25,7 @@ export const addSDXLRefinerToGraph = async (
state: RootState, state: RootState,
graph: NonNullableGraph, graph: NonNullableGraph,
baseNodeId: string, baseNodeId: string,
modelLoaderNodeId?: string, modelLoaderNodeId?: string
canvasInitImage?: ImageDTO,
canvasMaskImage?: ImageDTO
): Promise<void> => { ): Promise<void> => {
const { const {
refinerModel, refinerModel,
@ -46,8 +37,6 @@ export const addSDXLRefinerToGraph = async (
refinerStart, refinerStart,
} = state.sdxl; } = state.sdxl;
const { canvasCoherenceEdgeSize, canvasCoherenceMinDenoise, canvasCoherenceMode } = state.generation;
if (!refinerModel) { if (!refinerModel) {
return; return;
} }
@ -213,51 +202,9 @@ export const addSDXLRefinerToGraph = async (
); );
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH || graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) { if (graph.id === SDXL_CANVAS_INPAINT_GRAPH || graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
type: 'create_gradient_mask',
id: SDXL_REFINER_INPAINT_CREATE_MASK,
is_intermediate: true,
edge_radius: canvasCoherenceEdgeSize,
coherence_mode: canvasCoherenceMode,
minimum_denoise: canvasCoherenceMinDenoise,
};
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH) {
if (isUsingScaledDimensions) {
graph.edges.push({ graph.edges.push({
source: { source: {
node_id: MASK_RESIZE_UP, node_id: INPAINT_CREATE_MASK,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'mask',
},
});
} else {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
...(graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] as CreateGradientMaskInvocation),
mask: canvasMaskImage,
};
}
}
if (graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
graph.edges.push({
source: {
node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'mask',
},
});
}
graph.edges.push({
source: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'denoise_mask', field: 'denoise_mask',
}, },
destination: { destination: {

View File

@ -426,14 +426,7 @@ export const buildCanvasSDXLInpaintGraph = async (
// Add Refiner if enabled // Add Refiner if enabled
if (refinerModel) { if (refinerModel) {
await addSDXLRefinerToGraph( await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
state,
graph,
SDXL_DENOISE_LATENTS,
modelLoaderNodeId,
canvasInitImage,
canvasMaskImage
);
if (seamlessXAxis || seamlessYAxis) { if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS; modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
} }