mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Apply denoising_start/end, add torch-sdp to memory effictiend attention func
This commit is contained in:
parent
b0738b7f70
commit
2539e26c18
@ -320,8 +320,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
else:
|
else:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True)
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True)
|
||||||
|
|
||||||
print(f"{c1.shape=} {c2.shape=} {c2_pooled.shape=} {self.prompt=}")
|
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
target_size = (self.target_height, self.target_width)
|
target_size = (self.target_height, self.target_width)
|
||||||
|
@ -122,6 +122,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
|
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
@ -310,6 +311,25 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||||
return control_data
|
return control_data
|
||||||
|
|
||||||
|
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
||||||
|
# apply denoising_start
|
||||||
|
num_inference_steps = steps
|
||||||
|
scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
|
||||||
|
t_start = int(round(denoising_start * num_inference_steps))
|
||||||
|
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||||
|
num_inference_steps = num_inference_steps - t_start
|
||||||
|
|
||||||
|
# apply denoising_end
|
||||||
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
|
||||||
|
|
||||||
|
skipped_final_steps = int(round((1 - denoising_end) * steps))
|
||||||
|
num_inference_steps = num_inference_steps - skipped_final_steps
|
||||||
|
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
|
||||||
|
|
||||||
|
return num_inference_steps, timesteps
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
@ -360,11 +380,20 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
exit_stack=exit_stack,
|
exit_stack=exit_stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_inference_steps, timesteps = self.init_scheduler(
|
||||||
|
scheduler,
|
||||||
|
device=unet.device,
|
||||||
|
steps=self.steps,
|
||||||
|
denoising_start=0.0,
|
||||||
|
denoising_end=self.denoising_end,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||||
noise=noise,
|
noise=noise,
|
||||||
num_inference_steps=self.steps,
|
timesteps=timesteps,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
control_data=control_data, # list[ControlNetData]
|
control_data=control_data, # list[ControlNetData]
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
@ -385,8 +414,12 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
type: Literal["l2l"] = "l2l"
|
type: Literal["l2l"] = "l2l"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
|
noise: Optional[LatentsField] = Field(description="The noise to use (test override for future optional)")
|
||||||
|
|
||||||
|
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
||||||
|
#denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||||
|
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||||
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
|
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -405,6 +438,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
||||||
|
noise = None
|
||||||
|
if self.noise is not None:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
latent = context.services.latents.get(self.latents.latents_name)
|
latent = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
@ -432,6 +467,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||||
unet_info.context.model, _lora_loader()
|
unet_info.context.model, _lora_loader()
|
||||||
), unet_info as unet:
|
), unet_info as unet:
|
||||||
|
if noise is not None:
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
@ -448,28 +484,30 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
model=pipeline,
|
model=pipeline,
|
||||||
context=context,
|
context=context,
|
||||||
control_input=self.control,
|
control_input=self.control,
|
||||||
latents_shape=noise.shape,
|
latents_shape=latent.shape,
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
exit_stack=exit_stack,
|
exit_stack=exit_stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
initial_latents = (
|
initial_latents = latent
|
||||||
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
|
if self.denoising_start <= 0.0:
|
||||||
)
|
initial_latents = torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
|
||||||
|
|
||||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
num_inference_steps, timesteps = self.init_scheduler(
|
||||||
self.steps,
|
scheduler,
|
||||||
self.strength,
|
|
||||||
device=unet.device,
|
device=unet.device,
|
||||||
|
steps=self.steps,
|
||||||
|
denoising_start=self.denoising_start,
|
||||||
|
denoising_end=self.denoising_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||||
latents=initial_latents,
|
latents=initial_latents,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
num_inference_steps=self.steps,
|
num_inference_steps=num_inference_steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
control_data=control_data, # list[ControlNetData]
|
control_data=control_data, # list[ControlNetData]
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
|
@ -340,9 +340,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if xformers is available, use it, otherwise use sliced attention.
|
if xformers is available, use it, otherwise use sliced attention.
|
||||||
"""
|
"""
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
if torch.cuda.is_available() and is_xformers_available() and not config.disable_xformers:
|
if self.unet.device.type == "cuda":
|
||||||
|
if is_xformers_available() and not config.disable_xformers:
|
||||||
self.enable_xformers_memory_efficient_attention()
|
self.enable_xformers_memory_efficient_attention()
|
||||||
else:
|
return
|
||||||
|
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||||
|
# diffusers enable sdp automatically
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
if self.device.type == "cpu" or self.device.type == "mps":
|
if self.device.type == "cpu" or self.device.type == "mps":
|
||||||
mem_free = psutil.virtual_memory().free
|
mem_free = psutil.virtual_memory().free
|
||||||
elif self.device.type == "cuda":
|
elif self.device.type == "cuda":
|
||||||
@ -398,7 +404,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: ConditioningData,
|
||||||
*,
|
*,
|
||||||
noise: torch.Tensor,
|
noise: Optional[torch.Tensor],
|
||||||
timesteps=None,
|
timesteps=None,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
@ -434,7 +440,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: ConditioningData,
|
||||||
*,
|
*,
|
||||||
noise: torch.Tensor,
|
noise: Optional[torch.Tensor],
|
||||||
run_id: str = None,
|
run_id: str = None,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
@ -457,6 +463,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
dtype=timesteps.dtype,
|
dtype=timesteps.dtype,
|
||||||
device=self._model_group.device_for(self.unet),
|
device=self._model_group.device_for(self.unet),
|
||||||
)
|
)
|
||||||
|
if noise is not None:
|
||||||
#latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
#latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user