mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add Initial (non-working) Seamless Implementation
This commit is contained in:
parent
b5a83bbc8a
commit
95883c2efd
@ -21,6 +21,8 @@ from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import (
|
||||
DenoiseMaskField,
|
||||
DenoiseMaskOutput,
|
||||
ImageField,
|
||||
ImageOutput,
|
||||
LatentsField,
|
||||
@ -31,8 +33,8 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
@ -44,16 +46,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import Post
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, UIType, tags, title
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
@ -64,6 +57,72 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||
|
||||
|
||||
@title("Create Denoise Mask")
|
||||
@tags("mask", "denoise")
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
|
||||
# Metadata
|
||||
type: Literal["create_denoise_mask"] = "create_denoise_mask"
|
||||
|
||||
# Inputs
|
||||
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32, ui_order=4)
|
||||
|
||||
def prep_mask_tensor(self, mask_image):
|
||||
if mask_image.mode != "L":
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
if mask_tensor.dim() == 3:
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
# if shape is not None:
|
||||
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
|
||||
return mask_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
if self.image is not None:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
else:
|
||||
image = None
|
||||
|
||||
mask = self.prep_mask_tensor(
|
||||
context.services.images.get_pil_image(self.mask.image_name),
|
||||
)
|
||||
|
||||
if image is not None:
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR)
|
||||
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
# TODO:
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||
|
||||
masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents"
|
||||
context.services.latents.save(masked_latents_name, masked_latents)
|
||||
else:
|
||||
masked_latents_name = None
|
||||
|
||||
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
|
||||
context.services.latents.save(mask_name, mask)
|
||||
|
||||
return DenoiseMaskOutput(
|
||||
denoise_mask=DenoiseMaskField(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=masked_latents_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
context: InvocationContext,
|
||||
scheduler_info: ModelInfo,
|
||||
@ -126,13 +185,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
description=FieldDescriptions.latents, input=Input.Connection, ui_order=4
|
||||
)
|
||||
mask: Optional[ImageField] = InputField(
|
||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.mask,
|
||||
)
|
||||
seamless: bool = InputField(default=False, description="Enable or disable seamless padding")
|
||||
seamless_axes: str = InputField(default="xy", description="Specify which axes are seamless: 'x', 'y', or 'xy'")
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
@ -199,17 +258,46 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
|
||||
def configure_model_padding(self, model):
|
||||
"""
|
||||
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options.
|
||||
"""
|
||||
|
||||
def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
"""
|
||||
Patch for Conv2d._conv_forward that supports asymmetric padding
|
||||
"""
|
||||
working = torch.nn.functional.pad(input, self.asymmetric_padding['x'], mode=self.asymmetric_padding_mode['x'])
|
||||
working = torch.nn.functional.pad(working, self.asymmetric_padding['y'], mode=self.asymmetric_padding_mode['y'])
|
||||
return torch.nn.functional.conv2d(working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups)
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||
if self.seamless:
|
||||
m.asymmetric_padding_mode = {}
|
||||
m.asymmetric_padding = {}
|
||||
m.asymmetric_padding_mode['x'] = 'circular' if ('x' in self.seamless_axes) else 'constant'
|
||||
m.asymmetric_padding['x'] = (m._reversed_padding_repeated_twice[0], m._reversed_padding_repeated_twice[1], 0, 0)
|
||||
m.asymmetric_padding_mode['y'] = 'circular' if ('y' in self.seamless_axes) else 'constant'
|
||||
m.asymmetric_padding['y'] = (0, 0, m._reversed_padding_repeated_twice[2], m._reversed_padding_repeated_twice[3])
|
||||
m._conv_forward = _conv_forward_asymmetric.__get__(m, torch.nn.Conv2d)
|
||||
else:
|
||||
m._conv_forward = torch.nn.Conv2d._conv_forward.__get__(m, torch.nn.Conv2d)
|
||||
if hasattr(m, 'asymmetric_padding_mode'):
|
||||
del m.asymmetric_padding_mode
|
||||
if hasattr(m, 'asymmetric_padding'):
|
||||
del m.asymmetric_padding
|
||||
|
||||
def create_pipeline(
|
||||
self,
|
||||
unet,
|
||||
scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
# TODO:
|
||||
# configure_model_padding(
|
||||
# unet,
|
||||
# self.seamless,
|
||||
# self.seamless_axes,
|
||||
# )
|
||||
|
||||
self.configure_model_padding(
|
||||
unet
|
||||
)
|
||||
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
@ -342,19 +430,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep
|
||||
|
||||
def prep_mask_tensor(self, mask, context, lantents):
|
||||
if mask is None:
|
||||
return None
|
||||
def prep_inpaint_mask(self, context, latents):
|
||||
if self.denoise_mask is None:
|
||||
return None, None
|
||||
|
||||
mask_image = context.services.images.get_pil_image(mask.image_name)
|
||||
if mask_image.mode != "L":
|
||||
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
if mask_tensor.dim() == 3:
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
mask_tensor = tv_resize(mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR)
|
||||
return 1 - mask_tensor
|
||||
mask = context.services.latents.get(self.denoise_mask.mask_name)
|
||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)
|
||||
if self.denoise_mask.masked_latents_name is not None:
|
||||
masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name)
|
||||
else:
|
||||
masked_latents = None
|
||||
|
||||
return 1 - mask, masked_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
@ -375,7 +462,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if seed is None:
|
||||
seed = 0
|
||||
|
||||
mask = self.prep_mask_tensor(self.mask, context, latents)
|
||||
mask, masked_latents = self.prep_inpaint_mask(context, latents)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
@ -406,6 +493,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
if mask is not None:
|
||||
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
||||
if masked_latents is not None:
|
||||
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
@ -442,6 +531,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
mask=mask,
|
||||
masked_latents=masked_latents,
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
@ -663,26 +753,11 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
# image = context.services.images.get(
|
||||
# self.image.image_type, self.image.image_name
|
||||
# )
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info, upcast, tiled, image_tensor):
|
||||
with vae_info as vae:
|
||||
orig_dtype = vae.dtype
|
||||
if self.fp32:
|
||||
if upcast:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
@ -707,7 +782,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
vae.to(dtype=torch.float16)
|
||||
# latents = latents.half()
|
||||
|
||||
if self.tiled:
|
||||
if tiled:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
@ -721,6 +796,23 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
latents = vae.config.scaling_factor * latents
|
||||
latents = latents.to(dtype=orig_dtype)
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
latents = latents.to("cpu")
|
||||
context.services.latents.save(name, latents)
|
||||
|
Loading…
x
Reference in New Issue
Block a user