mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
rename: Inpaint Mask to Denoise Mask
This commit is contained in:
parent
226721ce51
commit
c923d094c6
@ -21,10 +21,10 @@ from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import (
|
||||
DenoiseMaskField,
|
||||
DenoiseMaskOutput,
|
||||
ImageField,
|
||||
ImageOutput,
|
||||
InpaintMaskField,
|
||||
InpaintMaskOutput,
|
||||
LatentsField,
|
||||
LatentsOutput,
|
||||
build_latents_output,
|
||||
@ -57,16 +57,16 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||
|
||||
|
||||
@title("Create Inpaint Mask")
|
||||
@tags("mask", "inpaint")
|
||||
class CreateInpaintMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for inpaint model run."""
|
||||
@title("Create Denoise Mask")
|
||||
@tags("mask", "denoise")
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
|
||||
# Metadata
|
||||
type: Literal["create_inpaint_mask"] = "create_inpaint_mask"
|
||||
type: Literal["create_denoise_mask"] = "create_denoise_mask"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be inpainted")
|
||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked")
|
||||
mask: ImageField = InputField(description="The mask to use when pasting")
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
@ -86,7 +86,7 @@ class CreateInpaintMaskInvocation(BaseInvocation):
|
||||
return mask_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> InpaintMaskOutput:
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
if self.image is not None:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
@ -118,8 +118,8 @@ class CreateInpaintMaskInvocation(BaseInvocation):
|
||||
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
|
||||
context.services.latents.save(mask_name, mask)
|
||||
|
||||
return InpaintMaskOutput(
|
||||
inpaint_mask=InpaintMaskField(
|
||||
return DenoiseMaskOutput(
|
||||
denoise_mask=DenoiseMaskField(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=masked_latents_name,
|
||||
),
|
||||
@ -189,7 +189,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
mask: Optional[InpaintMaskField] = InputField(
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.mask,
|
||||
)
|
||||
@ -403,13 +403,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
return num_inference_steps, timesteps, init_timestep
|
||||
|
||||
def prep_inpaint_mask(self, context, latents):
|
||||
if self.mask is None:
|
||||
if self.denoise_mask is None:
|
||||
return None, None
|
||||
|
||||
mask = context.services.latents.get(self.mask.mask_name)
|
||||
mask = context.services.latents.get(self.denoise_mask.mask_name)
|
||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)
|
||||
if self.mask.masked_latents_name is not None:
|
||||
masked_latents = context.services.latents.get(self.mask.masked_latents_name)
|
||||
if self.denoise_mask.masked_latents_name is not None:
|
||||
masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name)
|
||||
else:
|
||||
masked_latents = None
|
||||
|
||||
|
@ -296,21 +296,21 @@ class ImageCollectionInvocation(BaseInvocation):
|
||||
|
||||
# endregion
|
||||
|
||||
# region InpaintMask
|
||||
# region DenoiseMask
|
||||
|
||||
|
||||
class InpaintMaskField(BaseModel):
|
||||
class DenoiseMaskField(BaseModel):
|
||||
"""An inpaint mask field"""
|
||||
|
||||
mask_name: str = Field(description="The name of the mask image")
|
||||
masked_latents_name: Optional[str] = Field(description="The name of the masked image latents")
|
||||
|
||||
|
||||
class InpaintMaskOutput(BaseInvocationOutput):
|
||||
class DenoiseMaskOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single image"""
|
||||
|
||||
type: Literal["inpaint_mask_output"] = "inpaint_mask_output"
|
||||
inpaint_mask: InpaintMaskField = OutputField(description="Mask for inpaint model run")
|
||||
type: Literal["denoise_mask_output"] = "denoise_mask_output"
|
||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||
|
||||
|
||||
# endregion
|
||||
|
@ -10,10 +10,10 @@ import ColorInputField from './inputs/ColorInputField';
|
||||
import ConditioningInputField from './inputs/ConditioningInputField';
|
||||
import ControlInputField from './inputs/ControlInputField';
|
||||
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
|
||||
import DenoiseMaskInputField from './inputs/DenoiseMaskInputField';
|
||||
import EnumInputField from './inputs/EnumInputField';
|
||||
import ImageCollectionInputField from './inputs/ImageCollectionInputField';
|
||||
import ImageInputField from './inputs/ImageInputField';
|
||||
import InpaintMaskInputField from './inputs/InpaintMaskInputField';
|
||||
import LatentsInputField from './inputs/LatentsInputField';
|
||||
import LoRAModelInputField from './inputs/LoRAModelInputField';
|
||||
import MainModelInputField from './inputs/MainModelInputField';
|
||||
@ -107,11 +107,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'InpaintMaskField' &&
|
||||
fieldTemplate?.type === 'InpaintMaskField'
|
||||
field?.type === 'DenoiseMaskField' &&
|
||||
fieldTemplate?.type === 'DenoiseMaskField'
|
||||
) {
|
||||
return (
|
||||
<InpaintMaskInputField
|
||||
<DenoiseMaskInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
|
@ -0,0 +1,17 @@
|
||||
import {
|
||||
DenoiseMaskInputFieldTemplate,
|
||||
DenoiseMaskInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const DenoiseMaskInputFieldComponent = (
|
||||
_props: FieldComponentProps<
|
||||
DenoiseMaskInputFieldValue,
|
||||
DenoiseMaskInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(DenoiseMaskInputFieldComponent);
|
@ -1,17 +0,0 @@
|
||||
import {
|
||||
FieldComponentProps,
|
||||
InpaintMaskInputFieldTemplate,
|
||||
InpaintMaskInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const InpaintMaskInputFieldComponent = (
|
||||
_props: FieldComponentProps<
|
||||
InpaintMaskInputFieldValue,
|
||||
InpaintMaskInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(InpaintMaskInputFieldComponent);
|
@ -59,9 +59,9 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
description: 'Images may be passed between nodes.',
|
||||
color: 'purple.500',
|
||||
},
|
||||
InpaintMaskField: {
|
||||
title: 'Inpaint Mask',
|
||||
description: 'Inpaint Mask may be passed between nodes',
|
||||
DenoiseMaskField: {
|
||||
title: 'Denoise Mask',
|
||||
description: 'Denoise Mask may be passed between nodes',
|
||||
color: 'purple.500',
|
||||
},
|
||||
LatentsField: {
|
||||
|
@ -64,7 +64,7 @@ export const zFieldType = z.enum([
|
||||
'string',
|
||||
'array',
|
||||
'ImageField',
|
||||
'InpaintMaskField',
|
||||
'DenoiseMaskField',
|
||||
'LatentsField',
|
||||
'ConditioningField',
|
||||
'ControlField',
|
||||
@ -121,7 +121,7 @@ export type InputFieldTemplate =
|
||||
| StringInputFieldTemplate
|
||||
| BooleanInputFieldTemplate
|
||||
| ImageInputFieldTemplate
|
||||
| InpaintMaskInputFieldTemplate
|
||||
| DenoiseMaskInputFieldTemplate
|
||||
| LatentsInputFieldTemplate
|
||||
| ConditioningInputFieldTemplate
|
||||
| UNetInputFieldTemplate
|
||||
@ -207,11 +207,11 @@ export const zConditioningField = z.object({
|
||||
});
|
||||
export type ConditioningField = z.infer<typeof zConditioningField>;
|
||||
|
||||
export const zInpaintMaskField = z.object({
|
||||
export const zDenoiseMaskField = z.object({
|
||||
mask_name: z.string().trim().min(1),
|
||||
masked_latents_name: z.string().trim().min(1).optional(),
|
||||
});
|
||||
export type InpaintMaskFieldValue = z.infer<typeof zInpaintMaskField>;
|
||||
export type DenoiseMaskFieldValue = z.infer<typeof zDenoiseMaskField>;
|
||||
|
||||
export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('integer'),
|
||||
@ -249,12 +249,12 @@ export const zLatentsInputFieldValue = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>;
|
||||
|
||||
export const zInpaintMaskInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('InpaintMaskField'),
|
||||
value: zInpaintMaskField.optional(),
|
||||
export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('DenoiseMaskField'),
|
||||
value: zDenoiseMaskField.optional(),
|
||||
});
|
||||
export type InpaintMaskInputFieldValue = z.infer<
|
||||
typeof zInpaintMaskInputFieldValue
|
||||
export type DenoiseMaskInputFieldValue = z.infer<
|
||||
typeof zDenoiseMaskInputFieldValue
|
||||
>;
|
||||
|
||||
export const zConditioningInputFieldValue = zInputFieldValueBase.extend({
|
||||
@ -475,7 +475,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zBooleanInputFieldValue,
|
||||
zImageInputFieldValue,
|
||||
zLatentsInputFieldValue,
|
||||
zInpaintMaskInputFieldValue,
|
||||
zDenoiseMaskInputFieldValue,
|
||||
zConditioningInputFieldValue,
|
||||
zUNetInputFieldValue,
|
||||
zClipInputFieldValue,
|
||||
@ -549,9 +549,9 @@ export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'ImageCollection';
|
||||
};
|
||||
|
||||
export type InpaintMaskInputFieldTemplate = InputFieldTemplateBase & {
|
||||
export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'InpaintMaskField';
|
||||
type: 'DenoiseMaskField';
|
||||
};
|
||||
|
||||
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
|
||||
|
@ -8,12 +8,12 @@ import {
|
||||
ConditioningInputFieldTemplate,
|
||||
ControlInputFieldTemplate,
|
||||
ControlNetModelInputFieldTemplate,
|
||||
DenoiseMaskInputFieldTemplate,
|
||||
EnumInputFieldTemplate,
|
||||
FieldType,
|
||||
FloatInputFieldTemplate,
|
||||
ImageCollectionInputFieldTemplate,
|
||||
ImageInputFieldTemplate,
|
||||
InpaintMaskInputFieldTemplate,
|
||||
InputFieldTemplateBase,
|
||||
IntegerInputFieldTemplate,
|
||||
InvocationFieldSchema,
|
||||
@ -264,13 +264,13 @@ const buildImageCollectionInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildInpaintMaskInputFieldTemplate = ({
|
||||
const buildDenoiseMaskInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): InpaintMaskInputFieldTemplate => {
|
||||
const template: InpaintMaskInputFieldTemplate = {
|
||||
}: BuildInputFieldArg): DenoiseMaskInputFieldTemplate => {
|
||||
const template: DenoiseMaskInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'InpaintMaskField',
|
||||
type: 'DenoiseMaskField',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
@ -512,8 +512,8 @@ export const buildInputFieldTemplate = (
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'InpaintMaskField') {
|
||||
return buildInpaintMaskInputFieldTemplate({
|
||||
if (fieldType === 'DenoiseMaskField') {
|
||||
return buildDenoiseMaskInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
|
@ -49,7 +49,7 @@ export const buildInputFieldValue = (
|
||||
fieldValue.value = [];
|
||||
}
|
||||
|
||||
if (template.type === 'InpaintMaskField') {
|
||||
if (template.type === 'DenoiseMaskField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
CreateInpaintMaskInvocation,
|
||||
CreateDenoiseMaskInvocation,
|
||||
ImageBlurInvocation,
|
||||
ImageDTO,
|
||||
ImageToLatentsInvocation,
|
||||
@ -130,7 +130,7 @@ export const buildCanvasInpaintGraph = (
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[INPAINT_CREATE_MASK]: {
|
||||
type: 'create_inpaint_mask',
|
||||
type: 'create_denoise_mask',
|
||||
id: INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
@ -298,11 +298,11 @@ export const buildCanvasInpaintGraph = (
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'inpaint_mask',
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
destination: {
|
||||
node_id: DENOISE_LATENTS,
|
||||
field: 'mask',
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
},
|
||||
// Iterate
|
||||
@ -546,7 +546,7 @@ export const buildCanvasInpaintGraph = (
|
||||
image: canvasMaskImage,
|
||||
};
|
||||
graph.nodes[INPAINT_CREATE_MASK] = {
|
||||
...(graph.nodes[INPAINT_CREATE_MASK] as CreateInpaintMaskInvocation),
|
||||
...(graph.nodes[INPAINT_CREATE_MASK] as CreateDenoiseMaskInvocation),
|
||||
image: canvasInitImage,
|
||||
};
|
||||
|
||||
|
@ -155,7 +155,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
is_intermediate: true,
|
||||
},
|
||||
[INPAINT_CREATE_MASK]: {
|
||||
type: 'create_inpaint_mask',
|
||||
type: 'create_denoise_mask',
|
||||
id: INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
@ -338,11 +338,11 @@ export const buildCanvasOutpaintGraph = (
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'inpaint_mask',
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
destination: {
|
||||
node_id: DENOISE_LATENTS,
|
||||
field: 'mask',
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
},
|
||||
// Iterate
|
||||
|
@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
CreateInpaintMaskInvocation,
|
||||
CreateDenoiseMaskInvocation,
|
||||
ImageBlurInvocation,
|
||||
ImageDTO,
|
||||
ImageToLatentsInvocation,
|
||||
@ -139,7 +139,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
is_intermediate: true,
|
||||
},
|
||||
[INPAINT_CREATE_MASK]: {
|
||||
type: 'create_inpaint_mask',
|
||||
type: 'create_denoise_mask',
|
||||
id: INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
@ -312,11 +312,11 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'inpaint_mask',
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_DENOISE_LATENTS,
|
||||
field: 'mask',
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
},
|
||||
// Iterate
|
||||
@ -560,7 +560,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
image: canvasMaskImage,
|
||||
};
|
||||
graph.nodes[INPAINT_CREATE_MASK] = {
|
||||
...(graph.nodes[INPAINT_CREATE_MASK] as CreateInpaintMaskInvocation),
|
||||
...(graph.nodes[INPAINT_CREATE_MASK] as CreateDenoiseMaskInvocation),
|
||||
image: canvasInitImage,
|
||||
};
|
||||
|
||||
|
@ -158,7 +158,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
is_intermediate: true,
|
||||
},
|
||||
[INPAINT_CREATE_MASK]: {
|
||||
type: 'create_inpaint_mask',
|
||||
type: 'create_denoise_mask',
|
||||
id: INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
@ -352,11 +352,11 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'inpaint_mask',
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_DENOISE_LATENTS,
|
||||
field: 'mask',
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
},
|
||||
// Iterate
|
||||
|
110
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
110
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -111,7 +111,7 @@ export type ImageBlurInvocation = s['ImageBlurInvocation'];
|
||||
export type ImageScaleInvocation = s['ImageScaleInvocation'];
|
||||
export type InfillPatchMatchInvocation = s['InfillPatchMatchInvocation'];
|
||||
export type InfillTileInvocation = s['InfillTileInvocation'];
|
||||
export type CreateInpaintMaskInvocation = s['CreateInpaintMaskInvocation'];
|
||||
export type CreateDenoiseMaskInvocation = s['CreateDenoiseMaskInvocation'];
|
||||
export type RandomIntInvocation = s['RandomIntInvocation'];
|
||||
export type CompelInvocation = s['CompelInvocation'];
|
||||
export type DynamicPromptInvocation = s['DynamicPromptInvocation'];
|
||||
|
Loading…
Reference in New Issue
Block a user