mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: SDXL Refiner not working properly with Inpainting
This commit is contained in:
parent
40a964792d
commit
5a1f4cb1ce
@ -336,9 +336,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
latents = torch.lerp(
|
||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
||||
)
|
||||
|
||||
if is_inpainting_model(self.unet):
|
||||
if masked_latents is None:
|
||||
|
@ -1,25 +1,18 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type {
|
||||
CreateGradientMaskInvocation,
|
||||
ImageDTO,
|
||||
NonNullableGraph,
|
||||
SeamlessModeInvocation,
|
||||
} from 'services/api/types';
|
||||
import type { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types';
|
||||
import { isRefinerMainModelModelConfig } from 'services/api/types';
|
||||
|
||||
import {
|
||||
CANVAS_OUTPUT,
|
||||
INPAINT_CREATE_MASK,
|
||||
LATENTS_TO_IMAGE,
|
||||
MASK_COMBINE,
|
||||
MASK_RESIZE_UP,
|
||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
SDXL_CANVAS_INPAINT_GRAPH,
|
||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||
SDXL_MODEL_LOADER,
|
||||
SDXL_REFINER_DENOISE_LATENTS,
|
||||
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
SDXL_REFINER_MODEL_LOADER,
|
||||
SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||
SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
@ -32,9 +25,7 @@ export const addSDXLRefinerToGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string,
|
||||
modelLoaderNodeId?: string,
|
||||
canvasInitImage?: ImageDTO,
|
||||
canvasMaskImage?: ImageDTO
|
||||
modelLoaderNodeId?: string
|
||||
): Promise<void> => {
|
||||
const {
|
||||
refinerModel,
|
||||
@ -46,8 +37,6 @@ export const addSDXLRefinerToGraph = async (
|
||||
refinerStart,
|
||||
} = state.sdxl;
|
||||
|
||||
const { canvasCoherenceEdgeSize, canvasCoherenceMinDenoise, canvasCoherenceMode } = state.generation;
|
||||
|
||||
if (!refinerModel) {
|
||||
return;
|
||||
}
|
||||
@ -213,51 +202,9 @@ export const addSDXLRefinerToGraph = async (
|
||||
);
|
||||
|
||||
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH || graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
|
||||
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
|
||||
type: 'create_gradient_mask',
|
||||
id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
edge_radius: canvasCoherenceEdgeSize,
|
||||
coherence_mode: canvasCoherenceMode,
|
||||
minimum_denoise: canvasCoherenceMinDenoise,
|
||||
};
|
||||
|
||||
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH) {
|
||||
if (isUsingScaledDimensions) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
|
||||
...(graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] as CreateGradientMaskInvocation),
|
||||
mask: canvasMaskImage,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
destination: {
|
||||
|
@ -426,14 +426,7 @@ export const buildCanvasSDXLInpaintGraph = async (
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
await addSDXLRefinerToGraph(
|
||||
state,
|
||||
graph,
|
||||
SDXL_DENOISE_LATENTS,
|
||||
modelLoaderNodeId,
|
||||
canvasInitImage,
|
||||
canvasMaskImage
|
||||
);
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user