Clean-up code a bit

This commit is contained in:
Sergey Borisov 2023-08-13 19:50:48 +03:00
parent 59ba9fc0f6
commit 7a8f14d595
2 changed files with 10 additions and 11 deletions

View File

@ -34,6 +34,7 @@ from .diffusion import (
AttentionMapSaver,
InvokeAIDiffuserComponent,
PostprocessingSettings,
BasicConditioningInfo,
)
from ..util import normalize_device
@ -92,8 +93,7 @@ class AddsMaskGuidance:
mask: torch.FloatTensor
mask_latents: torch.FloatTensor
scheduler: SchedulerMixin
noise: Optional[torch.Tensor] = None
_debug: Optional[Callable] = None
noise: torch.Tensor
def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
output_class = step_output.__class__ # We'll create a new one with masked data.
@ -123,14 +123,13 @@ class AddsMaskGuidance:
# some schedulers expect t to be one-dimensional.
# TODO: file diffusers bug about inconsistency?
t = einops.repeat(t, "-> batch", batch=batch_size)
if self.noise is not None:
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
# get very confused about what is happening from step to step when we do that.
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
if self._debug:
self._debug(masked_input, f"t={t} lerped")
return masked_input
@ -202,8 +201,8 @@ class ControlNetData:
@dataclass
class ConditioningData:
unconditioned_embeddings: Any # TODO: type
text_embeddings: Any # TODO: type
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).

View File

@ -3,4 +3,4 @@ Initialization file for invokeai.models.diffusion
"""
from .cross_attention_control import InvokeAICrossAttentionMixin
from .cross_attention_map_saving import AttentionMapSaver
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings, BasicConditioningInfo, SDXLConditioningInfo