mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Apply denoising_start/end, add torch-sdp to memory effictiend attention func
This commit is contained in:
parent
b0738b7f70
commit
2539e26c18
@ -320,8 +320,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True)
|
||||
|
||||
print(f"{c1.shape=} {c2.shape=} {c2_pooled.shape=} {self.prompt=}")
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
target_size = (self.target_height, self.target_width)
|
||||
|
@ -122,6 +122,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
@ -310,6 +311,25 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
return control_data
|
||||
|
||||
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
||||
# apply denoising_start
|
||||
num_inference_steps = steps
|
||||
scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
t_start = int(round(denoising_start * num_inference_steps))
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||
num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# apply denoising_end
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
|
||||
|
||||
skipped_final_steps = int(round((1 - denoising_end) * steps))
|
||||
num_inference_steps = num_inference_steps - skipped_final_steps
|
||||
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
|
||||
|
||||
return num_inference_steps, timesteps
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
with SilenceWarnings():
|
||||
@ -360,11 +380,20 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
num_inference_steps, timesteps = self.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=0.0,
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
timesteps=timesteps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
@ -385,8 +414,12 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
type: Literal["l2l"] = "l2l"
|
||||
|
||||
# Inputs
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use (test override for future optional)")
|
||||
|
||||
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
||||
#denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
@ -405,6 +438,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
||||
noise = None
|
||||
if self.noise is not None:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
@ -432,6 +467,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
unet_info.context.model, _lora_loader()
|
||||
), unet_info as unet:
|
||||
if noise is not None:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
@ -448,28 +484,30 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
model=pipeline,
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
latents_shape=latent.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
initial_latents = (
|
||||
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
|
||||
)
|
||||
initial_latents = latent
|
||||
if self.denoising_start <= 0.0:
|
||||
initial_latents = torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
|
||||
|
||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
num_inference_steps, timesteps = self.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
|
@ -340,9 +340,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
if torch.cuda.is_available() and is_xformers_available() and not config.disable_xformers:
|
||||
if self.unet.device.type == "cuda":
|
||||
if is_xformers_available() and not config.disable_xformers:
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
return
|
||||
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
# diffusers enable sdp automatically
|
||||
return
|
||||
|
||||
|
||||
if self.device.type == "cpu" or self.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.device.type == "cuda":
|
||||
@ -398,7 +404,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
noise: Optional[torch.Tensor],
|
||||
timesteps=None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
run_id=None,
|
||||
@ -434,7 +440,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
timesteps,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: torch.Tensor,
|
||||
noise: Optional[torch.Tensor],
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
@ -457,6 +463,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
dtype=timesteps.dtype,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
if noise is not None:
|
||||
#latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user