Remove sdxl t2l/l2l nodes

This commit is contained in:
Sergey Borisov 2023-08-08 03:38:42 +03:00
parent 1db2c93f75
commit 492bfe002a

View File

@ -1,17 +1,10 @@
import torch import torch
import inspect from typing import Literal
from tqdm import tqdm from pydantic import Field
from typing import List, Literal, Optional, Union
from pydantic import Field, validator from ...backend.model_management import ModelType, SubModelType
from ...backend.model_management import ModelType, SubModelType, ModelPatcher
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
from .compel import ConditioningField
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
class SDXLModelLoaderOutput(BaseInvocationOutput): class SDXLModelLoaderOutput(BaseInvocationOutput):
@ -201,526 +194,3 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
), ),
), ),
) )
# Text to image
class SDXLTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
type: Literal["t2l_sdxl"] = "t2l_sdxl"
# Inputs
# fmt: off
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet submodel")
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Text To Latents",
"tags": ["latents"],
"type_hints": {
"model": "model",
# "cfg_scale": "float",
"cfg_scale": "number",
},
},
}
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
sample,
step,
total_steps,
) -> None:
stable_diffusion_xl_step_callback(
context=context,
node=self.dict(),
source_node_id=source_node_id,
sample=sample,
step=step,
total_steps=total_steps,
)
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
latents = context.services.latents.get(self.noise.latents_name)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
prompt_embeds = positive_cond_data.conditionings[0].embeds
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
add_time_ids = positive_cond_data.conditionings[0].add_time_ids
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
num_inference_steps = self.steps
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
scheduler.set_timesteps(num_inference_steps, device=unet.device)
timesteps = scheduler.timesteps
latents = latents.to(device=unet.device, dtype=unet.dtype) * scheduler.init_noise_sigma
extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
eta=0.0,
)
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
generator=torch.Generator(device=unet.device).manual_seed(0),
)
num_warmup_steps = len(timesteps) - self.steps * scheduler.order
# apply denoising_end
skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
num_inference_steps = num_inference_steps - skipped_final_steps
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
if not context.services.configuration.sequential_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
# del noise_pred_uncond
# del noise_pred_text
# if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
# if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
else:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latents, t)
# import gc
# gc.collect()
# torch.cuda.empty_cache()
# predict the noise residual
added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids}
noise_pred_uncond = unet(
latent_model_input,
t,
encoder_hidden_states=negative_prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
noise_pred_text = unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
# del noise_pred_text
# del noise_pred_uncond
# import gc
# gc.collect()
# torch.cuda.empty_cache()
# if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# del noise_pred
# import gc
# gc.collect()
# torch.cuda.empty_cache()
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
# if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
#################
latents = latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)
class SDXLLatentsToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
type: Literal["l2l_sdxl"] = "l2l_sdxl"
# Inputs
# fmt: off
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet submodel")
latents: Optional[LatentsField] = Field(description="Initial latents")
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Latents to Latents",
"tags": ["latents"],
"type_hints": {
"model": "model",
# "cfg_scale": "float",
"cfg_scale": "number",
},
},
}
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
sample,
step,
total_steps,
) -> None:
stable_diffusion_xl_step_callback(
context=context,
node=self.dict(),
source_node_id=source_node_id,
sample=sample,
step=step,
total_steps=total_steps,
)
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
latents = context.services.latents.get(self.latents.latents_name)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
prompt_embeds = positive_cond_data.conditionings[0].embeds
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
add_time_ids = positive_cond_data.conditionings[0].add_time_ids
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
do_classifier_free_guidance = True
cross_attention_kwargs = None
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
# apply denoising_start
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps, device=unet.device)
t_start = int(round(self.denoising_start * num_inference_steps))
timesteps = scheduler.timesteps[t_start * scheduler.order :]
num_inference_steps = num_inference_steps - t_start
# apply noise(if provided)
if self.noise is not None and timesteps.shape[0] > 0:
noise = context.services.latents.get(self.noise.latents_name)
latents = scheduler.add_noise(latents, noise, timesteps[:1])
del noise
# apply scheduler extra args
extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
eta=0.0,
)
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
generator=torch.Generator(device=unet.device).manual_seed(0),
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
# apply denoising_end
skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
num_inference_steps = num_inference_steps - skipped_final_steps
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
if not context.services.configuration.sequential_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
# del noise_pred_uncond
# del noise_pred_text
# if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
# if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
else:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latents, t)
# import gc
# gc.collect()
# torch.cuda.empty_cache()
# predict the noise residual
added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_time_ids}
noise_pred_uncond = unet(
latent_model_input,
t,
encoder_hidden_states=negative_prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
noise_pred_text = unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
# del noise_pred_text
# del noise_pred_uncond
# import gc
# gc.collect()
# torch.cuda.empty_cache()
# if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# del noise_pred
# import gc
# gc.collect()
# torch.cuda.empty_cache()
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
# if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
#################
latents = latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)