mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Apply black
This commit is contained in:
@ -7,13 +7,13 @@ from pydantic import Field, validator
|
||||
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
|
||||
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
|
||||
from .compel import ConditioningField
|
||||
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
|
||||
|
||||
|
||||
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL base model loader output"""
|
||||
|
||||
@ -26,16 +26,19 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL refiner model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
@ -125,8 +128,10 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
||||
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
@ -196,7 +201,8 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Text to image
|
||||
class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
@ -213,9 +219,9 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
@ -224,10 +230,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
@ -237,10 +243,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
"title": "SDXL Text To Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -265,9 +271,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
latents = context.services.latents.get(self.noise.latents_name)
|
||||
|
||||
@ -293,14 +297,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(), context=context
|
||||
)
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with unet_info as unet:
|
||||
|
||||
extra_step_kwargs = dict()
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
@ -350,10 +350,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
#del noise_pred_uncond
|
||||
#del noise_pred_text
|
||||
# del noise_pred_uncond
|
||||
# del noise_pred_text
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
@ -364,7 +364,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
else:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
@ -378,13 +378,13 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
@ -411,42 +411,41 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
# perform guidance
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
#del noise_pred_text
|
||||
#del noise_pred_uncond
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
# del noise_pred_text
|
||||
# del noise_pred_uncond
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
#del noise_pred
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
# del noise_pred
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
|
||||
|
||||
#################
|
||||
|
||||
latents = latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
||||
|
||||
class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
@ -466,9 +465,9 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
||||
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||
|
||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
@ -477,10 +476,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
@ -490,10 +489,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
"title": "SDXL Latents to Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -518,9 +517,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
@ -545,7 +542,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
t_start = int(round(self.denoising_start * num_inference_steps))
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order:]
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||
num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# apply noise(if provided)
|
||||
@ -555,12 +552,12 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
del noise
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(), context=context,
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with unet_info as unet:
|
||||
|
||||
# apply scheduler extra args
|
||||
extra_step_kwargs = dict()
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
@ -611,10 +608,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
#del noise_pred_uncond
|
||||
#del noise_pred_text
|
||||
# del noise_pred_uncond
|
||||
# del noise_pred_text
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
@ -625,7 +622,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
else:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
@ -639,13 +636,13 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
@ -672,38 +669,36 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
# perform guidance
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
#del noise_pred_text
|
||||
#del noise_pred_uncond
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
# del noise_pred_text
|
||||
# del noise_pred_uncond
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
#del noise_pred
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
# del noise_pred
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
|
||||
|
||||
#################
|
||||
|
||||
latents = latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
Reference in New Issue
Block a user