Apply denoising_start/end, add torch-sdp to memory effictiend attention func

This commit is contained in:
Sergey Borisov 2023-08-07 19:57:11 +03:00
parent b0738b7f70
commit 2539e26c18
3 changed files with 88 additions and 45 deletions

View File

@ -320,8 +320,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
else: else:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True) c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True)
print(f"{c1.shape=} {c2.shape=} {c2_pooled.shape=} {self.prompt=}")
original_size = (self.original_height, self.original_width) original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left) crop_coords = (self.crop_top, self.crop_left)
target_size = (self.target_height, self.target_width) target_size = (self.target_height, self.target_width)

View File

@ -122,6 +122,7 @@ class TextToLatentsInvocation(BaseInvocation):
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on # fmt: on
@ -310,6 +311,25 @@ class TextToLatentsInvocation(BaseInvocation):
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data return control_data
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
# apply denoising_start
num_inference_steps = steps
scheduler.set_timesteps(num_inference_steps, device=device)
t_start = int(round(denoising_start * num_inference_steps))
timesteps = scheduler.timesteps[t_start * scheduler.order :]
num_inference_steps = num_inference_steps - t_start
# apply denoising_end
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
skipped_final_steps = int(round((1 - denoising_end) * steps))
num_inference_steps = num_inference_steps - skipped_final_steps
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
return num_inference_steps, timesteps
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
with SilenceWarnings(): with SilenceWarnings():
@ -360,11 +380,20 @@ class TextToLatentsInvocation(BaseInvocation):
exit_stack=exit_stack, exit_stack=exit_stack,
) )
num_inference_steps, timesteps = self.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=0.0,
denoising_end=self.denoising_end,
)
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise, noise=noise,
num_inference_steps=self.steps, timesteps=timesteps,
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData] control_data=control_data, # list[ControlNetData]
callback=step_callback, callback=step_callback,
@ -385,8 +414,12 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
type: Literal["l2l"] = "l2l" type: Literal["l2l"] = "l2l"
# Inputs # Inputs
noise: Optional[LatentsField] = Field(description="The noise to use (test override for future optional)")
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
#denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
latents: Optional[LatentsField] = Field(description="The latents to use as a base image") latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -405,7 +438,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
with SilenceWarnings(): # this quenches NSFW nag from diffusers with SilenceWarnings(): # this quenches NSFW nag from diffusers
noise = context.services.latents.get(self.noise.latents_name) noise = None
if self.noise is not None:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name) latent = context.services.latents.get(self.latents.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
@ -432,7 +467,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet( with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader() unet_info.context.model, _lora_loader()
), unet_info as unet: ), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype) if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype) latent = latent.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler( scheduler = get_scheduler(
@ -448,28 +484,30 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
model=pipeline, model=pipeline,
context=context, context=context,
control_input=self.control, control_input=self.control,
latents_shape=noise.shape, latents_shape=latent.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
exit_stack=exit_stack, exit_stack=exit_stack,
) )
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
initial_latents = ( initial_latents = latent
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype) if self.denoising_start <= 0.0:
) initial_latents = torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
timesteps, _ = pipeline.get_img2img_timesteps( num_inference_steps, timesteps = self.init_scheduler(
self.steps, scheduler,
self.strength,
device=unet.device, device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
) )
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents, latents=initial_latents,
timesteps=timesteps, timesteps=timesteps,
noise=noise, noise=noise,
num_inference_steps=self.steps, num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData] control_data=control_data, # list[ControlNetData]
callback=step_callback, callback=step_callback,

View File

@ -340,33 +340,39 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if xformers is available, use it, otherwise use sliced attention. if xformers is available, use it, otherwise use sliced attention.
""" """
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
if torch.cuda.is_available() and is_xformers_available() and not config.disable_xformers: if self.unet.device.type == "cuda":
self.enable_xformers_memory_efficient_attention() if is_xformers_available() and not config.disable_xformers:
self.enable_xformers_memory_efficient_attention()
return
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enable sdp automatically
return
if self.device.type == "cpu" or self.device.type == "mps":
mem_free = psutil.virtual_memory().free
elif self.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
else: else:
if self.device.type == "cpu" or self.device.type == "mps": raise ValueError(f"unrecognized device {self.device}")
mem_free = psutil.virtual_memory().free # input tensor of [1, 4, h/8, w/8]
elif self.device.type == "cuda": # output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device)) bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
else: max_size_required_for_baddbmm = (
raise ValueError(f"unrecognized device {self.device}") 16
# input tensor of [1, 4, h/8, w/8] * latents.size(dim=2)
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)] * latents.size(dim=3)
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 * latents.size(dim=2)
max_size_required_for_baddbmm = ( * latents.size(dim=3)
16 * bytes_per_element_needed_for_baddbmm_duplication
* latents.size(dim=2) )
* latents.size(dim=3) if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
* latents.size(dim=2) self.enable_attention_slicing(slice_size="max")
* latents.size(dim=3) elif torch.backends.mps.is_available():
* bytes_per_element_needed_for_baddbmm_duplication # diffusers recommends always enabling for mps
) self.enable_attention_slicing(slice_size="max")
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code else:
self.enable_attention_slicing(slice_size="max") self.disable_attention_slicing()
elif torch.backends.mps.is_available():
# diffusers recommends always enabling for mps
self.enable_attention_slicing(slice_size="max")
else:
self.disable_attention_slicing()
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False): def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
# overridden method; types match the superclass. # overridden method; types match the superclass.
@ -398,7 +404,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
num_inference_steps: int, num_inference_steps: int,
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
*, *,
noise: torch.Tensor, noise: Optional[torch.Tensor],
timesteps=None, timesteps=None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
run_id=None, run_id=None,
@ -434,7 +440,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timesteps, timesteps,
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
*, *,
noise: torch.Tensor, noise: Optional[torch.Tensor],
run_id: str = None, run_id: str = None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
@ -457,8 +463,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
dtype=timesteps.dtype, dtype=timesteps.dtype,
device=self._model_group.device_for(self.unet), device=self._model_group.device_for(self.unet),
) )
#latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers if noise is not None:
latents = self.scheduler.add_noise(latents, noise, batched_t) #latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_t)
yield PipelineIntermediateState( yield PipelineIntermediateState(
run_id=run_id, run_id=run_id,