mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Separate prompt to sdxl and sdxl-refiner, add denoising start-end fields, add l2l node(supports both sdxl and sdxl-refiner), add fp32 to vae encode
This commit is contained in:
parent
ab840742b0
commit
c9c2229917
@ -36,6 +36,7 @@ class BasicConditioningInfo:
|
|||||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||||
#type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
|
#type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
|
||||||
pooled_embeds: torch.Tensor
|
pooled_embeds: torch.Tensor
|
||||||
|
add_time_ids: torch.Tensor
|
||||||
|
|
||||||
ConditioningInfoType = Annotated[
|
ConditioningInfoType = Annotated[
|
||||||
Union[BasicConditioningInfo, SDXLConditioningInfo],
|
Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
@ -300,6 +301,12 @@ class SDXLRawPromptInvocation(BaseInvocation):
|
|||||||
|
|
||||||
prompt: str = Field(default="", description="Prompt")
|
prompt: str = Field(default="", description="Prompt")
|
||||||
style: str = Field(default="", description="Style prompt")
|
style: str = Field(default="", description="Style prompt")
|
||||||
|
original_width: int = Field(1024, description="")
|
||||||
|
original_height: int = Field(1024, description="")
|
||||||
|
crop_top: int = Field(0, description="")
|
||||||
|
crop_left: int = Field(0, description="")
|
||||||
|
target_width: int = Field(1024, description="")
|
||||||
|
target_height: int = Field(1024, description="")
|
||||||
clip1: ClipField = Field(None, description="Clip to use")
|
clip1: ClipField = Field(None, description="Clip to use")
|
||||||
clip2: ClipField = Field(None, description="Clip to use")
|
clip2: ClipField = Field(None, description="Clip to use")
|
||||||
|
|
||||||
@ -385,11 +392,20 @@ class SDXLRawPromptInvocation(BaseInvocation):
|
|||||||
else:
|
else:
|
||||||
c2, c2_pooled, ec2 = self.run_clip(context, self.clip2, self.style)
|
c2, c2_pooled, ec2 = self.run_clip(context, self.clip2, self.style)
|
||||||
|
|
||||||
|
original_size = (self.original_height, self.original_width)
|
||||||
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
|
target_size = (self.target_height, self.target_width)
|
||||||
|
|
||||||
|
add_time_ids = torch.tensor([
|
||||||
|
original_size + crop_coords + target_size
|
||||||
|
])
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[
|
conditionings=[
|
||||||
SDXLConditioningInfo(
|
SDXLConditioningInfo(
|
||||||
embeds=torch.cat([c1, c2], dim=-1),
|
embeds=torch.cat([c1, c2], dim=-1),
|
||||||
pooled_embeds=c2_pooled,
|
pooled_embeds=c2_pooled,
|
||||||
|
add_time_ids=add_time_ids,
|
||||||
extra_conditioning=ec1,
|
extra_conditioning=ec1,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -404,6 +420,124 @@ class SDXLRawPromptInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class SDXLRefinerRawPromptInvocation(BaseInvocation):
|
||||||
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
|
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
|
||||||
|
|
||||||
|
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||||
|
original_width: int = Field(1024, description="")
|
||||||
|
original_height: int = Field(1024, description="")
|
||||||
|
crop_top: int = Field(0, description="")
|
||||||
|
crop_left: int = Field(0, description="")
|
||||||
|
aesthetic_score: float = Field(6.0, description="")
|
||||||
|
clip2: ClipField = Field(None, description="Clip to use")
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"title": "SDXL Refiner Prompt (Raw)",
|
||||||
|
"tags": ["prompt", "compel"],
|
||||||
|
"type_hints": {
|
||||||
|
"model": "model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def run_clip(self, context, clip_field, prompt):
|
||||||
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
|
**clip_field.tokenizer.dict(),
|
||||||
|
)
|
||||||
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
|
**clip_field.text_encoder.dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _lora_loader():
|
||||||
|
for lora in clip_field.loras:
|
||||||
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.dict(exclude={"weight"}))
|
||||||
|
yield (lora_info.context.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
|
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
|
ti_list = []
|
||||||
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||||
|
name = trigger[1:-1]
|
||||||
|
try:
|
||||||
|
ti_list.append(
|
||||||
|
context.services.model_manager.get_model(
|
||||||
|
model_name=name,
|
||||||
|
base_model=clip_field.text_encoder.base_model,
|
||||||
|
model_type=ModelType.TextualInversion,
|
||||||
|
).context.model
|
||||||
|
)
|
||||||
|
except ModelNotFoundException:
|
||||||
|
# print(e)
|
||||||
|
#import traceback
|
||||||
|
#print(traceback.format_exc())
|
||||||
|
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||||
|
|
||||||
|
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||||
|
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||||
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
||||||
|
text_encoder_info as text_encoder:
|
||||||
|
|
||||||
|
text_inputs = tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
text_input_ids = text_inputs.input_ids
|
||||||
|
prompt_embeds = text_encoder(
|
||||||
|
text_input_ids.to(text_encoder.device),
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
c_pooled = prompt_embeds[0]
|
||||||
|
c = prompt_embeds.hidden_states[-2]
|
||||||
|
|
||||||
|
del tokenizer
|
||||||
|
del text_encoder
|
||||||
|
del tokenizer_info
|
||||||
|
del text_encoder_info
|
||||||
|
|
||||||
|
return c.detach(), c_pooled.detach(), None
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
|
c2, c2_pooled, ec2 = self.run_clip(context, self.clip2, self.style)
|
||||||
|
|
||||||
|
original_size = (self.original_height, self.original_width)
|
||||||
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
|
|
||||||
|
add_time_ids = torch.tensor([
|
||||||
|
original_size + crop_coords + (self.aesthetic_score,)
|
||||||
|
])
|
||||||
|
|
||||||
|
conditioning_data = ConditioningFieldData(
|
||||||
|
conditionings=[
|
||||||
|
SDXLConditioningInfo(
|
||||||
|
embeds=c2,
|
||||||
|
pooled_embeds=c2_pooled,
|
||||||
|
add_time_ids=add_time_ids,
|
||||||
|
extra_conditioning=ec2, # or None
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
|
context.services.latents.save(conditioning_name, conditioning_data)
|
||||||
|
|
||||||
|
return CompelOutput(
|
||||||
|
conditioning=ConditioningField(
|
||||||
|
conditioning_name=conditioning_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||||
"""Clip skip node output"""
|
"""Clip skip node output"""
|
||||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||||
|
@ -650,6 +650,8 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
tiled: bool = Field(
|
tiled: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Encode latents by overlaping tiles(less memory consumption)")
|
description="Encode latents by overlaping tiles(less memory consumption)")
|
||||||
|
fp32: bool = Field(False, description="Decode in full precision")
|
||||||
|
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -676,6 +678,32 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||||
|
|
||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
|
orig_dtype = vae.dtype
|
||||||
|
if self.fp32:
|
||||||
|
vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
use_torch_2_0_or_xformers = isinstance(
|
||||||
|
vae.decoder.mid_block.attentions[0].processor,
|
||||||
|
(
|
||||||
|
AttnProcessor2_0,
|
||||||
|
XFormersAttnProcessor,
|
||||||
|
LoRAXFormersAttnProcessor,
|
||||||
|
LoRAAttnProcessor2_0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# if xformers or torch_2_0 is used attention block does not need
|
||||||
|
# to be in float32 which can save lots of memory
|
||||||
|
if use_torch_2_0_or_xformers:
|
||||||
|
vae.post_quant_conv.to(orig_dtype)
|
||||||
|
vae.decoder.conv_in.to(orig_dtype)
|
||||||
|
vae.decoder.mid_block.to(orig_dtype)
|
||||||
|
#else:
|
||||||
|
# latents = latents.float()
|
||||||
|
|
||||||
|
else:
|
||||||
|
vae.to(dtype=torch.float16)
|
||||||
|
#latents = latents.half()
|
||||||
|
|
||||||
if self.tiled:
|
if self.tiled:
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
else:
|
else:
|
||||||
@ -690,6 +718,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
) # FIXME: uses torch.randn. make reproducible!
|
) # FIXME: uses torch.randn. make reproducible!
|
||||||
|
|
||||||
latents = 0.18215 * latents
|
latents = 0.18215 * latents
|
||||||
|
latents = latents.to(dtype=orig_dtype)
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
# context.services.latents.set(name, latents)
|
# context.services.latents.set(name, latents)
|
||||||
|
@ -29,6 +29,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
|
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
||||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
@ -68,12 +69,12 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||||
prompt_embeds = positive_cond_data.conditionings[0].embeds
|
prompt_embeds = positive_cond_data.conditionings[0].embeds
|
||||||
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
|
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
|
||||||
|
add_time_ids = positive_cond_data.conditionings[0].add_time_ids
|
||||||
|
|
||||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
|
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
|
||||||
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
|
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
|
||||||
|
add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids
|
||||||
add_time_ids = torch.tensor([(latents.shape[2] * 8, latents.shape[3] * 8) + (0, 0) + (latents.shape[2] * 8, latents.shape[3] * 8)])
|
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
@ -81,18 +82,12 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler_name=self.scheduler,
|
scheduler_name=self.scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler.set_timesteps(self.steps)
|
num_inference_steps = self.steps
|
||||||
|
scheduler.set_timesteps(num_inference_steps)
|
||||||
timesteps = scheduler.timesteps
|
timesteps = scheduler.timesteps
|
||||||
|
|
||||||
latents = latents * scheduler.init_noise_sigma
|
latents = latents * scheduler.init_noise_sigma
|
||||||
|
|
||||||
extra_step_kwargs = dict()
|
|
||||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
|
||||||
extra_step_kwargs.update(
|
|
||||||
eta=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
#################
|
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict()
|
**self.unet.unet.dict()
|
||||||
@ -101,22 +96,33 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
with unet_info as unet:
|
with unet_info as unet:
|
||||||
|
|
||||||
|
extra_step_kwargs = dict()
|
||||||
|
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
|
extra_step_kwargs.update(
|
||||||
|
eta=0.0,
|
||||||
|
)
|
||||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
extra_step_kwargs.update(
|
extra_step_kwargs.update(
|
||||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_warmup_steps = len(timesteps) - self.steps * scheduler.order
|
||||||
|
|
||||||
|
# apply denoising_end
|
||||||
|
skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
|
||||||
|
num_inference_steps = num_inference_steps - skipped_final_steps
|
||||||
|
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
|
||||||
|
|
||||||
if not context.services.configuration.sequential_guidance:
|
if not context.services.configuration.sequential_guidance:
|
||||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||||
|
|
||||||
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
|
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
num_warmup_steps = len(timesteps) - self.steps * scheduler.order
|
|
||||||
with tqdm(total=self.steps) as progress_bar:
|
with tqdm(total=self.steps) as progress_bar:
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
@ -157,13 +163,249 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
else:
|
else:
|
||||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||||
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
num_warmup_steps = len(timesteps) - self.steps * scheduler.order
|
|
||||||
with tqdm(total=self.steps) as progress_bar:
|
with tqdm(total=self.steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
|
||||||
|
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||||
|
|
||||||
|
#import gc
|
||||||
|
#gc.collect()
|
||||||
|
#torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
|
||||||
|
added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids}
|
||||||
|
noise_pred_uncond = unet(
|
||||||
|
latent_model_input,
|
||||||
|
t,
|
||||||
|
encoder_hidden_states=negative_prompt_embeds,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
|
||||||
|
noise_pred_text = unet(
|
||||||
|
latent_model_input,
|
||||||
|
t,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
#del noise_pred_text
|
||||||
|
#del noise_pred_uncond
|
||||||
|
#import gc
|
||||||
|
#gc.collect()
|
||||||
|
#torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||||
|
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
|
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||||
|
|
||||||
|
#del noise_pred
|
||||||
|
#import gc
|
||||||
|
#gc.collect()
|
||||||
|
#torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
#if callback is not None and i % callback_steps == 0:
|
||||||
|
# callback(i, t, latents)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#################
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
|
context.services.latents.save(name, latents)
|
||||||
|
return build_latents_output(latents_name=name, latents=latents)
|
||||||
|
|
||||||
|
class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||||
|
"""Generates latents from conditionings."""
|
||||||
|
|
||||||
|
type: Literal["l2l_sdxl"] = "l2l_sdxl"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
# fmt: off
|
||||||
|
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||||
|
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||||
|
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||||
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
|
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
|
latents: Optional[LatentsField] = Field(description="Initial latents")
|
||||||
|
|
||||||
|
denoising_start: float = Field(default=0.0, ge=0, lt=1, description="")
|
||||||
|
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
||||||
|
|
||||||
|
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
|
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
|
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
@validator("cfg_scale")
|
||||||
|
def ge_one(cls, v):
|
||||||
|
"""validate that all cfg_scale values are >= 1"""
|
||||||
|
if isinstance(v, list):
|
||||||
|
for i in v:
|
||||||
|
if i < 1:
|
||||||
|
raise ValueError('cfg_scale must be greater than 1')
|
||||||
|
else:
|
||||||
|
if v < 1:
|
||||||
|
raise ValueError('cfg_scale must be greater than 1')
|
||||||
|
return v
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["latents"],
|
||||||
|
"type_hints": {
|
||||||
|
"model": "model",
|
||||||
|
# "cfg_scale": "float",
|
||||||
|
"cfg_scale": "number"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# based on
|
||||||
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
|
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||||
|
prompt_embeds = positive_cond_data.conditionings[0].embeds
|
||||||
|
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
|
||||||
|
add_time_ids = positive_cond_data.conditionings[0].add_time_ids
|
||||||
|
|
||||||
|
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
|
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
|
||||||
|
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
|
||||||
|
add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids
|
||||||
|
|
||||||
|
scheduler = get_scheduler(
|
||||||
|
context=context,
|
||||||
|
scheduler_info=self.unet.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply denoising_start
|
||||||
|
num_inference_steps = self.steps
|
||||||
|
scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
|
t_start = int(round(self.denoising_start * num_inference_steps))
|
||||||
|
timesteps = scheduler.timesteps[t_start * scheduler.order:]
|
||||||
|
num_inference_steps = num_inference_steps - t_start
|
||||||
|
|
||||||
|
# apply noise(if provided)
|
||||||
|
if self.noise is not None:
|
||||||
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
latents = scheduler.add_noise(latents, noise, timesteps[:1])
|
||||||
|
del noise
|
||||||
|
|
||||||
|
unet_info = context.services.model_manager.get_model(
|
||||||
|
**self.unet.unet.dict()
|
||||||
|
)
|
||||||
|
do_classifier_free_guidance = True
|
||||||
|
cross_attention_kwargs = None
|
||||||
|
with unet_info as unet:
|
||||||
|
|
||||||
|
# apply scheduler extra args
|
||||||
|
extra_step_kwargs = dict()
|
||||||
|
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
|
extra_step_kwargs.update(
|
||||||
|
eta=0.0,
|
||||||
|
)
|
||||||
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
|
extra_step_kwargs.update(
|
||||||
|
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
|
||||||
|
|
||||||
|
# apply denoising_end
|
||||||
|
skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
|
||||||
|
num_inference_steps = num_inference_steps - skipped_final_steps
|
||||||
|
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
|
||||||
|
|
||||||
|
if not context.services.configuration.sequential_guidance:
|
||||||
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||||
|
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||||
|
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||||
|
|
||||||
|
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
|
with tqdm(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
|
||||||
|
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||||
|
noise_pred = unet(
|
||||||
|
latent_model_input,
|
||||||
|
t,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
#del noise_pred_uncond
|
||||||
|
#del noise_pred_text
|
||||||
|
|
||||||
|
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||||
|
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
|
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
#if callback is not None and i % callback_steps == 0:
|
||||||
|
# callback(i, t, latents)
|
||||||
|
else:
|
||||||
|
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
|
with tqdm(total=num_inference_steps) as progress_bar:
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
Loading…
Reference in New Issue
Block a user