rename: Inpaint Mask to Denoise Mask

This commit is contained in:
blessedcoolant 2023-08-27 05:50:13 +12:00
parent 226721ce51
commit c923d094c6
15 changed files with 137 additions and 137 deletions

View File

@ -21,10 +21,10 @@ from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import (
DenoiseMaskField,
DenoiseMaskOutput,
ImageField,
ImageOutput,
InpaintMaskField,
InpaintMaskOutput,
LatentsField,
LatentsOutput,
build_latents_output,
@ -57,16 +57,16 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@title("Create Inpaint Mask")
@tags("mask", "inpaint")
class CreateInpaintMaskInvocation(BaseInvocation):
"""Creates mask for inpaint model run."""
@title("Create Denoise Mask")
@tags("mask", "denoise")
class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
# Metadata
type: Literal["create_inpaint_mask"] = "create_inpaint_mask"
type: Literal["create_denoise_mask"] = "create_denoise_mask"
# Inputs
image: Optional[ImageField] = InputField(default=None, description="Image which will be inpainted")
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked")
mask: ImageField = InputField(description="The mask to use when pasting")
vae: VaeField = InputField(
description=FieldDescriptions.vae,
@ -86,7 +86,7 @@ class CreateInpaintMaskInvocation(BaseInvocation):
return mask_tensor
@torch.no_grad()
def invoke(self, context: InvocationContext) -> InpaintMaskOutput:
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None:
image = context.services.images.get_pil_image(self.image.image_name)
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
@ -118,8 +118,8 @@ class CreateInpaintMaskInvocation(BaseInvocation):
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
context.services.latents.save(mask_name, mask)
return InpaintMaskOutput(
inpaint_mask=InpaintMaskField(
return DenoiseMaskOutput(
denoise_mask=DenoiseMaskField(
mask_name=mask_name,
masked_latents_name=masked_latents_name,
),
@ -189,7 +189,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
)
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
mask: Optional[InpaintMaskField] = InputField(
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.mask,
)
@ -403,13 +403,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
return num_inference_steps, timesteps, init_timestep
def prep_inpaint_mask(self, context, latents):
if self.mask is None:
if self.denoise_mask is None:
return None, None
mask = context.services.latents.get(self.mask.mask_name)
mask = context.services.latents.get(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)
if self.mask.masked_latents_name is not None:
masked_latents = context.services.latents.get(self.mask.masked_latents_name)
if self.denoise_mask.masked_latents_name is not None:
masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name)
else:
masked_latents = None

View File

@ -296,21 +296,21 @@ class ImageCollectionInvocation(BaseInvocation):
# endregion
# region InpaintMask
# region DenoiseMask
class InpaintMaskField(BaseModel):
class DenoiseMaskField(BaseModel):
"""An inpaint mask field"""
mask_name: str = Field(description="The name of the mask image")
masked_latents_name: Optional[str] = Field(description="The name of the masked image latents")
class InpaintMaskOutput(BaseInvocationOutput):
class DenoiseMaskOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image"""
type: Literal["inpaint_mask_output"] = "inpaint_mask_output"
inpaint_mask: InpaintMaskField = OutputField(description="Mask for inpaint model run")
type: Literal["denoise_mask_output"] = "denoise_mask_output"
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
# endregion

View File

@ -10,10 +10,10 @@ import ColorInputField from './inputs/ColorInputField';
import ConditioningInputField from './inputs/ConditioningInputField';
import ControlInputField from './inputs/ControlInputField';
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
import DenoiseMaskInputField from './inputs/DenoiseMaskInputField';
import EnumInputField from './inputs/EnumInputField';
import ImageCollectionInputField from './inputs/ImageCollectionInputField';
import ImageInputField from './inputs/ImageInputField';
import InpaintMaskInputField from './inputs/InpaintMaskInputField';
import LatentsInputField from './inputs/LatentsInputField';
import LoRAModelInputField from './inputs/LoRAModelInputField';
import MainModelInputField from './inputs/MainModelInputField';
@ -107,11 +107,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
}
if (
field?.type === 'InpaintMaskField' &&
fieldTemplate?.type === 'InpaintMaskField'
field?.type === 'DenoiseMaskField' &&
fieldTemplate?.type === 'DenoiseMaskField'
) {
return (
<InpaintMaskInputField
<DenoiseMaskInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}

View File

@ -0,0 +1,17 @@
import {
DenoiseMaskInputFieldTemplate,
DenoiseMaskInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const DenoiseMaskInputFieldComponent = (
_props: FieldComponentProps<
DenoiseMaskInputFieldValue,
DenoiseMaskInputFieldTemplate
>
) => {
return null;
};
export default memo(DenoiseMaskInputFieldComponent);

View File

@ -1,17 +0,0 @@
import {
FieldComponentProps,
InpaintMaskInputFieldTemplate,
InpaintMaskInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
const InpaintMaskInputFieldComponent = (
_props: FieldComponentProps<
InpaintMaskInputFieldValue,
InpaintMaskInputFieldTemplate
>
) => {
return null;
};
export default memo(InpaintMaskInputFieldComponent);

View File

@ -59,9 +59,9 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
description: 'Images may be passed between nodes.',
color: 'purple.500',
},
InpaintMaskField: {
title: 'Inpaint Mask',
description: 'Inpaint Mask may be passed between nodes',
DenoiseMaskField: {
title: 'Denoise Mask',
description: 'Denoise Mask may be passed between nodes',
color: 'purple.500',
},
LatentsField: {

View File

@ -64,7 +64,7 @@ export const zFieldType = z.enum([
'string',
'array',
'ImageField',
'InpaintMaskField',
'DenoiseMaskField',
'LatentsField',
'ConditioningField',
'ControlField',
@ -121,7 +121,7 @@ export type InputFieldTemplate =
| StringInputFieldTemplate
| BooleanInputFieldTemplate
| ImageInputFieldTemplate
| InpaintMaskInputFieldTemplate
| DenoiseMaskInputFieldTemplate
| LatentsInputFieldTemplate
| ConditioningInputFieldTemplate
| UNetInputFieldTemplate
@ -207,11 +207,11 @@ export const zConditioningField = z.object({
});
export type ConditioningField = z.infer<typeof zConditioningField>;
export const zInpaintMaskField = z.object({
export const zDenoiseMaskField = z.object({
mask_name: z.string().trim().min(1),
masked_latents_name: z.string().trim().min(1).optional(),
});
export type InpaintMaskFieldValue = z.infer<typeof zInpaintMaskField>;
export type DenoiseMaskFieldValue = z.infer<typeof zDenoiseMaskField>;
export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('integer'),
@ -249,12 +249,12 @@ export const zLatentsInputFieldValue = zInputFieldValueBase.extend({
});
export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>;
export const zInpaintMaskInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('InpaintMaskField'),
value: zInpaintMaskField.optional(),
export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('DenoiseMaskField'),
value: zDenoiseMaskField.optional(),
});
export type InpaintMaskInputFieldValue = z.infer<
typeof zInpaintMaskInputFieldValue
export type DenoiseMaskInputFieldValue = z.infer<
typeof zDenoiseMaskInputFieldValue
>;
export const zConditioningInputFieldValue = zInputFieldValueBase.extend({
@ -475,7 +475,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
zBooleanInputFieldValue,
zImageInputFieldValue,
zLatentsInputFieldValue,
zInpaintMaskInputFieldValue,
zDenoiseMaskInputFieldValue,
zConditioningInputFieldValue,
zUNetInputFieldValue,
zClipInputFieldValue,
@ -549,9 +549,9 @@ export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'ImageCollection';
};
export type InpaintMaskInputFieldTemplate = InputFieldTemplateBase & {
export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'InpaintMaskField';
type: 'DenoiseMaskField';
};
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {

View File

@ -8,12 +8,12 @@ import {
ConditioningInputFieldTemplate,
ControlInputFieldTemplate,
ControlNetModelInputFieldTemplate,
DenoiseMaskInputFieldTemplate,
EnumInputFieldTemplate,
FieldType,
FloatInputFieldTemplate,
ImageCollectionInputFieldTemplate,
ImageInputFieldTemplate,
InpaintMaskInputFieldTemplate,
InputFieldTemplateBase,
IntegerInputFieldTemplate,
InvocationFieldSchema,
@ -264,13 +264,13 @@ const buildImageCollectionInputFieldTemplate = ({
return template;
};
const buildInpaintMaskInputFieldTemplate = ({
const buildDenoiseMaskInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): InpaintMaskInputFieldTemplate => {
const template: InpaintMaskInputFieldTemplate = {
}: BuildInputFieldArg): DenoiseMaskInputFieldTemplate => {
const template: DenoiseMaskInputFieldTemplate = {
...baseField,
type: 'InpaintMaskField',
type: 'DenoiseMaskField',
default: schemaObject.default ?? undefined,
};
@ -512,8 +512,8 @@ export const buildInputFieldTemplate = (
baseField,
});
}
if (fieldType === 'InpaintMaskField') {
return buildInpaintMaskInputFieldTemplate({
if (fieldType === 'DenoiseMaskField') {
return buildDenoiseMaskInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});

View File

@ -49,7 +49,7 @@ export const buildInputFieldValue = (
fieldValue.value = [];
}
if (template.type === 'InpaintMaskField') {
if (template.type === 'DenoiseMaskField') {
fieldValue.value = undefined;
}

View File

@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
CreateInpaintMaskInvocation,
CreateDenoiseMaskInvocation,
ImageBlurInvocation,
ImageDTO,
ImageToLatentsInvocation,
@ -130,7 +130,7 @@ export const buildCanvasInpaintGraph = (
fp32: vaePrecision === 'fp32' ? true : false,
},
[INPAINT_CREATE_MASK]: {
type: 'create_inpaint_mask',
type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK,
is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false,
@ -298,11 +298,11 @@ export const buildCanvasInpaintGraph = (
{
source: {
node_id: INPAINT_CREATE_MASK,
field: 'inpaint_mask',
field: 'denoise_mask',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'mask',
field: 'denoise_mask',
},
},
// Iterate
@ -546,7 +546,7 @@ export const buildCanvasInpaintGraph = (
image: canvasMaskImage,
};
graph.nodes[INPAINT_CREATE_MASK] = {
...(graph.nodes[INPAINT_CREATE_MASK] as CreateInpaintMaskInvocation),
...(graph.nodes[INPAINT_CREATE_MASK] as CreateDenoiseMaskInvocation),
image: canvasInitImage,
};

View File

@ -155,7 +155,7 @@ export const buildCanvasOutpaintGraph = (
is_intermediate: true,
},
[INPAINT_CREATE_MASK]: {
type: 'create_inpaint_mask',
type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK,
is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false,
@ -338,11 +338,11 @@ export const buildCanvasOutpaintGraph = (
{
source: {
node_id: INPAINT_CREATE_MASK,
field: 'inpaint_mask',
field: 'denoise_mask',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'mask',
field: 'denoise_mask',
},
},
// Iterate

View File

@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
CreateInpaintMaskInvocation,
CreateDenoiseMaskInvocation,
ImageBlurInvocation,
ImageDTO,
ImageToLatentsInvocation,
@ -139,7 +139,7 @@ export const buildCanvasSDXLInpaintGraph = (
is_intermediate: true,
},
[INPAINT_CREATE_MASK]: {
type: 'create_inpaint_mask',
type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK,
is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false,
@ -312,11 +312,11 @@ export const buildCanvasSDXLInpaintGraph = (
{
source: {
node_id: INPAINT_CREATE_MASK,
field: 'inpaint_mask',
field: 'denoise_mask',
},
destination: {
node_id: SDXL_DENOISE_LATENTS,
field: 'mask',
field: 'denoise_mask',
},
},
// Iterate
@ -560,7 +560,7 @@ export const buildCanvasSDXLInpaintGraph = (
image: canvasMaskImage,
};
graph.nodes[INPAINT_CREATE_MASK] = {
...(graph.nodes[INPAINT_CREATE_MASK] as CreateInpaintMaskInvocation),
...(graph.nodes[INPAINT_CREATE_MASK] as CreateDenoiseMaskInvocation),
image: canvasInitImage,
};

View File

@ -158,7 +158,7 @@ export const buildCanvasSDXLOutpaintGraph = (
is_intermediate: true,
},
[INPAINT_CREATE_MASK]: {
type: 'create_inpaint_mask',
type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK,
is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false,
@ -352,11 +352,11 @@ export const buildCanvasSDXLOutpaintGraph = (
{
source: {
node_id: INPAINT_CREATE_MASK,
field: 'inpaint_mask',
field: 'denoise_mask',
},
destination: {
node_id: SDXL_DENOISE_LATENTS,
field: 'mask',
field: 'denoise_mask',
},
},
// Iterate

File diff suppressed because one or more lines are too long

View File

@ -111,7 +111,7 @@ export type ImageBlurInvocation = s['ImageBlurInvocation'];
export type ImageScaleInvocation = s['ImageScaleInvocation'];
export type InfillPatchMatchInvocation = s['InfillPatchMatchInvocation'];
export type InfillTileInvocation = s['InfillTileInvocation'];
export type CreateInpaintMaskInvocation = s['CreateInpaintMaskInvocation'];
export type CreateDenoiseMaskInvocation = s['CreateDenoiseMaskInvocation'];
export type RandomIntInvocation = s['RandomIntInvocation'];
export type CompelInvocation = s['CompelInvocation'];
export type DynamicPromptInvocation = s['DynamicPromptInvocation'];