rename: Inpaint Mask to Denoise Mask

This commit is contained in:
blessedcoolant 2023-08-27 05:50:13 +12:00
parent 226721ce51
commit c923d094c6
15 changed files with 137 additions and 137 deletions

View File

@ -21,10 +21,10 @@ from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import ( from invokeai.app.invocations.primitives import (
DenoiseMaskField,
DenoiseMaskOutput,
ImageField, ImageField,
ImageOutput, ImageOutput,
InpaintMaskField,
InpaintMaskOutput,
LatentsField, LatentsField,
LatentsOutput, LatentsOutput,
build_latents_output, build_latents_output,
@ -57,16 +57,16 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@title("Create Inpaint Mask") @title("Create Denoise Mask")
@tags("mask", "inpaint") @tags("mask", "denoise")
class CreateInpaintMaskInvocation(BaseInvocation): class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for inpaint model run.""" """Creates mask for denoising model run."""
# Metadata # Metadata
type: Literal["create_inpaint_mask"] = "create_inpaint_mask" type: Literal["create_denoise_mask"] = "create_denoise_mask"
# Inputs # Inputs
image: Optional[ImageField] = InputField(default=None, description="Image which will be inpainted") image: Optional[ImageField] = InputField(default=None, description="Image which will be masked")
mask: ImageField = InputField(description="The mask to use when pasting") mask: ImageField = InputField(description="The mask to use when pasting")
vae: VaeField = InputField( vae: VaeField = InputField(
description=FieldDescriptions.vae, description=FieldDescriptions.vae,
@ -86,7 +86,7 @@ class CreateInpaintMaskInvocation(BaseInvocation):
return mask_tensor return mask_tensor
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> InpaintMaskOutput: def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None: if self.image is not None:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
image = image_resized_to_grid_as_tensor(image.convert("RGB")) image = image_resized_to_grid_as_tensor(image.convert("RGB"))
@ -118,8 +118,8 @@ class CreateInpaintMaskInvocation(BaseInvocation):
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask" mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
context.services.latents.save(mask_name, mask) context.services.latents.save(mask_name, mask)
return InpaintMaskOutput( return DenoiseMaskOutput(
inpaint_mask=InpaintMaskField( denoise_mask=DenoiseMaskField(
mask_name=mask_name, mask_name=mask_name,
masked_latents_name=masked_latents_name, masked_latents_name=masked_latents_name,
), ),
@ -189,7 +189,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5 default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
) )
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection) latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
mask: Optional[InpaintMaskField] = InputField( denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, default=None,
description=FieldDescriptions.mask, description=FieldDescriptions.mask,
) )
@ -403,13 +403,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
return num_inference_steps, timesteps, init_timestep return num_inference_steps, timesteps, init_timestep
def prep_inpaint_mask(self, context, latents): def prep_inpaint_mask(self, context, latents):
if self.mask is None: if self.denoise_mask is None:
return None, None return None, None
mask = context.services.latents.get(self.mask.mask_name) mask = context.services.latents.get(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)
if self.mask.masked_latents_name is not None: if self.denoise_mask.masked_latents_name is not None:
masked_latents = context.services.latents.get(self.mask.masked_latents_name) masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name)
else: else:
masked_latents = None masked_latents = None

View File

@ -296,21 +296,21 @@ class ImageCollectionInvocation(BaseInvocation):
# endregion # endregion
# region InpaintMask # region DenoiseMask
class InpaintMaskField(BaseModel): class DenoiseMaskField(BaseModel):
"""An inpaint mask field""" """An inpaint mask field"""
mask_name: str = Field(description="The name of the mask image") mask_name: str = Field(description="The name of the mask image")
masked_latents_name: Optional[str] = Field(description="The name of the masked image latents") masked_latents_name: Optional[str] = Field(description="The name of the masked image latents")
class InpaintMaskOutput(BaseInvocationOutput): class DenoiseMaskOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image""" """Base class for nodes that output a single image"""
type: Literal["inpaint_mask_output"] = "inpaint_mask_output" type: Literal["denoise_mask_output"] = "denoise_mask_output"
inpaint_mask: InpaintMaskField = OutputField(description="Mask for inpaint model run") denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
# endregion # endregion

View File

@ -10,10 +10,10 @@ import ColorInputField from './inputs/ColorInputField';
import ConditioningInputField from './inputs/ConditioningInputField'; import ConditioningInputField from './inputs/ConditioningInputField';
import ControlInputField from './inputs/ControlInputField'; import ControlInputField from './inputs/ControlInputField';
import ControlNetModelInputField from './inputs/ControlNetModelInputField'; import ControlNetModelInputField from './inputs/ControlNetModelInputField';
import DenoiseMaskInputField from './inputs/DenoiseMaskInputField';
import EnumInputField from './inputs/EnumInputField'; import EnumInputField from './inputs/EnumInputField';
import ImageCollectionInputField from './inputs/ImageCollectionInputField'; import ImageCollectionInputField from './inputs/ImageCollectionInputField';
import ImageInputField from './inputs/ImageInputField'; import ImageInputField from './inputs/ImageInputField';
import InpaintMaskInputField from './inputs/InpaintMaskInputField';
import LatentsInputField from './inputs/LatentsInputField'; import LatentsInputField from './inputs/LatentsInputField';
import LoRAModelInputField from './inputs/LoRAModelInputField'; import LoRAModelInputField from './inputs/LoRAModelInputField';
import MainModelInputField from './inputs/MainModelInputField'; import MainModelInputField from './inputs/MainModelInputField';
@ -107,11 +107,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
} }
if ( if (
field?.type === 'InpaintMaskField' && field?.type === 'DenoiseMaskField' &&
fieldTemplate?.type === 'InpaintMaskField' fieldTemplate?.type === 'DenoiseMaskField'
) { ) {
return ( return (
<InpaintMaskInputField <DenoiseMaskInputField
nodeId={nodeId} nodeId={nodeId}
field={field} field={field}
fieldTemplate={fieldTemplate} fieldTemplate={fieldTemplate}

View File

@ -0,0 +1,17 @@
import {
DenoiseMaskInputFieldTemplate,
DenoiseMaskInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const DenoiseMaskInputFieldComponent = (
_props: FieldComponentProps<
DenoiseMaskInputFieldValue,
DenoiseMaskInputFieldTemplate
>
) => {
return null;
};
export default memo(DenoiseMaskInputFieldComponent);

View File

@ -1,17 +0,0 @@
import {
FieldComponentProps,
InpaintMaskInputFieldTemplate,
InpaintMaskInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
const InpaintMaskInputFieldComponent = (
_props: FieldComponentProps<
InpaintMaskInputFieldValue,
InpaintMaskInputFieldTemplate
>
) => {
return null;
};
export default memo(InpaintMaskInputFieldComponent);

View File

@ -59,9 +59,9 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
description: 'Images may be passed between nodes.', description: 'Images may be passed between nodes.',
color: 'purple.500', color: 'purple.500',
}, },
InpaintMaskField: { DenoiseMaskField: {
title: 'Inpaint Mask', title: 'Denoise Mask',
description: 'Inpaint Mask may be passed between nodes', description: 'Denoise Mask may be passed between nodes',
color: 'purple.500', color: 'purple.500',
}, },
LatentsField: { LatentsField: {

View File

@ -64,7 +64,7 @@ export const zFieldType = z.enum([
'string', 'string',
'array', 'array',
'ImageField', 'ImageField',
'InpaintMaskField', 'DenoiseMaskField',
'LatentsField', 'LatentsField',
'ConditioningField', 'ConditioningField',
'ControlField', 'ControlField',
@ -121,7 +121,7 @@ export type InputFieldTemplate =
| StringInputFieldTemplate | StringInputFieldTemplate
| BooleanInputFieldTemplate | BooleanInputFieldTemplate
| ImageInputFieldTemplate | ImageInputFieldTemplate
| InpaintMaskInputFieldTemplate | DenoiseMaskInputFieldTemplate
| LatentsInputFieldTemplate | LatentsInputFieldTemplate
| ConditioningInputFieldTemplate | ConditioningInputFieldTemplate
| UNetInputFieldTemplate | UNetInputFieldTemplate
@ -207,11 +207,11 @@ export const zConditioningField = z.object({
}); });
export type ConditioningField = z.infer<typeof zConditioningField>; export type ConditioningField = z.infer<typeof zConditioningField>;
export const zInpaintMaskField = z.object({ export const zDenoiseMaskField = z.object({
mask_name: z.string().trim().min(1), mask_name: z.string().trim().min(1),
masked_latents_name: z.string().trim().min(1).optional(), masked_latents_name: z.string().trim().min(1).optional(),
}); });
export type InpaintMaskFieldValue = z.infer<typeof zInpaintMaskField>; export type DenoiseMaskFieldValue = z.infer<typeof zDenoiseMaskField>;
export const zIntegerInputFieldValue = zInputFieldValueBase.extend({ export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('integer'), type: z.literal('integer'),
@ -249,12 +249,12 @@ export const zLatentsInputFieldValue = zInputFieldValueBase.extend({
}); });
export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>; export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>;
export const zInpaintMaskInputFieldValue = zInputFieldValueBase.extend({ export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('InpaintMaskField'), type: z.literal('DenoiseMaskField'),
value: zInpaintMaskField.optional(), value: zDenoiseMaskField.optional(),
}); });
export type InpaintMaskInputFieldValue = z.infer< export type DenoiseMaskInputFieldValue = z.infer<
typeof zInpaintMaskInputFieldValue typeof zDenoiseMaskInputFieldValue
>; >;
export const zConditioningInputFieldValue = zInputFieldValueBase.extend({ export const zConditioningInputFieldValue = zInputFieldValueBase.extend({
@ -475,7 +475,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
zBooleanInputFieldValue, zBooleanInputFieldValue,
zImageInputFieldValue, zImageInputFieldValue,
zLatentsInputFieldValue, zLatentsInputFieldValue,
zInpaintMaskInputFieldValue, zDenoiseMaskInputFieldValue,
zConditioningInputFieldValue, zConditioningInputFieldValue,
zUNetInputFieldValue, zUNetInputFieldValue,
zClipInputFieldValue, zClipInputFieldValue,
@ -549,9 +549,9 @@ export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'ImageCollection'; type: 'ImageCollection';
}; };
export type InpaintMaskInputFieldTemplate = InputFieldTemplateBase & { export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
default: undefined; default: undefined;
type: 'InpaintMaskField'; type: 'DenoiseMaskField';
}; };
export type LatentsInputFieldTemplate = InputFieldTemplateBase & { export type LatentsInputFieldTemplate = InputFieldTemplateBase & {

View File

@ -8,12 +8,12 @@ import {
ConditioningInputFieldTemplate, ConditioningInputFieldTemplate,
ControlInputFieldTemplate, ControlInputFieldTemplate,
ControlNetModelInputFieldTemplate, ControlNetModelInputFieldTemplate,
DenoiseMaskInputFieldTemplate,
EnumInputFieldTemplate, EnumInputFieldTemplate,
FieldType, FieldType,
FloatInputFieldTemplate, FloatInputFieldTemplate,
ImageCollectionInputFieldTemplate, ImageCollectionInputFieldTemplate,
ImageInputFieldTemplate, ImageInputFieldTemplate,
InpaintMaskInputFieldTemplate,
InputFieldTemplateBase, InputFieldTemplateBase,
IntegerInputFieldTemplate, IntegerInputFieldTemplate,
InvocationFieldSchema, InvocationFieldSchema,
@ -264,13 +264,13 @@ const buildImageCollectionInputFieldTemplate = ({
return template; return template;
}; };
const buildInpaintMaskInputFieldTemplate = ({ const buildDenoiseMaskInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
}: BuildInputFieldArg): InpaintMaskInputFieldTemplate => { }: BuildInputFieldArg): DenoiseMaskInputFieldTemplate => {
const template: InpaintMaskInputFieldTemplate = { const template: DenoiseMaskInputFieldTemplate = {
...baseField, ...baseField,
type: 'InpaintMaskField', type: 'DenoiseMaskField',
default: schemaObject.default ?? undefined, default: schemaObject.default ?? undefined,
}; };
@ -512,8 +512,8 @@ export const buildInputFieldTemplate = (
baseField, baseField,
}); });
} }
if (fieldType === 'InpaintMaskField') { if (fieldType === 'DenoiseMaskField') {
return buildInpaintMaskInputFieldTemplate({ return buildDenoiseMaskInputFieldTemplate({
schemaObject: fieldSchema, schemaObject: fieldSchema,
baseField, baseField,
}); });

View File

@ -49,7 +49,7 @@ export const buildInputFieldValue = (
fieldValue.value = []; fieldValue.value = [];
} }
if (template.type === 'InpaintMaskField') { if (template.type === 'DenoiseMaskField') {
fieldValue.value = undefined; fieldValue.value = undefined;
} }

View File

@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { import {
CreateInpaintMaskInvocation, CreateDenoiseMaskInvocation,
ImageBlurInvocation, ImageBlurInvocation,
ImageDTO, ImageDTO,
ImageToLatentsInvocation, ImageToLatentsInvocation,
@ -130,7 +130,7 @@ export const buildCanvasInpaintGraph = (
fp32: vaePrecision === 'fp32' ? true : false, fp32: vaePrecision === 'fp32' ? true : false,
}, },
[INPAINT_CREATE_MASK]: { [INPAINT_CREATE_MASK]: {
type: 'create_inpaint_mask', type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK, id: INPAINT_CREATE_MASK,
is_intermediate: true, is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false, fp32: vaePrecision === 'fp32' ? true : false,
@ -298,11 +298,11 @@ export const buildCanvasInpaintGraph = (
{ {
source: { source: {
node_id: INPAINT_CREATE_MASK, node_id: INPAINT_CREATE_MASK,
field: 'inpaint_mask', field: 'denoise_mask',
}, },
destination: { destination: {
node_id: DENOISE_LATENTS, node_id: DENOISE_LATENTS,
field: 'mask', field: 'denoise_mask',
}, },
}, },
// Iterate // Iterate
@ -546,7 +546,7 @@ export const buildCanvasInpaintGraph = (
image: canvasMaskImage, image: canvasMaskImage,
}; };
graph.nodes[INPAINT_CREATE_MASK] = { graph.nodes[INPAINT_CREATE_MASK] = {
...(graph.nodes[INPAINT_CREATE_MASK] as CreateInpaintMaskInvocation), ...(graph.nodes[INPAINT_CREATE_MASK] as CreateDenoiseMaskInvocation),
image: canvasInitImage, image: canvasInitImage,
}; };

View File

@ -155,7 +155,7 @@ export const buildCanvasOutpaintGraph = (
is_intermediate: true, is_intermediate: true,
}, },
[INPAINT_CREATE_MASK]: { [INPAINT_CREATE_MASK]: {
type: 'create_inpaint_mask', type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK, id: INPAINT_CREATE_MASK,
is_intermediate: true, is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false, fp32: vaePrecision === 'fp32' ? true : false,
@ -338,11 +338,11 @@ export const buildCanvasOutpaintGraph = (
{ {
source: { source: {
node_id: INPAINT_CREATE_MASK, node_id: INPAINT_CREATE_MASK,
field: 'inpaint_mask', field: 'denoise_mask',
}, },
destination: { destination: {
node_id: DENOISE_LATENTS, node_id: DENOISE_LATENTS,
field: 'mask', field: 'denoise_mask',
}, },
}, },
// Iterate // Iterate

View File

@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { import {
CreateInpaintMaskInvocation, CreateDenoiseMaskInvocation,
ImageBlurInvocation, ImageBlurInvocation,
ImageDTO, ImageDTO,
ImageToLatentsInvocation, ImageToLatentsInvocation,
@ -139,7 +139,7 @@ export const buildCanvasSDXLInpaintGraph = (
is_intermediate: true, is_intermediate: true,
}, },
[INPAINT_CREATE_MASK]: { [INPAINT_CREATE_MASK]: {
type: 'create_inpaint_mask', type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK, id: INPAINT_CREATE_MASK,
is_intermediate: true, is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false, fp32: vaePrecision === 'fp32' ? true : false,
@ -312,11 +312,11 @@ export const buildCanvasSDXLInpaintGraph = (
{ {
source: { source: {
node_id: INPAINT_CREATE_MASK, node_id: INPAINT_CREATE_MASK,
field: 'inpaint_mask', field: 'denoise_mask',
}, },
destination: { destination: {
node_id: SDXL_DENOISE_LATENTS, node_id: SDXL_DENOISE_LATENTS,
field: 'mask', field: 'denoise_mask',
}, },
}, },
// Iterate // Iterate
@ -560,7 +560,7 @@ export const buildCanvasSDXLInpaintGraph = (
image: canvasMaskImage, image: canvasMaskImage,
}; };
graph.nodes[INPAINT_CREATE_MASK] = { graph.nodes[INPAINT_CREATE_MASK] = {
...(graph.nodes[INPAINT_CREATE_MASK] as CreateInpaintMaskInvocation), ...(graph.nodes[INPAINT_CREATE_MASK] as CreateDenoiseMaskInvocation),
image: canvasInitImage, image: canvasInitImage,
}; };

View File

@ -158,7 +158,7 @@ export const buildCanvasSDXLOutpaintGraph = (
is_intermediate: true, is_intermediate: true,
}, },
[INPAINT_CREATE_MASK]: { [INPAINT_CREATE_MASK]: {
type: 'create_inpaint_mask', type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK, id: INPAINT_CREATE_MASK,
is_intermediate: true, is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false, fp32: vaePrecision === 'fp32' ? true : false,
@ -352,11 +352,11 @@ export const buildCanvasSDXLOutpaintGraph = (
{ {
source: { source: {
node_id: INPAINT_CREATE_MASK, node_id: INPAINT_CREATE_MASK,
field: 'inpaint_mask', field: 'denoise_mask',
}, },
destination: { destination: {
node_id: SDXL_DENOISE_LATENTS, node_id: SDXL_DENOISE_LATENTS,
field: 'mask', field: 'denoise_mask',
}, },
}, },
// Iterate // Iterate

File diff suppressed because one or more lines are too long

View File

@ -111,7 +111,7 @@ export type ImageBlurInvocation = s['ImageBlurInvocation'];
export type ImageScaleInvocation = s['ImageScaleInvocation']; export type ImageScaleInvocation = s['ImageScaleInvocation'];
export type InfillPatchMatchInvocation = s['InfillPatchMatchInvocation']; export type InfillPatchMatchInvocation = s['InfillPatchMatchInvocation'];
export type InfillTileInvocation = s['InfillTileInvocation']; export type InfillTileInvocation = s['InfillTileInvocation'];
export type CreateInpaintMaskInvocation = s['CreateInpaintMaskInvocation']; export type CreateDenoiseMaskInvocation = s['CreateDenoiseMaskInvocation'];
export type RandomIntInvocation = s['RandomIntInvocation']; export type RandomIntInvocation = s['RandomIntInvocation'];
export type CompelInvocation = s['CompelInvocation']; export type CompelInvocation = s['CompelInvocation'];
export type DynamicPromptInvocation = s['DynamicPromptInvocation']; export type DynamicPromptInvocation = s['DynamicPromptInvocation'];