inpainting for the normal model. I think it works this time.

This commit is contained in:
Kevin Turner 2022-12-05 12:36:50 -08:00
parent 69d42762de
commit 5c7e6751e0

View File

@ -22,8 +22,9 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@ -53,39 +54,76 @@ _default_personalization_config_params = dict(
@dataclass
class AddsMaskLatents:
"""Add the channels required for inpainting model input.
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
and the latent encoding of the base image.
This class assumes the same mask and base image should apply to all items in the batch.
"""
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
mask: torch.FloatTensor
mask_latents: torch.FloatTensor
initial_image_latents: torch.FloatTensor
def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor:
batch_size = latents.size(0)
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
mask_latents = einops.repeat(self.mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
model_input, _ = einops.pack([latents, mask, mask_latents], 'b * h w')
model_input = self.add_mask_channels(latents)
return self.forward(model_input, t, text_embeddings)
def add_mask_channels(self, latents):
batch_size = latents.size(0)
# duplicate mask and latents for each batch
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
image_latents = einops.repeat(self.initial_image_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
# add mask and image as additional channels
model_input, _ = einops.pack([latents, mask, image_latents], 'b * h w')
return model_input
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
return (
isinstance(b, torch.Tensor)
and (a.size() == b.size())
)
@dataclass
class AddsMaskGuidance:
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
mask: torch.FloatTensor
mask_latents: torch.FloatTensor
_scheduler: SchedulerMixin
_noise_func: Callable
_debug: Optional[Callable] = None
def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor:
def __call__(self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning) -> BaseOutput:
output_class = step_output.__class__ # We'll create a new one with masked data.
# The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
# It's reasonable to assume the first thing is prev_sample, but then does it have other things
# like pred_original_sample? Should we apply the mask to them too?
# But what if there's just some other random field?
prev_sample = step_output[0]
# Mask anything that has the same shape as prev_sample, return others as-is.
return output_class(
{k: (self.apply_mask(v, self._t_for_field(k, t))
if are_like_tensors(prev_sample, v) else v)
for k, v in step_output.items()}
)
def _t_for_field(self, field_name:str, t):
if field_name == "pred_original_sample":
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0
return t
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
batch_size = latents.size(0)
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
noise = self._noise_func(self.mask_latents)
mask_latents = self._scheduler.add_noise(self.mask_latents, noise, t[0]) # .to(dtype=mask_latents.dtype)
mask_latents = self._scheduler.add_noise(self.mask_latents, noise, t)
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
# if self._debug:
# self._debug(latents, f"t={t[0]} latents")
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
if self._debug:
self._debug(masked_input, f"t={t[0]} lerped")
return self.forward(masked_input, t, text_embeddings)
self._debug(masked_input, f"t={t} lerped")
return masked_input
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
@ -263,10 +301,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id: str = None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
timesteps = None,
additional_guidance: List[Callable] = None,
**extra_step_kwargs):
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None:
additional_guidance = []
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
step_count=len(self.scheduler.timesteps))
@ -289,7 +331,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
batched_t.fill_(t)
step_output = self.step(batched_t, latents, guidance_scale,
text_embeddings, unconditioned_embeddings,
i, **extra_step_kwargs)
i, additional_guidance=additional_guidance,
**extra_step_kwargs)
latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None)
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
@ -306,11 +349,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
@torch.inference_mode()
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
step_index:int | None = None,
step_index:int | None = None, additional_guidance: List[Callable] = None,
**extra_step_kwargs):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
if additional_guidance is None:
additional_guidance = []
# TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
@ -323,7 +369,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=step_index)
# compute the previous noisy sample x_t -> x_t-1
return self.scheduler.step(noise_pred, timestep, latents, **extra_step_kwargs)
step_output = self.scheduler.step(noise_pred, timestep, latents, **extra_step_kwargs)
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
# But the way things are now, scheduler runs _after_ that, so there was
# no way to use it to apply an operation that happens after the last scheduler.step.
for guidance in additional_guidance:
step_output = guidance(step_output, timestep, (unconditioned_embeddings, text_embeddings))
return step_output
def _unet_forward(self, latents, t, text_embeddings):
# predict the noise residual
@ -401,6 +455,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
assert img2img_pipeline.scheduler is self.scheduler
# 6. Prepare latent variables
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
@ -410,13 +466,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR) \
.to(device=device, dtype=latents_dtype)
guidance: List[Callable] = []
if is_inpainting_model(self.unet):
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
self.invokeai_diffuser.model_forward_callback = \
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
else:
self.invokeai_diffuser.model_forward_callback = \
AddsMaskGuidance(self._unet_forward, mask, init_image_latents,
self.scheduler, noise_func) # self.debug_latents)
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise_func))
result = None
@ -425,7 +482,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps,
run_id=run_id, **extra_step_kwargs):
run_id=run_id, additional_guidance=guidance, **extra_step_kwargs):
if callback is not None and isinstance(result, PipelineIntermediateState):
callback(result)
if result is None: