mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
(minor) Use SilenceWarnings as a decorator rather than a context manager to save an indentation level.
This commit is contained in:
parent
8e47e005a7
commit
79ceac2f82
@ -657,155 +657,155 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
seed = None
|
||||||
seed = None
|
noise = None
|
||||||
noise = None
|
if self.noise is not None:
|
||||||
if self.noise is not None:
|
noise = context.tensors.load(self.noise.latents_name)
|
||||||
noise = context.tensors.load(self.noise.latents_name)
|
seed = self.noise.seed
|
||||||
seed = self.noise.seed
|
|
||||||
|
|
||||||
if self.latents is not None:
|
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
|
||||||
if seed is None:
|
|
||||||
seed = self.latents.seed
|
|
||||||
|
|
||||||
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
|
||||||
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
|
||||||
|
|
||||||
elif noise is not None:
|
|
||||||
latents = torch.zeros_like(noise)
|
|
||||||
else:
|
|
||||||
raise Exception("'latents' or 'noise' must be provided!")
|
|
||||||
|
|
||||||
|
if self.latents is not None:
|
||||||
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = 0
|
seed = self.latents.seed
|
||||||
|
|
||||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
||||||
|
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
||||||
|
|
||||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
elif noise is not None:
|
||||||
# below. Investigate whether this is appropriate.
|
latents = torch.zeros_like(noise)
|
||||||
t2i_adapter_data = self.run_t2i_adapters(
|
else:
|
||||||
context,
|
raise Exception("'latents' or 'noise' must be provided!")
|
||||||
self.t2i_adapter,
|
|
||||||
latents.shape,
|
if seed is None:
|
||||||
do_classifier_free_guidance=True,
|
seed = 0
|
||||||
|
|
||||||
|
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
|
|
||||||
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||||
|
# below. Investigate whether this is appropriate.
|
||||||
|
t2i_adapter_data = self.run_t2i_adapters(
|
||||||
|
context,
|
||||||
|
self.t2i_adapter,
|
||||||
|
latents.shape,
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
ip_adapters: List[IPAdapterField] = []
|
||||||
|
if self.ip_adapter is not None:
|
||||||
|
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
||||||
|
if isinstance(self.ip_adapter, list):
|
||||||
|
ip_adapters = self.ip_adapter
|
||||||
|
else:
|
||||||
|
ip_adapters = [self.ip_adapter]
|
||||||
|
|
||||||
|
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
|
||||||
|
# a series of image conditioning embeddings. This is being done here rather than in the
|
||||||
|
# big model context below in order to use less VRAM on low-VRAM systems.
|
||||||
|
# The image prompts are then passed to prep_ip_adapter_data().
|
||||||
|
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
|
||||||
|
|
||||||
|
# get the unet's config so that we can pass the base to dispatch_progress()
|
||||||
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
|
||||||
|
def step_callback(state: PipelineIntermediateState) -> None:
|
||||||
|
context.util.sd_step_callback(state, unet_config.base)
|
||||||
|
|
||||||
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
|
for lora in self.unet.loras:
|
||||||
|
lora_info = context.models.load(lora.lora)
|
||||||
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||||
|
yield (lora_info.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
|
unet_info = context.models.load(self.unet.unet)
|
||||||
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
|
with (
|
||||||
|
ExitStack() as exit_stack,
|
||||||
|
unet_info.model_on_device() as (model_state_dict, unet),
|
||||||
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||||
|
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||||
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
|
ModelPatcher.apply_lora_unet(
|
||||||
|
unet,
|
||||||
|
loras=_lora_loader(),
|
||||||
|
model_state_dict=model_state_dict,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
if noise is not None:
|
||||||
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
if masked_latents is not None:
|
||||||
|
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
|
scheduler = get_scheduler(
|
||||||
|
context=context,
|
||||||
|
scheduler_info=self.unet.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
ip_adapters: List[IPAdapterField] = []
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
if self.ip_adapter is not None:
|
|
||||||
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
|
||||||
if isinstance(self.ip_adapter, list):
|
|
||||||
ip_adapters = self.ip_adapter
|
|
||||||
else:
|
|
||||||
ip_adapters = [self.ip_adapter]
|
|
||||||
|
|
||||||
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
|
_, _, latent_height, latent_width = latents.shape
|
||||||
# a series of image conditioning embeddings. This is being done here rather than in the
|
conditioning_data = self.get_conditioning_data(
|
||||||
# big model context below in order to use less VRAM on low-VRAM systems.
|
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||||
# The image prompts are then passed to prep_ip_adapter_data().
|
)
|
||||||
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
|
|
||||||
|
|
||||||
# get the unet's config so that we can pass the base to dispatch_progress()
|
controlnet_data = self.prep_control_data(
|
||||||
unet_config = context.models.get_config(self.unet.unet.key)
|
context=context,
|
||||||
|
control_input=self.control,
|
||||||
|
latents_shape=latents.shape,
|
||||||
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
exit_stack=exit_stack,
|
||||||
|
)
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState) -> None:
|
ip_adapter_data = self.prep_ip_adapter_data(
|
||||||
context.util.sd_step_callback(state, unet_config.base)
|
context=context,
|
||||||
|
ip_adapters=ip_adapters,
|
||||||
|
image_prompts=image_prompts,
|
||||||
|
exit_stack=exit_stack,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
dtype=unet.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||||
for lora in self.unet.loras:
|
scheduler,
|
||||||
lora_info = context.models.load(lora.lora)
|
device=unet.device,
|
||||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
steps=self.steps,
|
||||||
yield (lora_info.model, lora.weight)
|
denoising_start=self.denoising_start,
|
||||||
del lora_info
|
denoising_end=self.denoising_end,
|
||||||
return
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
unet_info = context.models.load(self.unet.unet)
|
result_latents = pipeline.latents_from_embeddings(
|
||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
latents=latents,
|
||||||
with (
|
timesteps=timesteps,
|
||||||
ExitStack() as exit_stack,
|
init_timestep=init_timestep,
|
||||||
unet_info.model_on_device() as (model_state_dict, unet),
|
noise=noise,
|
||||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
seed=seed,
|
||||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
mask=mask,
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
masked_latents=masked_latents,
|
||||||
ModelPatcher.apply_lora_unet(
|
gradient_mask=gradient_mask,
|
||||||
unet,
|
num_inference_steps=num_inference_steps,
|
||||||
loras=_lora_loader(),
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
model_state_dict=model_state_dict,
|
conditioning_data=conditioning_data,
|
||||||
),
|
control_data=controlnet_data,
|
||||||
):
|
ip_adapter_data=ip_adapter_data,
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
callback=step_callback,
|
||||||
if noise is not None:
|
)
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
|
||||||
if mask is not None:
|
|
||||||
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
|
||||||
if masked_latents is not None:
|
|
||||||
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
|
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
context=context,
|
result_latents = result_latents.to("cpu")
|
||||||
scheduler_info=self.unet.scheduler,
|
TorchDevice.empty_cache()
|
||||||
scheduler_name=self.scheduler,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
name = context.tensors.save(tensor=result_latents)
|
||||||
|
|
||||||
_, _, latent_height, latent_width = latents.shape
|
|
||||||
conditioning_data = self.get_conditioning_data(
|
|
||||||
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
|
||||||
)
|
|
||||||
|
|
||||||
controlnet_data = self.prep_control_data(
|
|
||||||
context=context,
|
|
||||||
control_input=self.control,
|
|
||||||
latents_shape=latents.shape,
|
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
|
||||||
do_classifier_free_guidance=True,
|
|
||||||
exit_stack=exit_stack,
|
|
||||||
)
|
|
||||||
|
|
||||||
ip_adapter_data = self.prep_ip_adapter_data(
|
|
||||||
context=context,
|
|
||||||
ip_adapters=ip_adapters,
|
|
||||||
image_prompts=image_prompts,
|
|
||||||
exit_stack=exit_stack,
|
|
||||||
latent_height=latent_height,
|
|
||||||
latent_width=latent_width,
|
|
||||||
dtype=unet.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
|
||||||
scheduler,
|
|
||||||
device=unet.device,
|
|
||||||
steps=self.steps,
|
|
||||||
denoising_start=self.denoising_start,
|
|
||||||
denoising_end=self.denoising_end,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
result_latents = pipeline.latents_from_embeddings(
|
|
||||||
latents=latents,
|
|
||||||
timesteps=timesteps,
|
|
||||||
init_timestep=init_timestep,
|
|
||||||
noise=noise,
|
|
||||||
seed=seed,
|
|
||||||
mask=mask,
|
|
||||||
masked_latents=masked_latents,
|
|
||||||
gradient_mask=gradient_mask,
|
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
control_data=controlnet_data,
|
|
||||||
ip_adapter_data=ip_adapter_data,
|
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
|
||||||
callback=step_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
result_latents = result_latents.to("cpu")
|
|
||||||
TorchDevice.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=result_latents)
|
|
||||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||||
|
Loading…
Reference in New Issue
Block a user