mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
inpainting for the normal model. I think it works this time.
This commit is contained in:
parent
69d42762de
commit
5c7e6751e0
@ -22,8 +22,9 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
@ -53,39 +54,76 @@ _default_personalization_config_params = dict(
|
||||
|
||||
@dataclass
|
||||
class AddsMaskLatents:
|
||||
"""Add the channels required for inpainting model input.
|
||||
|
||||
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
|
||||
and the latent encoding of the base image.
|
||||
|
||||
This class assumes the same mask and base image should apply to all items in the batch.
|
||||
"""
|
||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
mask: torch.FloatTensor
|
||||
mask_latents: torch.FloatTensor
|
||||
initial_image_latents: torch.FloatTensor
|
||||
|
||||
def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
mask_latents = einops.repeat(self.mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
model_input, _ = einops.pack([latents, mask, mask_latents], 'b * h w')
|
||||
model_input = self.add_mask_channels(latents)
|
||||
return self.forward(model_input, t, text_embeddings)
|
||||
|
||||
def add_mask_channels(self, latents):
|
||||
batch_size = latents.size(0)
|
||||
# duplicate mask and latents for each batch
|
||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
image_latents = einops.repeat(self.initial_image_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
# add mask and image as additional channels
|
||||
model_input, _ = einops.pack([latents, mask, image_latents], 'b * h w')
|
||||
return model_input
|
||||
|
||||
|
||||
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
|
||||
return (
|
||||
isinstance(b, torch.Tensor)
|
||||
and (a.size() == b.size())
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class AddsMaskGuidance:
|
||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
mask: torch.FloatTensor
|
||||
mask_latents: torch.FloatTensor
|
||||
_scheduler: SchedulerMixin
|
||||
_noise_func: Callable
|
||||
_debug: Optional[Callable] = None
|
||||
|
||||
def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor:
|
||||
def __call__(self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning) -> BaseOutput:
|
||||
output_class = step_output.__class__ # We'll create a new one with masked data.
|
||||
|
||||
# The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
|
||||
# It's reasonable to assume the first thing is prev_sample, but then does it have other things
|
||||
# like pred_original_sample? Should we apply the mask to them too?
|
||||
# But what if there's just some other random field?
|
||||
prev_sample = step_output[0]
|
||||
# Mask anything that has the same shape as prev_sample, return others as-is.
|
||||
return output_class(
|
||||
{k: (self.apply_mask(v, self._t_for_field(k, t))
|
||||
if are_like_tensors(prev_sample, v) else v)
|
||||
for k, v in step_output.items()}
|
||||
)
|
||||
|
||||
def _t_for_field(self, field_name:str, t):
|
||||
if field_name == "pred_original_sample":
|
||||
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0
|
||||
return t
|
||||
|
||||
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
noise = self._noise_func(self.mask_latents)
|
||||
mask_latents = self._scheduler.add_noise(self.mask_latents, noise, t[0]) # .to(dtype=mask_latents.dtype)
|
||||
mask_latents = self._scheduler.add_noise(self.mask_latents, noise, t)
|
||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||
# if self._debug:
|
||||
# self._debug(latents, f"t={t[0]} latents")
|
||||
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||
if self._debug:
|
||||
self._debug(masked_input, f"t={t[0]} lerped")
|
||||
return self.forward(masked_input, t, text_embeddings)
|
||||
self._debug(masked_input, f"t={t} lerped")
|
||||
return masked_input
|
||||
|
||||
|
||||
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
|
||||
@ -263,10 +301,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
run_id: str = None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||
timesteps = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
**extra_step_kwargs):
|
||||
if run_id is None:
|
||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps))
|
||||
@ -289,7 +331,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(batched_t, latents, guidance_scale,
|
||||
text_embeddings, unconditioned_embeddings,
|
||||
i, **extra_step_kwargs)
|
||||
i, additional_guidance=additional_guidance,
|
||||
**extra_step_kwargs)
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||
@ -306,11 +349,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
@torch.inference_mode()
|
||||
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
|
||||
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
|
||||
step_index:int | None = None,
|
||||
step_index:int | None = None, additional_guidance: List[Callable] = None,
|
||||
**extra_step_kwargs):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
@ -323,7 +369,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index=step_index)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
return self.scheduler.step(noise_pred, timestep, latents, **extra_step_kwargs)
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **extra_step_kwargs)
|
||||
|
||||
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
|
||||
# But the way things are now, scheduler runs _after_ that, so there was
|
||||
# no way to use it to apply an operation that happens after the last scheduler.step.
|
||||
for guidance in additional_guidance:
|
||||
step_output = guidance(step_output, timestep, (unconditioned_embeddings, text_embeddings))
|
||||
|
||||
return step_output
|
||||
|
||||
def _unet_forward(self, latents, t, text_embeddings):
|
||||
# predict the noise residual
|
||||
@ -401,6 +455,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
||||
|
||||
assert img2img_pipeline.scheduler is self.scheduler
|
||||
|
||||
# 6. Prepare latent variables
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
|
||||
@ -410,13 +466,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR) \
|
||||
.to(device=device, dtype=latents_dtype)
|
||||
|
||||
guidance: List[Callable] = []
|
||||
|
||||
if is_inpainting_model(self.unet):
|
||||
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
||||
self.invokeai_diffuser.model_forward_callback = \
|
||||
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
|
||||
else:
|
||||
self.invokeai_diffuser.model_forward_callback = \
|
||||
AddsMaskGuidance(self._unet_forward, mask, init_image_latents,
|
||||
self.scheduler, noise_func) # self.debug_latents)
|
||||
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise_func))
|
||||
|
||||
result = None
|
||||
|
||||
@ -425,7 +482,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
timesteps=timesteps,
|
||||
run_id=run_id, **extra_step_kwargs):
|
||||
run_id=run_id, additional_guidance=guidance, **extra_step_kwargs):
|
||||
if callback is not None and isinstance(result, PipelineIntermediateState):
|
||||
callback(result)
|
||||
if result is None:
|
||||
|
Loading…
Reference in New Issue
Block a user