mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
inpainting for the normal model. I think it works this time.
This commit is contained in:
parent
69d42762de
commit
5c7e6751e0
@ -22,8 +22,9 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
|
from diffusers.utils.outputs import BaseOutput
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
@ -53,39 +54,76 @@ _default_personalization_config_params = dict(
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AddsMaskLatents:
|
class AddsMaskLatents:
|
||||||
|
"""Add the channels required for inpainting model input.
|
||||||
|
|
||||||
|
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
|
||||||
|
and the latent encoding of the base image.
|
||||||
|
|
||||||
|
This class assumes the same mask and base image should apply to all items in the batch.
|
||||||
|
"""
|
||||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||||
mask: torch.FloatTensor
|
mask: torch.FloatTensor
|
||||||
mask_latents: torch.FloatTensor
|
initial_image_latents: torch.FloatTensor
|
||||||
|
|
||||||
def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor:
|
def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor:
|
||||||
batch_size = latents.size(0)
|
model_input = self.add_mask_channels(latents)
|
||||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
|
||||||
mask_latents = einops.repeat(self.mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
|
||||||
model_input, _ = einops.pack([latents, mask, mask_latents], 'b * h w')
|
|
||||||
return self.forward(model_input, t, text_embeddings)
|
return self.forward(model_input, t, text_embeddings)
|
||||||
|
|
||||||
|
def add_mask_channels(self, latents):
|
||||||
|
batch_size = latents.size(0)
|
||||||
|
# duplicate mask and latents for each batch
|
||||||
|
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||||
|
image_latents = einops.repeat(self.initial_image_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||||
|
# add mask and image as additional channels
|
||||||
|
model_input, _ = einops.pack([latents, mask, image_latents], 'b * h w')
|
||||||
|
return model_input
|
||||||
|
|
||||||
|
|
||||||
|
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
|
||||||
|
return (
|
||||||
|
isinstance(b, torch.Tensor)
|
||||||
|
and (a.size() == b.size())
|
||||||
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AddsMaskGuidance:
|
class AddsMaskGuidance:
|
||||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
|
||||||
mask: torch.FloatTensor
|
mask: torch.FloatTensor
|
||||||
mask_latents: torch.FloatTensor
|
mask_latents: torch.FloatTensor
|
||||||
_scheduler: SchedulerMixin
|
_scheduler: SchedulerMixin
|
||||||
_noise_func: Callable
|
_noise_func: Callable
|
||||||
_debug: Optional[Callable] = None
|
_debug: Optional[Callable] = None
|
||||||
|
|
||||||
def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor:
|
def __call__(self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning) -> BaseOutput:
|
||||||
|
output_class = step_output.__class__ # We'll create a new one with masked data.
|
||||||
|
|
||||||
|
# The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
|
||||||
|
# It's reasonable to assume the first thing is prev_sample, but then does it have other things
|
||||||
|
# like pred_original_sample? Should we apply the mask to them too?
|
||||||
|
# But what if there's just some other random field?
|
||||||
|
prev_sample = step_output[0]
|
||||||
|
# Mask anything that has the same shape as prev_sample, return others as-is.
|
||||||
|
return output_class(
|
||||||
|
{k: (self.apply_mask(v, self._t_for_field(k, t))
|
||||||
|
if are_like_tensors(prev_sample, v) else v)
|
||||||
|
for k, v in step_output.items()}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _t_for_field(self, field_name:str, t):
|
||||||
|
if field_name == "pred_original_sample":
|
||||||
|
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0
|
||||||
|
return t
|
||||||
|
|
||||||
|
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
||||||
batch_size = latents.size(0)
|
batch_size = latents.size(0)
|
||||||
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||||
noise = self._noise_func(self.mask_latents)
|
noise = self._noise_func(self.mask_latents)
|
||||||
mask_latents = self._scheduler.add_noise(self.mask_latents, noise, t[0]) # .to(dtype=mask_latents.dtype)
|
mask_latents = self._scheduler.add_noise(self.mask_latents, noise, t)
|
||||||
|
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||||
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||||
# if self._debug:
|
|
||||||
# self._debug(latents, f"t={t[0]} latents")
|
|
||||||
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||||
if self._debug:
|
if self._debug:
|
||||||
self._debug(masked_input, f"t={t[0]} lerped")
|
self._debug(masked_input, f"t={t} lerped")
|
||||||
return self.forward(masked_input, t, text_embeddings)
|
return masked_input
|
||||||
|
|
||||||
|
|
||||||
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
|
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
|
||||||
@ -263,10 +301,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
run_id: str = None,
|
run_id: str = None,
|
||||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||||
timesteps = None,
|
timesteps = None,
|
||||||
|
additional_guidance: List[Callable] = None,
|
||||||
**extra_step_kwargs):
|
**extra_step_kwargs):
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||||
|
|
||||||
|
if additional_guidance is None:
|
||||||
|
additional_guidance = []
|
||||||
|
|
||||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
|
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
|
||||||
step_count=len(self.scheduler.timesteps))
|
step_count=len(self.scheduler.timesteps))
|
||||||
@ -289,7 +331,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
batched_t.fill_(t)
|
batched_t.fill_(t)
|
||||||
step_output = self.step(batched_t, latents, guidance_scale,
|
step_output = self.step(batched_t, latents, guidance_scale,
|
||||||
text_embeddings, unconditioned_embeddings,
|
text_embeddings, unconditioned_embeddings,
|
||||||
i, **extra_step_kwargs)
|
i, additional_guidance=additional_guidance,
|
||||||
|
**extra_step_kwargs)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||||
@ -306,11 +349,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
|
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
|
||||||
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
|
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,
|
||||||
step_index:int | None = None,
|
step_index:int | None = None, additional_guidance: List[Callable] = None,
|
||||||
**extra_step_kwargs):
|
**extra_step_kwargs):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
|
|
||||||
|
if additional_guidance is None:
|
||||||
|
additional_guidance = []
|
||||||
|
|
||||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||||
@ -323,7 +369,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index=step_index)
|
step_index=step_index)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
return self.scheduler.step(noise_pred, timestep, latents, **extra_step_kwargs)
|
step_output = self.scheduler.step(noise_pred, timestep, latents, **extra_step_kwargs)
|
||||||
|
|
||||||
|
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
|
||||||
|
# But the way things are now, scheduler runs _after_ that, so there was
|
||||||
|
# no way to use it to apply an operation that happens after the last scheduler.step.
|
||||||
|
for guidance in additional_guidance:
|
||||||
|
step_output = guidance(step_output, timestep, (unconditioned_embeddings, text_embeddings))
|
||||||
|
|
||||||
|
return step_output
|
||||||
|
|
||||||
def _unet_forward(self, latents, t, text_embeddings):
|
def _unet_forward(self, latents, t, text_embeddings):
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
@ -401,6 +455,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
||||||
|
|
||||||
|
assert img2img_pipeline.scheduler is self.scheduler
|
||||||
|
|
||||||
# 6. Prepare latent variables
|
# 6. Prepare latent variables
|
||||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||||
latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
|
latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func)
|
||||||
@ -410,13 +466,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR) \
|
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR) \
|
||||||
.to(device=device, dtype=latents_dtype)
|
.to(device=device, dtype=latents_dtype)
|
||||||
|
|
||||||
|
guidance: List[Callable] = []
|
||||||
|
|
||||||
if is_inpainting_model(self.unet):
|
if is_inpainting_model(self.unet):
|
||||||
|
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
||||||
self.invokeai_diffuser.model_forward_callback = \
|
self.invokeai_diffuser.model_forward_callback = \
|
||||||
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
|
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
|
||||||
else:
|
else:
|
||||||
self.invokeai_diffuser.model_forward_callback = \
|
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise_func))
|
||||||
AddsMaskGuidance(self._unet_forward, mask, init_image_latents,
|
|
||||||
self.scheduler, noise_func) # self.debug_latents)
|
|
||||||
|
|
||||||
result = None
|
result = None
|
||||||
|
|
||||||
@ -425,7 +482,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
latents, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
run_id=run_id, **extra_step_kwargs):
|
run_id=run_id, additional_guidance=guidance, **extra_step_kwargs):
|
||||||
if callback is not None and isinstance(result, PipelineIntermediateState):
|
if callback is not None and isinstance(result, PipelineIntermediateState):
|
||||||
callback(result)
|
callback(result)
|
||||||
if result is None:
|
if result is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user