mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Simplify handling of AddsMaskGuidance, and fix some related type errors.
This commit is contained in:
parent
ffc28176fe
commit
82a69a54ac
@ -78,8 +78,8 @@ def are_like_tensors(a: torch.Tensor, b: object) -> bool:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AddsMaskGuidance:
|
class AddsMaskGuidance:
|
||||||
mask: torch.FloatTensor
|
mask: torch.Tensor
|
||||||
mask_latents: torch.FloatTensor
|
mask_latents: torch.Tensor
|
||||||
scheduler: SchedulerMixin
|
scheduler: SchedulerMixin
|
||||||
noise: torch.Tensor
|
noise: torch.Tensor
|
||||||
gradient_mask: bool
|
gradient_mask: bool
|
||||||
@ -87,7 +87,7 @@ class AddsMaskGuidance:
|
|||||||
def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||||
return self.apply_mask(latents, t)
|
return self.apply_mask(latents, t)
|
||||||
|
|
||||||
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
def apply_mask(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size = latents.size(0)
|
batch_size = latents.size(0)
|
||||||
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
if t.dim() == 0:
|
if t.dim() == 0:
|
||||||
@ -285,11 +285,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
conditioning_data: TextConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
*,
|
|
||||||
noise: Optional[torch.Tensor],
|
noise: Optional[torch.Tensor],
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
init_timestep: torch.Tensor,
|
init_timestep: torch.Tensor,
|
||||||
additional_guidance: List[Callable] = None,
|
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
@ -302,9 +300,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if init_timestep.shape[0] == 0:
|
if init_timestep.shape[0] == 0:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
if additional_guidance is None:
|
|
||||||
additional_guidance = []
|
|
||||||
|
|
||||||
orig_latents = latents.clone()
|
orig_latents = latents.clone()
|
||||||
|
|
||||||
batch_size = latents.shape[0]
|
batch_size = latents.shape[0]
|
||||||
@ -314,6 +309,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||||
|
|
||||||
|
mask_guidance: AddsMaskGuidance | None = None
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
if is_inpainting_model(self.unet):
|
if is_inpainting_model(self.unet):
|
||||||
if masked_latents is None:
|
if masked_latents is None:
|
||||||
@ -332,7 +328,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||||
|
|
||||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
mask_guidance = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
latents = self.generate_latents_from_embeddings(
|
latents = self.generate_latents_from_embeddings(
|
||||||
@ -340,7 +336,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
additional_guidance=additional_guidance,
|
mask_guidance=mask_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
@ -367,16 +363,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data: TextConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
*,
|
mask_guidance: AddsMaskGuidance | None = None,
|
||||||
additional_guidance: List[Callable] = None,
|
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
if additional_guidance is None:
|
|
||||||
additional_guidance = []
|
|
||||||
|
|
||||||
batch_size = latents.shape[0]
|
batch_size = latents.shape[0]
|
||||||
|
|
||||||
@ -412,7 +405,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# print("timesteps:", timesteps)
|
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
batched_t = t.expand(batch_size)
|
batched_t = t.expand(batch_size)
|
||||||
step_output = self.step(
|
step_output = self.step(
|
||||||
@ -422,7 +414,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
additional_guidance=additional_guidance,
|
mask_guidance=mask_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
@ -453,19 +445,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
additional_guidance: List[Callable] = None,
|
mask_guidance: AddsMaskGuidance | None = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
if additional_guidance is None:
|
|
||||||
additional_guidance = []
|
|
||||||
|
|
||||||
# one day we will expand this extension point, but for now it just does denoise masking
|
if mask_guidance is not None:
|
||||||
for guidance in additional_guidance:
|
latents = mask_guidance(latents, timestep)
|
||||||
latents = guidance(latents, timestep)
|
|
||||||
|
|
||||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
@ -541,17 +530,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
||||||
|
|
||||||
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
|
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting
|
||||||
for guidance in additional_guidance:
|
# again.
|
||||||
# apply the mask to any "denoised" or "pred_original_sample" fields
|
if mask_guidance is not None:
|
||||||
|
# Apply the mask to any "denoised" or "pred_original_sample" fields.
|
||||||
if hasattr(step_output, "denoised"):
|
if hasattr(step_output, "denoised"):
|
||||||
step_output.pred_original_sample = guidance(step_output.denoised, self.scheduler.timesteps[-1])
|
step_output.pred_original_sample = mask_guidance(step_output.denoised, self.scheduler.timesteps[-1])
|
||||||
elif hasattr(step_output, "pred_original_sample"):
|
elif hasattr(step_output, "pred_original_sample"):
|
||||||
step_output.pred_original_sample = guidance(
|
step_output.pred_original_sample = mask_guidance(
|
||||||
step_output.pred_original_sample, self.scheduler.timesteps[-1]
|
step_output.pred_original_sample, self.scheduler.timesteps[-1]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
step_output.pred_original_sample = guidance(latents, self.scheduler.timesteps[-1])
|
step_output.pred_original_sample = mask_guidance(latents, self.scheduler.timesteps[-1])
|
||||||
|
|
||||||
return step_output
|
return step_output
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user