mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add mask to l2l
This commit is contained in:
parent
5f29526a8e
commit
96b7248051
invokeai
@ -41,6 +41,9 @@ from diffusers.models.attention_processor import (
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
|
||||
import torchvision.transforms as T
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
@ -397,6 +400,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
timesteps=timesteps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
@ -424,7 +428,11 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
||||
#denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
|
||||
mask: Optional[ImageField] = Field(
|
||||
None, description="Mask",
|
||||
)
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
@ -440,6 +448,22 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
},
|
||||
}
|
||||
|
||||
def prep_mask_tensor(self, mask, context, lantents):
|
||||
if mask is None:
|
||||
return None
|
||||
|
||||
mask_image = context.services.images.get_pil_image(mask.image_name)
|
||||
if mask_image.mode != "L":
|
||||
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
if mask_tensor.dim() == 3:
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
mask_tensor = tv_resize(
|
||||
mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR
|
||||
)
|
||||
return mask_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
||||
@ -452,6 +476,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
if self.noise.seed is not None:
|
||||
seed = self.noise.seed
|
||||
|
||||
mask = self.prep_mask_tensor(self.mask, context, latent)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
@ -479,6 +505,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
if noise is not None:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
||||
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
@ -516,6 +543,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
mask=mask,
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
|
@ -100,7 +100,7 @@ class AddsMaskGuidance:
|
||||
mask: torch.FloatTensor
|
||||
mask_latents: torch.FloatTensor
|
||||
scheduler: SchedulerMixin
|
||||
noise: torch.Tensor
|
||||
noise: Optional[torch.Tensor] = None
|
||||
_debug: Optional[Callable] = None
|
||||
|
||||
def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
|
||||
@ -131,11 +131,10 @@ class AddsMaskGuidance:
|
||||
# some schedulers expect t to be one-dimensional.
|
||||
# TODO: file diffusers bug about inconsistency?
|
||||
t = einops.repeat(t, "-> batch", batch=batch_size)
|
||||
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
||||
# get very confused about what is happening from step to step when we do that.
|
||||
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
|
||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||
|
||||
if self.noise is not None:
|
||||
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
|
||||
|
||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||
if self._debug:
|
||||
@ -408,7 +407,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance: List[Callable] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
# TODO:
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device("cpu")
|
||||
else:
|
||||
@ -417,19 +419,74 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if timesteps is None:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
infer_latents_from_embeddings = GeneratorToCallbackinator(
|
||||
self.generate_latents_from_embeddings, PipelineIntermediateState
|
||||
)
|
||||
result: PipelineIntermediateState = infer_latents_from_embeddings(
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
noise=noise,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
callback=callback,
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
orig_latents = latents.clone()
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
batched_t = torch.full(
|
||||
(batch_size,),
|
||||
timesteps[0],
|
||||
dtype=timesteps.dtype,
|
||||
device=self.unet.device,
|
||||
)
|
||||
return result.latents, result.attention_map_saver
|
||||
|
||||
if noise is not None:
|
||||
#latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
else:
|
||||
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
|
||||
if mask is not None:
|
||||
noise = torch.randn(
|
||||
orig_latents.shape,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
||||
|
||||
|
||||
if mask is not None:
|
||||
if is_inpainting_model(self.unet):
|
||||
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
|
||||
# (that's why there's a mask!) but it seems to really want that blanked out.
|
||||
#masked_latents = latents * torch.where(mask < 0.5, 1, 0) TODO: inpaint/outpaint/infill
|
||||
|
||||
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
||||
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
||||
self._unet_forward, mask, orig_latents
|
||||
)
|
||||
else:
|
||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
|
||||
|
||||
try:
|
||||
result: PipelineIntermediateState = infer_latents_from_embeddings(
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
callback=callback,
|
||||
)
|
||||
finally:
|
||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||
|
||||
latents = result.latents
|
||||
|
||||
# restore unmasked part
|
||||
if mask is not None:
|
||||
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
||||
|
||||
return latents, result.attention_map_saver
|
||||
|
||||
def generate_latents_from_embeddings(
|
||||
self,
|
||||
@ -437,7 +494,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
timesteps,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: Optional[torch.Tensor],
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
):
|
||||
@ -457,9 +513,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
dtype=timesteps.dtype,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
if noise is not None:
|
||||
#latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
yield PipelineIntermediateState(
|
||||
step=-1,
|
||||
|
Loading…
x
Reference in New Issue
Block a user