SDXL Prompt and t2l nodes draft, add fp32 to vae decode

This commit is contained in:
Sergey Borisov 2023-07-11 18:19:36 +03:00
parent 34cff848c7
commit 358ced6bab
3 changed files with 537 additions and 2 deletions

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional, Union, List
from typing import Literal, Optional, Union, List, Annotated
from pydantic import BaseModel, Field
import re
import torch
@ -14,6 +14,7 @@ from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .model import ClipField
from dataclasses import dataclass
class ConditioningField(BaseModel):
@ -23,6 +24,33 @@ class ConditioningField(BaseModel):
class Config:
schema_extra = {"required": ["conditioning_name"]}
@dataclass
class BasicConditioningInfo:
#type: Literal["basic_conditioning"] = "basic_conditioning"
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
#type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
pooled_embeds: torch.Tensor
ConditioningInfoType = Annotated[
Union[BasicConditioningInfo, SDXLConditioningInfo],
Field(discriminator="type")
]
@dataclass
class ConditioningFieldData:
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
#unconditioned: Optional[torch.Tensor]
#class ConditioningAlgo(str, Enum):
# Compose = "compose"
# ComposeEx = "compose_ex"
# PerpNeg = "perp_neg"
class CompelOutput(BaseInvocationOutput):
"""Compel parser output"""
@ -121,8 +149,9 @@ class CompelInvocation(BaseInvocation):
cross_attention_control_args=options.get(
"cross_attention_control", None),)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
raise NotImplementedError("TODO: redo to new conditionings")
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
@ -132,6 +161,252 @@ class CompelInvocation(BaseInvocation):
),
)
# TODO: implement with compel package update
class SDXLCompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_compel"] = "sdxl_compel"
prompt: str = Field(default="", description="Prompt")
clip1: ClipField = Field(None, description="Clip to use")
clip2: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
def run_clip(self, context, clip_field):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
)
text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(),
)
def _lora_loader():
for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
context.services.model_manager.get_model(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except ModelNotFoundException:
# print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
text_encoder_info as text_encoder:
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
)
conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
### TODO: pooled
text_inputs = tokenizer(
self.prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
)
c_pooled = prompt_embeds[0]
c = prompt_embeds.hidden_states[-2]
### TODO: pooled
# TODO: long prompt support
# if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
del tokenizer
del text_encoder
del tokenizer_info
del text_encoder_info
del compel
return c.detach(), c_pooled.detach(), None
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip(context, self.clip1)
c2, c2_pooled, ec2 = self.run_clip(context, self.clip2)
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
embeds=torch.cat([c1, c2], dim=-1),
pooled_embeds=c2_pooled,
extra_conditioning=ec1,
)
]
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
class SDXLRawPromptInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt"
prompt: str = Field(default="", description="Prompt")
style: str = Field(default="", description="Style prompt")
clip1: ClipField = Field(None, description="Clip to use")
clip2: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Prompt (Raw)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
def run_clip(self, context, clip_field, prompt):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
)
text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(),
)
def _lora_loader():
for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
name = trigger[1:-1]
try:
ti_list.append(
context.services.model_manager.get_model(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except ModelNotFoundException:
# print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
text_encoder_info as text_encoder:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
)
c_pooled = prompt_embeds[0]
c = prompt_embeds.hidden_states[-2]
del tokenizer
del text_encoder
del tokenizer_info
del text_encoder_info
return c.detach(), c_pooled.detach(), None
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip(context, self.clip1, self.prompt)
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip(context, self.clip2, self.prompt)
else:
c2, c2_pooled, ec2 = self.run_clip(context, self.clip2, self.style)
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
embeds=torch.cat([c1, c2], dim=-1),
pooled_embeds=c2_pooled,
extra_conditioning=ec1,
)
]
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output"""
type: Literal["clip_skip_output"] = "clip_skip_output"

View File

@ -28,6 +28,13 @@ from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
@ -449,6 +456,7 @@ class LatentsToImageInvocation(BaseInvocation):
tiled: bool = Field(
default=False,
description="Decode latents by overlaping tiles(less memory consumption)")
fp32: bool = Field(False, description="Decode in full precision")
# Schema customisation
class Config(InvocationConfig):
@ -467,6 +475,31 @@ class LatentsToImageInvocation(BaseInvocation):
)
with vae_info as vae:
if self.fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(latents.dtype)
vae.decoder.conv_in.to(latents.dtype)
vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
else:
vae.to(dtype=torch.float16)
latents = latents.half()
if self.tiled or context.services.configuration.tiled_decode:
vae.enable_tiling()
else:

View File

@ -0,0 +1,227 @@
import copy
import torch
import inspect
from tqdm import tqdm
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field, validator
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
from .compel import ConditioningField
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
# Text to image
class SDXLTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
type: Literal["t2l_sdxl"] = "t2l_sdxl"
# Inputs
# fmt: off
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet submodel")
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
# "cfg_scale": "float",
"cfg_scale": "number"
}
},
}
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.noise.latents_name)
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
prompt_embeds = positive_cond_data.conditionings[0].embeds
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
add_time_ids = torch.tensor([(latents.shape[2] * 8, latents.shape[3] * 8) + (0, 0) + (latents.shape[2] * 8, latents.shape[3] * 8)])
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
scheduler.set_timesteps(self.steps)
timesteps = scheduler.timesteps
extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
eta=0.0,
)
#################
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()
)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
if not context.services.configuration.sequential_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
num_warmup_steps = len(timesteps) - self.steps * scheduler.order
with tqdm(total=self.steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
#del noise_pred_uncond
#del noise_pred_text
#if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
else:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
num_warmup_steps = len(timesteps) - self.steps * scheduler.order
with tqdm(total=self.steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latents, t)
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# predict the noise residual
added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_time_ids}
noise_pred_uncond = unet(
latent_model_input,
t,
encoder_hidden_states=negative_prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
noise_pred_text = unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
#del noise_pred_text
#del noise_pred_uncond
#import gc
#gc.collect()
#torch.cuda.empty_cache()
#if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
#del noise_pred
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
#################
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)