mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
a514c9e28b
Update workflows handling for Workflow Library. **Updated Workflow Storage** "Embedded Workflows" are workflows associated with images, and are now only stored in the image files. "Library Workflows" are not associated with images, and are stored only in DB. This works out nicely. We have always saved workflows to files, but recently began saving them to the DB in addition to in image files. When that happened, we stopped reading workflows from files, so all the workflows that only existed in images were inaccessible. With this change, access to those workflows is restored, and no workflows are lost. **Updated Workflow Handling in Nodes** Prior to this change, workflows were embedded in images by passing the whole workflow JSON to a special workflow field on a node. In the node's `invoke()` function, the node was able to access this workflow and save it with the image. This (inaccurately) models workflows as a property of an image and is rather awkward technically. A workflow is now a property of a batch/session queue item. It is available in the InvocationContext and therefore available to all nodes during `invoke()`. **Database Migrations** Added a `SQLiteMigrator` class to handle database migrations. Migrations were needed to accomodate the DB-related changes in this PR. See the code for details. The `images`, `workflows` and `session_queue` tables required migrations for this PR, and are using the new migrator. Other tables/services are still creating tables themselves. A followup PR will adapt them to use the migrator. **Other/Support Changes** - Add a `has_workflow` column to `images` table to indicate that the image has an embedded workflow. - Add handling for retrieving the workflow from an image in python. The image file must be fetched, the workflow extracted, and then sent to client, avoiding needing the browser to parse the image file. With the `has_workflow` column, the UI knows if there is a workflow to be fetched, and only fetches when the user requests to load the workflow. - Add route to get the workflow from an image - Add CRUD service/routes for the library workflows - `workflow_images` table and services removed (no longer needed now that embedded workflows are not in the DB)
1164 lines
45 KiB
Python
1164 lines
45 KiB
Python
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
from contextlib import ExitStack
|
|
from functools import singledispatchmethod
|
|
from typing import List, Literal, Optional, Union
|
|
|
|
import einops
|
|
import numpy as np
|
|
import torch
|
|
import torchvision.transforms as T
|
|
from diffusers import AutoencoderKL, AutoencoderTiny
|
|
from diffusers.image_processor import VaeImageProcessor
|
|
from diffusers.models.adapter import T2IAdapter
|
|
from diffusers.models.attention_processor import (
|
|
AttnProcessor2_0,
|
|
LoRAAttnProcessor2_0,
|
|
LoRAXFormersAttnProcessor,
|
|
XFormersAttnProcessor,
|
|
)
|
|
from diffusers.schedulers import DPMSolverSDEScheduler
|
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
|
from pydantic import field_validator
|
|
from torchvision.transforms.functional import resize as tv_resize
|
|
|
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
|
from invokeai.app.invocations.primitives import (
|
|
DenoiseMaskField,
|
|
DenoiseMaskOutput,
|
|
ImageField,
|
|
ImageOutput,
|
|
LatentsField,
|
|
LatentsOutput,
|
|
build_latents_output,
|
|
)
|
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
|
from invokeai.app.shared.fields import FieldDescriptions
|
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
|
|
|
from ...backend.model_management.lora import ModelPatcher
|
|
from ...backend.model_management.models import BaseModelType
|
|
from ...backend.model_management.seamless import set_seamless
|
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
|
ControlNetData,
|
|
IPAdapterData,
|
|
StableDiffusionGeneratorPipeline,
|
|
T2IAdapterData,
|
|
image_resized_to_grid_as_tensor,
|
|
)
|
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
|
from .baseinvocation import (
|
|
BaseInvocation,
|
|
BaseInvocationOutput,
|
|
Input,
|
|
InputField,
|
|
InvocationContext,
|
|
OutputField,
|
|
UIType,
|
|
WithMetadata,
|
|
invocation,
|
|
invocation_output,
|
|
)
|
|
from .compel import ConditioningField
|
|
from .controlnet_image_processors import ControlField
|
|
from .model import ModelInfo, UNetField, VaeField
|
|
|
|
if choose_torch_device() == torch.device("mps"):
|
|
from torch import mps
|
|
|
|
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
|
|
|
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
|
|
|
|
|
@invocation_output("scheduler_output")
|
|
class SchedulerOutput(BaseInvocationOutput):
|
|
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
|
|
|
|
|
@invocation(
|
|
"scheduler",
|
|
title="Scheduler",
|
|
tags=["scheduler"],
|
|
category="latents",
|
|
version="1.0.0",
|
|
)
|
|
class SchedulerInvocation(BaseInvocation):
|
|
"""Selects a scheduler."""
|
|
|
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
|
default="euler",
|
|
description=FieldDescriptions.scheduler,
|
|
ui_type=UIType.Scheduler,
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> SchedulerOutput:
|
|
return SchedulerOutput(scheduler=self.scheduler)
|
|
|
|
|
|
@invocation(
|
|
"create_denoise_mask",
|
|
title="Create Denoise Mask",
|
|
tags=["mask", "denoise"],
|
|
category="latents",
|
|
version="1.0.0",
|
|
)
|
|
class CreateDenoiseMaskInvocation(BaseInvocation):
|
|
"""Creates mask for denoising model run."""
|
|
|
|
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
|
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
|
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
|
fp32: bool = InputField(
|
|
default=DEFAULT_PRECISION == "float32",
|
|
description=FieldDescriptions.fp32,
|
|
ui_order=4,
|
|
)
|
|
|
|
def prep_mask_tensor(self, mask_image):
|
|
if mask_image.mode != "L":
|
|
mask_image = mask_image.convert("L")
|
|
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
|
if mask_tensor.dim() == 3:
|
|
mask_tensor = mask_tensor.unsqueeze(0)
|
|
# if shape is not None:
|
|
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
|
|
return mask_tensor
|
|
|
|
@torch.no_grad()
|
|
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
|
if self.image is not None:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
|
if image.dim() == 3:
|
|
image = image.unsqueeze(0)
|
|
else:
|
|
image = None
|
|
|
|
mask = self.prep_mask_tensor(
|
|
context.services.images.get_pil_image(self.mask.image_name),
|
|
)
|
|
|
|
if image is not None:
|
|
vae_info = context.services.model_manager.get_model(
|
|
**self.vae.vae.model_dump(),
|
|
context=context,
|
|
)
|
|
|
|
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
|
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
|
|
# TODO:
|
|
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
|
|
|
masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents"
|
|
context.services.latents.save(masked_latents_name, masked_latents)
|
|
else:
|
|
masked_latents_name = None
|
|
|
|
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
|
|
context.services.latents.save(mask_name, mask)
|
|
|
|
return DenoiseMaskOutput(
|
|
denoise_mask=DenoiseMaskField(
|
|
mask_name=mask_name,
|
|
masked_latents_name=masked_latents_name,
|
|
),
|
|
)
|
|
|
|
|
|
def get_scheduler(
|
|
context: InvocationContext,
|
|
scheduler_info: ModelInfo,
|
|
scheduler_name: str,
|
|
seed: int,
|
|
) -> Scheduler:
|
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
|
orig_scheduler_info = context.services.model_manager.get_model(
|
|
**scheduler_info.model_dump(),
|
|
context=context,
|
|
)
|
|
with orig_scheduler_info as orig_scheduler:
|
|
scheduler_config = orig_scheduler.config
|
|
|
|
if "_backup" in scheduler_config:
|
|
scheduler_config = scheduler_config["_backup"]
|
|
scheduler_config = {
|
|
**scheduler_config,
|
|
**scheduler_extra_config,
|
|
"_backup": scheduler_config,
|
|
}
|
|
|
|
# make dpmpp_sde reproducable(seed can be passed only in initializer)
|
|
if scheduler_class is DPMSolverSDEScheduler:
|
|
scheduler_config["noise_sampler_seed"] = seed
|
|
|
|
scheduler = scheduler_class.from_config(scheduler_config)
|
|
|
|
# hack copied over from generate.py
|
|
if not hasattr(scheduler, "uses_inpainting_model"):
|
|
scheduler.uses_inpainting_model = lambda: False
|
|
return scheduler
|
|
|
|
|
|
@invocation(
|
|
"denoise_latents",
|
|
title="Denoise Latents",
|
|
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
|
category="latents",
|
|
version="1.4.0",
|
|
)
|
|
class DenoiseLatentsInvocation(BaseInvocation):
|
|
"""Denoises noisy latents to decodable images"""
|
|
|
|
positive_conditioning: ConditioningField = InputField(
|
|
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
|
)
|
|
negative_conditioning: ConditioningField = InputField(
|
|
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
|
)
|
|
noise: Optional[LatentsField] = InputField(
|
|
default=None,
|
|
description=FieldDescriptions.noise,
|
|
input=Input.Connection,
|
|
ui_order=3,
|
|
)
|
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
|
cfg_scale: Union[float, List[float]] = InputField(
|
|
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
|
)
|
|
denoising_start: float = InputField(
|
|
default=0.0,
|
|
ge=0,
|
|
le=1,
|
|
description=FieldDescriptions.denoising_start,
|
|
)
|
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
|
default="euler",
|
|
description=FieldDescriptions.scheduler,
|
|
ui_type=UIType.Scheduler,
|
|
)
|
|
unet: UNetField = InputField(
|
|
description=FieldDescriptions.unet,
|
|
input=Input.Connection,
|
|
title="UNet",
|
|
ui_order=2,
|
|
)
|
|
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
|
default=None,
|
|
input=Input.Connection,
|
|
ui_order=5,
|
|
)
|
|
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
|
|
description=FieldDescriptions.ip_adapter,
|
|
title="IP-Adapter",
|
|
default=None,
|
|
input=Input.Connection,
|
|
ui_order=6,
|
|
)
|
|
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
|
|
description=FieldDescriptions.t2i_adapter,
|
|
title="T2I-Adapter",
|
|
default=None,
|
|
input=Input.Connection,
|
|
ui_order=7,
|
|
)
|
|
latents: Optional[LatentsField] = InputField(
|
|
default=None,
|
|
description=FieldDescriptions.latents,
|
|
input=Input.Connection,
|
|
ui_order=4,
|
|
)
|
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
|
default=None,
|
|
description=FieldDescriptions.mask,
|
|
input=Input.Connection,
|
|
ui_order=8,
|
|
)
|
|
|
|
@field_validator("cfg_scale")
|
|
def ge_one(cls, v):
|
|
"""validate that all cfg_scale values are >= 1"""
|
|
if isinstance(v, list):
|
|
for i in v:
|
|
if i < 1:
|
|
raise ValueError("cfg_scale must be greater than 1")
|
|
else:
|
|
if v < 1:
|
|
raise ValueError("cfg_scale must be greater than 1")
|
|
return v
|
|
|
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
|
def dispatch_progress(
|
|
self,
|
|
context: InvocationContext,
|
|
source_node_id: str,
|
|
intermediate_state: PipelineIntermediateState,
|
|
base_model: BaseModelType,
|
|
) -> None:
|
|
stable_diffusion_step_callback(
|
|
context=context,
|
|
intermediate_state=intermediate_state,
|
|
node=self.model_dump(),
|
|
source_node_id=source_node_id,
|
|
base_model=base_model,
|
|
)
|
|
|
|
def get_conditioning_data(
|
|
self,
|
|
context: InvocationContext,
|
|
scheduler,
|
|
unet,
|
|
seed,
|
|
) -> ConditioningData:
|
|
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
|
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
|
extra_conditioning_info = c.extra_conditioning
|
|
|
|
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
|
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
|
|
|
conditioning_data = ConditioningData(
|
|
unconditioned_embeddings=uc,
|
|
text_embeddings=c,
|
|
guidance_scale=self.cfg_scale,
|
|
extra=extra_conditioning_info,
|
|
postprocessing_settings=PostprocessingSettings(
|
|
threshold=0.0, # threshold,
|
|
warmup=0.2, # warmup,
|
|
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
|
v_symmetry_time_pct=None, # v_symmetry_time_pct,
|
|
),
|
|
)
|
|
|
|
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
|
scheduler,
|
|
# for ddim scheduler
|
|
eta=0.0, # ddim_eta
|
|
# for ancestral and sde schedulers
|
|
# flip all bits to have noise different from initial
|
|
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
|
|
)
|
|
return conditioning_data
|
|
|
|
def create_pipeline(
|
|
self,
|
|
unet,
|
|
scheduler,
|
|
) -> StableDiffusionGeneratorPipeline:
|
|
# TODO:
|
|
# configure_model_padding(
|
|
# unet,
|
|
# self.seamless,
|
|
# self.seamless_axes,
|
|
# )
|
|
|
|
class FakeVae:
|
|
class FakeVaeConfig:
|
|
def __init__(self):
|
|
self.block_out_channels = [0]
|
|
|
|
def __init__(self):
|
|
self.config = FakeVae.FakeVaeConfig()
|
|
|
|
return StableDiffusionGeneratorPipeline(
|
|
vae=FakeVae(), # TODO: oh...
|
|
text_encoder=None,
|
|
tokenizer=None,
|
|
unet=unet,
|
|
scheduler=scheduler,
|
|
safety_checker=None,
|
|
feature_extractor=None,
|
|
requires_safety_checker=False,
|
|
)
|
|
|
|
def prep_control_data(
|
|
self,
|
|
context: InvocationContext,
|
|
control_input: Union[ControlField, List[ControlField]],
|
|
latents_shape: List[int],
|
|
exit_stack: ExitStack,
|
|
do_classifier_free_guidance: bool = True,
|
|
) -> List[ControlNetData]:
|
|
# assuming fixed dimensional scaling of 8:1 for image:latents
|
|
control_height_resize = latents_shape[2] * 8
|
|
control_width_resize = latents_shape[3] * 8
|
|
if control_input is None:
|
|
control_list = None
|
|
elif isinstance(control_input, list) and len(control_input) == 0:
|
|
control_list = None
|
|
elif isinstance(control_input, ControlField):
|
|
control_list = [control_input]
|
|
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
|
control_list = control_input
|
|
else:
|
|
control_list = None
|
|
if control_list is None:
|
|
return None
|
|
# After above handling, any control that is not None should now be of type list[ControlField].
|
|
|
|
# FIXME: add checks to skip entry if model or image is None
|
|
# and if weight is None, populate with default 1.0?
|
|
controlnet_data = []
|
|
for control_info in control_list:
|
|
control_model = exit_stack.enter_context(
|
|
context.services.model_manager.get_model(
|
|
model_name=control_info.control_model.model_name,
|
|
model_type=ModelType.ControlNet,
|
|
base_model=control_info.control_model.base_model,
|
|
context=context,
|
|
)
|
|
)
|
|
|
|
# control_models.append(control_model)
|
|
control_image_field = control_info.image
|
|
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
|
# self.image.image_type, self.image.image_name
|
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
|
# and add in batch_size, num_images_per_prompt?
|
|
# and do real check for classifier_free_guidance?
|
|
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
|
control_image = prepare_control_image(
|
|
image=input_image,
|
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
width=control_width_resize,
|
|
height=control_height_resize,
|
|
# batch_size=batch_size * num_images_per_prompt,
|
|
# num_images_per_prompt=num_images_per_prompt,
|
|
device=control_model.device,
|
|
dtype=control_model.dtype,
|
|
control_mode=control_info.control_mode,
|
|
resize_mode=control_info.resize_mode,
|
|
)
|
|
control_item = ControlNetData(
|
|
model=control_model, # model object
|
|
image_tensor=control_image,
|
|
weight=control_info.control_weight,
|
|
begin_step_percent=control_info.begin_step_percent,
|
|
end_step_percent=control_info.end_step_percent,
|
|
control_mode=control_info.control_mode,
|
|
# any resizing needed should currently be happening in prepare_control_image(),
|
|
# but adding resize_mode to ControlNetData in case needed in the future
|
|
resize_mode=control_info.resize_mode,
|
|
)
|
|
controlnet_data.append(control_item)
|
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
|
|
|
return controlnet_data
|
|
|
|
def prep_ip_adapter_data(
|
|
self,
|
|
context: InvocationContext,
|
|
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
|
conditioning_data: ConditioningData,
|
|
exit_stack: ExitStack,
|
|
) -> Optional[list[IPAdapterData]]:
|
|
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
|
to the `conditioning_data` (in-place).
|
|
"""
|
|
if ip_adapter is None:
|
|
return None
|
|
|
|
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
|
if not isinstance(ip_adapter, list):
|
|
ip_adapter = [ip_adapter]
|
|
|
|
if len(ip_adapter) == 0:
|
|
return None
|
|
|
|
ip_adapter_data_list = []
|
|
conditioning_data.ip_adapter_conditioning = []
|
|
for single_ip_adapter in ip_adapter:
|
|
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
|
context.services.model_manager.get_model(
|
|
model_name=single_ip_adapter.ip_adapter_model.model_name,
|
|
model_type=ModelType.IPAdapter,
|
|
base_model=single_ip_adapter.ip_adapter_model.base_model,
|
|
context=context,
|
|
)
|
|
)
|
|
|
|
image_encoder_model_info = context.services.model_manager.get_model(
|
|
model_name=single_ip_adapter.image_encoder_model.model_name,
|
|
model_type=ModelType.CLIPVision,
|
|
base_model=single_ip_adapter.image_encoder_model.base_model,
|
|
context=context,
|
|
)
|
|
|
|
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
|
single_ipa_images = single_ip_adapter.image
|
|
if not isinstance(single_ipa_images, list):
|
|
single_ipa_images = [single_ipa_images]
|
|
|
|
single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images]
|
|
|
|
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
|
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
|
with image_encoder_model_info as image_encoder_model:
|
|
# Get image embeddings from CLIP and ImageProjModel.
|
|
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
|
single_ipa_images, image_encoder_model
|
|
)
|
|
|
|
conditioning_data.ip_adapter_conditioning.append(
|
|
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
|
)
|
|
|
|
ip_adapter_data_list.append(
|
|
IPAdapterData(
|
|
ip_adapter_model=ip_adapter_model,
|
|
weight=single_ip_adapter.weight,
|
|
begin_step_percent=single_ip_adapter.begin_step_percent,
|
|
end_step_percent=single_ip_adapter.end_step_percent,
|
|
)
|
|
)
|
|
|
|
return ip_adapter_data_list
|
|
|
|
def run_t2i_adapters(
|
|
self,
|
|
context: InvocationContext,
|
|
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
|
latents_shape: list[int],
|
|
do_classifier_free_guidance: bool,
|
|
) -> Optional[list[T2IAdapterData]]:
|
|
if t2i_adapter is None:
|
|
return None
|
|
|
|
# Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField.
|
|
if isinstance(t2i_adapter, T2IAdapterField):
|
|
t2i_adapter = [t2i_adapter]
|
|
|
|
if len(t2i_adapter) == 0:
|
|
return None
|
|
|
|
t2i_adapter_data = []
|
|
for t2i_adapter_field in t2i_adapter:
|
|
t2i_adapter_model_info = context.services.model_manager.get_model(
|
|
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
|
|
model_type=ModelType.T2IAdapter,
|
|
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
|
|
context=context,
|
|
)
|
|
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
|
|
|
|
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
|
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
|
max_unet_downscale = 8
|
|
elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL:
|
|
max_unet_downscale = 4
|
|
else:
|
|
raise ValueError(
|
|
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
|
|
)
|
|
|
|
t2i_adapter_model: T2IAdapter
|
|
with t2i_adapter_model_info as t2i_adapter_model:
|
|
total_downscale_factor = t2i_adapter_model.total_downscale_factor
|
|
|
|
# Resize the T2I-Adapter input image.
|
|
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
|
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
|
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
|
|
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
|
|
|
|
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
|
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
|
|
# T2I-Adapter model.
|
|
#
|
|
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
|
|
# of the same requirements (e.g. preserving binary masks during resize).
|
|
t2i_image = prepare_control_image(
|
|
image=image,
|
|
do_classifier_free_guidance=False,
|
|
width=t2i_input_width,
|
|
height=t2i_input_height,
|
|
num_channels=t2i_adapter_model.config.in_channels,
|
|
device=t2i_adapter_model.device,
|
|
dtype=t2i_adapter_model.dtype,
|
|
resize_mode=t2i_adapter_field.resize_mode,
|
|
)
|
|
|
|
adapter_state = t2i_adapter_model(t2i_image)
|
|
|
|
if do_classifier_free_guidance:
|
|
for idx, value in enumerate(adapter_state):
|
|
adapter_state[idx] = torch.cat([value] * 2, dim=0)
|
|
|
|
t2i_adapter_data.append(
|
|
T2IAdapterData(
|
|
adapter_state=adapter_state,
|
|
weight=t2i_adapter_field.weight,
|
|
begin_step_percent=t2i_adapter_field.begin_step_percent,
|
|
end_step_percent=t2i_adapter_field.end_step_percent,
|
|
)
|
|
)
|
|
|
|
return t2i_adapter_data
|
|
|
|
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
|
# TODO: research more for second order schedulers timesteps
|
|
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
|
if scheduler.config.get("cpu_only", False):
|
|
scheduler.set_timesteps(steps, device="cpu")
|
|
timesteps = scheduler.timesteps.to(device=device)
|
|
else:
|
|
scheduler.set_timesteps(steps, device=device)
|
|
timesteps = scheduler.timesteps
|
|
|
|
# skip greater order timesteps
|
|
_timesteps = timesteps[:: scheduler.order]
|
|
|
|
# get start timestep index
|
|
t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start)))
|
|
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
|
|
|
|
# get end timestep index
|
|
t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end)))
|
|
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
|
|
|
|
# apply order to indexes
|
|
t_start_idx *= scheduler.order
|
|
t_end_idx *= scheduler.order
|
|
|
|
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
|
|
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
|
num_inference_steps = len(timesteps) // scheduler.order
|
|
|
|
return num_inference_steps, timesteps, init_timestep
|
|
|
|
def prep_inpaint_mask(self, context, latents):
|
|
if self.denoise_mask is None:
|
|
return None, None
|
|
|
|
mask = context.services.latents.get(self.denoise_mask.mask_name)
|
|
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
|
if self.denoise_mask.masked_latents_name is not None:
|
|
masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name)
|
|
else:
|
|
masked_latents = None
|
|
|
|
return 1 - mask, masked_latents
|
|
|
|
@torch.no_grad()
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
|
seed = None
|
|
noise = None
|
|
if self.noise is not None:
|
|
noise = context.services.latents.get(self.noise.latents_name)
|
|
seed = self.noise.seed
|
|
|
|
if self.latents is not None:
|
|
latents = context.services.latents.get(self.latents.latents_name)
|
|
if seed is None:
|
|
seed = self.latents.seed
|
|
|
|
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
|
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
|
|
|
elif noise is not None:
|
|
latents = torch.zeros_like(noise)
|
|
else:
|
|
raise Exception("'latents' or 'noise' must be provided!")
|
|
|
|
if seed is None:
|
|
seed = 0
|
|
|
|
mask, masked_latents = self.prep_inpaint_mask(context, latents)
|
|
|
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
|
# below. Investigate whether this is appropriate.
|
|
t2i_adapter_data = self.run_t2i_adapters(
|
|
context,
|
|
self.t2i_adapter,
|
|
latents.shape,
|
|
do_classifier_free_guidance=True,
|
|
)
|
|
|
|
# Get the source node id (we are invoking the prepared node)
|
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
|
|
def step_callback(state: PipelineIntermediateState):
|
|
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
|
|
|
|
def _lora_loader():
|
|
for lora in self.unet.loras:
|
|
lora_info = context.services.model_manager.get_model(
|
|
**lora.model_dump(exclude={"weight"}),
|
|
context=context,
|
|
)
|
|
yield (lora_info.context.model, lora.weight)
|
|
del lora_info
|
|
return
|
|
|
|
unet_info = context.services.model_manager.get_model(
|
|
**self.unet.unet.model_dump(),
|
|
context=context,
|
|
)
|
|
with (
|
|
ExitStack() as exit_stack,
|
|
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
|
|
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
|
unet_info as unet,
|
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
|
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
|
):
|
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
|
if noise is not None:
|
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
|
if mask is not None:
|
|
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
|
if masked_latents is not None:
|
|
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
|
|
|
|
scheduler = get_scheduler(
|
|
context=context,
|
|
scheduler_info=self.unet.scheduler,
|
|
scheduler_name=self.scheduler,
|
|
seed=seed,
|
|
)
|
|
|
|
pipeline = self.create_pipeline(unet, scheduler)
|
|
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
|
|
|
controlnet_data = self.prep_control_data(
|
|
context=context,
|
|
control_input=self.control,
|
|
latents_shape=latents.shape,
|
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
|
do_classifier_free_guidance=True,
|
|
exit_stack=exit_stack,
|
|
)
|
|
|
|
ip_adapter_data = self.prep_ip_adapter_data(
|
|
context=context,
|
|
ip_adapter=self.ip_adapter,
|
|
conditioning_data=conditioning_data,
|
|
exit_stack=exit_stack,
|
|
)
|
|
|
|
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
|
scheduler,
|
|
device=unet.device,
|
|
steps=self.steps,
|
|
denoising_start=self.denoising_start,
|
|
denoising_end=self.denoising_end,
|
|
)
|
|
|
|
(
|
|
result_latents,
|
|
result_attention_map_saver,
|
|
) = pipeline.latents_from_embeddings(
|
|
latents=latents,
|
|
timesteps=timesteps,
|
|
init_timestep=init_timestep,
|
|
noise=noise,
|
|
seed=seed,
|
|
mask=mask,
|
|
masked_latents=masked_latents,
|
|
num_inference_steps=num_inference_steps,
|
|
conditioning_data=conditioning_data,
|
|
control_data=controlnet_data,
|
|
ip_adapter_data=ip_adapter_data,
|
|
t2i_adapter_data=t2i_adapter_data,
|
|
callback=step_callback,
|
|
)
|
|
|
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
result_latents = result_latents.to("cpu")
|
|
torch.cuda.empty_cache()
|
|
if choose_torch_device() == torch.device("mps"):
|
|
mps.empty_cache()
|
|
|
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
context.services.latents.save(name, result_latents)
|
|
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
|
|
|
|
|
@invocation(
|
|
"l2i",
|
|
title="Latents to Image",
|
|
tags=["latents", "image", "vae", "l2i"],
|
|
category="latents",
|
|
version="1.2.0",
|
|
)
|
|
class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
|
"""Generates an image from latents."""
|
|
|
|
latents: LatentsField = InputField(
|
|
description=FieldDescriptions.latents,
|
|
input=Input.Connection,
|
|
)
|
|
vae: VaeField = InputField(
|
|
description=FieldDescriptions.vae,
|
|
input=Input.Connection,
|
|
)
|
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
|
|
|
@torch.no_grad()
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
latents = context.services.latents.get(self.latents.latents_name)
|
|
|
|
vae_info = context.services.model_manager.get_model(
|
|
**self.vae.vae.model_dump(),
|
|
context=context,
|
|
)
|
|
|
|
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
|
|
latents = latents.to(vae.device)
|
|
if self.fp32:
|
|
vae.to(dtype=torch.float32)
|
|
|
|
use_torch_2_0_or_xformers = isinstance(
|
|
vae.decoder.mid_block.attentions[0].processor,
|
|
(
|
|
AttnProcessor2_0,
|
|
XFormersAttnProcessor,
|
|
LoRAXFormersAttnProcessor,
|
|
LoRAAttnProcessor2_0,
|
|
),
|
|
)
|
|
# if xformers or torch_2_0 is used attention block does not need
|
|
# to be in float32 which can save lots of memory
|
|
if use_torch_2_0_or_xformers:
|
|
vae.post_quant_conv.to(latents.dtype)
|
|
vae.decoder.conv_in.to(latents.dtype)
|
|
vae.decoder.mid_block.to(latents.dtype)
|
|
else:
|
|
latents = latents.float()
|
|
|
|
else:
|
|
vae.to(dtype=torch.float16)
|
|
latents = latents.half()
|
|
|
|
if self.tiled or context.services.configuration.tiled_decode:
|
|
vae.enable_tiling()
|
|
else:
|
|
vae.disable_tiling()
|
|
|
|
# clear memory as vae decode can request a lot
|
|
torch.cuda.empty_cache()
|
|
if choose_torch_device() == torch.device("mps"):
|
|
mps.empty_cache()
|
|
|
|
with torch.inference_mode():
|
|
# copied from diffusers pipeline
|
|
latents = latents / vae.config.scaling_factor
|
|
image = vae.decode(latents, return_dict=False)[0]
|
|
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
|
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
|
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
|
|
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
|
|
|
torch.cuda.empty_cache()
|
|
if choose_torch_device() == torch.device("mps"):
|
|
mps.empty_cache()
|
|
|
|
image_dto = context.services.images.create(
|
|
image=image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
metadata=self.metadata,
|
|
workflow=context.workflow,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
|
|
|
|
|
@invocation(
|
|
"lresize",
|
|
title="Resize Latents",
|
|
tags=["latents", "resize"],
|
|
category="latents",
|
|
version="1.0.0",
|
|
)
|
|
class ResizeLatentsInvocation(BaseInvocation):
|
|
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
|
|
|
latents: LatentsField = InputField(
|
|
description=FieldDescriptions.latents,
|
|
input=Input.Connection,
|
|
)
|
|
width: int = InputField(
|
|
ge=64,
|
|
multiple_of=8,
|
|
description=FieldDescriptions.width,
|
|
)
|
|
height: int = InputField(
|
|
ge=64,
|
|
multiple_of=8,
|
|
description=FieldDescriptions.width,
|
|
)
|
|
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
|
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
latents = context.services.latents.get(self.latents.latents_name)
|
|
|
|
# TODO:
|
|
device = choose_torch_device()
|
|
|
|
resized_latents = torch.nn.functional.interpolate(
|
|
latents.to(device),
|
|
size=(self.height // 8, self.width // 8),
|
|
mode=self.mode,
|
|
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
|
)
|
|
|
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
resized_latents = resized_latents.to("cpu")
|
|
torch.cuda.empty_cache()
|
|
if device == torch.device("mps"):
|
|
mps.empty_cache()
|
|
|
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
# context.services.latents.set(name, resized_latents)
|
|
context.services.latents.save(name, resized_latents)
|
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
|
|
|
|
|
@invocation(
|
|
"lscale",
|
|
title="Scale Latents",
|
|
tags=["latents", "resize"],
|
|
category="latents",
|
|
version="1.0.0",
|
|
)
|
|
class ScaleLatentsInvocation(BaseInvocation):
|
|
"""Scales latents by a given factor."""
|
|
|
|
latents: LatentsField = InputField(
|
|
description=FieldDescriptions.latents,
|
|
input=Input.Connection,
|
|
)
|
|
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
|
|
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
|
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
latents = context.services.latents.get(self.latents.latents_name)
|
|
|
|
# TODO:
|
|
device = choose_torch_device()
|
|
|
|
# resizing
|
|
resized_latents = torch.nn.functional.interpolate(
|
|
latents.to(device),
|
|
scale_factor=self.scale_factor,
|
|
mode=self.mode,
|
|
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
|
)
|
|
|
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
resized_latents = resized_latents.to("cpu")
|
|
torch.cuda.empty_cache()
|
|
if device == torch.device("mps"):
|
|
mps.empty_cache()
|
|
|
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
# context.services.latents.set(name, resized_latents)
|
|
context.services.latents.save(name, resized_latents)
|
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
|
|
|
|
|
@invocation(
|
|
"i2l",
|
|
title="Image to Latents",
|
|
tags=["latents", "image", "vae", "i2l"],
|
|
category="latents",
|
|
version="1.0.0",
|
|
)
|
|
class ImageToLatentsInvocation(BaseInvocation):
|
|
"""Encodes an image into latents."""
|
|
|
|
image: ImageField = InputField(
|
|
description="The image to encode",
|
|
)
|
|
vae: VaeField = InputField(
|
|
description=FieldDescriptions.vae,
|
|
input=Input.Connection,
|
|
)
|
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
|
|
|
@staticmethod
|
|
def vae_encode(vae_info, upcast, tiled, image_tensor):
|
|
with vae_info as vae:
|
|
orig_dtype = vae.dtype
|
|
if upcast:
|
|
vae.to(dtype=torch.float32)
|
|
|
|
use_torch_2_0_or_xformers = isinstance(
|
|
vae.decoder.mid_block.attentions[0].processor,
|
|
(
|
|
AttnProcessor2_0,
|
|
XFormersAttnProcessor,
|
|
LoRAXFormersAttnProcessor,
|
|
LoRAAttnProcessor2_0,
|
|
),
|
|
)
|
|
# if xformers or torch_2_0 is used attention block does not need
|
|
# to be in float32 which can save lots of memory
|
|
if use_torch_2_0_or_xformers:
|
|
vae.post_quant_conv.to(orig_dtype)
|
|
vae.decoder.conv_in.to(orig_dtype)
|
|
vae.decoder.mid_block.to(orig_dtype)
|
|
# else:
|
|
# latents = latents.float()
|
|
|
|
else:
|
|
vae.to(dtype=torch.float16)
|
|
# latents = latents.half()
|
|
|
|
if tiled:
|
|
vae.enable_tiling()
|
|
else:
|
|
vae.disable_tiling()
|
|
|
|
# non_noised_latents_from_image
|
|
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
|
with torch.inference_mode():
|
|
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
|
|
|
|
latents = vae.config.scaling_factor * latents
|
|
latents = latents.to(dtype=orig_dtype)
|
|
|
|
return latents
|
|
|
|
@torch.no_grad()
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
vae_info = context.services.model_manager.get_model(
|
|
**self.vae.vae.model_dump(),
|
|
context=context,
|
|
)
|
|
|
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
|
if image_tensor.dim() == 3:
|
|
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
|
|
|
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
|
|
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
latents = latents.to("cpu")
|
|
context.services.latents.save(name, latents)
|
|
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
|
|
|
@singledispatchmethod
|
|
@staticmethod
|
|
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
|
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
|
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
|
return latents
|
|
|
|
@_encode_to_tensor.register
|
|
@staticmethod
|
|
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
|
return vae.encode(image_tensor).latents
|
|
|
|
|
|
@invocation(
|
|
"lblend",
|
|
title="Blend Latents",
|
|
tags=["latents", "blend"],
|
|
category="latents",
|
|
version="1.0.0",
|
|
)
|
|
class BlendLatentsInvocation(BaseInvocation):
|
|
"""Blend two latents using a given alpha. Latents must have same size."""
|
|
|
|
latents_a: LatentsField = InputField(
|
|
description=FieldDescriptions.latents,
|
|
input=Input.Connection,
|
|
)
|
|
latents_b: LatentsField = InputField(
|
|
description=FieldDescriptions.latents,
|
|
input=Input.Connection,
|
|
)
|
|
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
|
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
latents_a = context.services.latents.get(self.latents_a.latents_name)
|
|
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
|
|
|
if latents_a.shape != latents_b.shape:
|
|
raise Exception("Latents to blend must be the same size.")
|
|
|
|
# TODO:
|
|
device = choose_torch_device()
|
|
|
|
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
|
"""
|
|
Spherical linear interpolation
|
|
Args:
|
|
t (float/np.ndarray): Float value between 0.0 and 1.0
|
|
v0 (np.ndarray): Starting vector
|
|
v1 (np.ndarray): Final vector
|
|
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
|
colineal. Not recommended to alter this.
|
|
Returns:
|
|
v2 (np.ndarray): Interpolation vector between v0 and v1
|
|
"""
|
|
inputs_are_torch = False
|
|
if not isinstance(v0, np.ndarray):
|
|
inputs_are_torch = True
|
|
v0 = v0.detach().cpu().numpy()
|
|
if not isinstance(v1, np.ndarray):
|
|
inputs_are_torch = True
|
|
v1 = v1.detach().cpu().numpy()
|
|
|
|
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
|
if np.abs(dot) > DOT_THRESHOLD:
|
|
v2 = (1 - t) * v0 + t * v1
|
|
else:
|
|
theta_0 = np.arccos(dot)
|
|
sin_theta_0 = np.sin(theta_0)
|
|
theta_t = theta_0 * t
|
|
sin_theta_t = np.sin(theta_t)
|
|
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
|
s1 = sin_theta_t / sin_theta_0
|
|
v2 = s0 * v0 + s1 * v1
|
|
|
|
if inputs_are_torch:
|
|
v2 = torch.from_numpy(v2).to(device)
|
|
|
|
return v2
|
|
|
|
# blend
|
|
blended_latents = slerp(self.alpha, latents_a, latents_b)
|
|
|
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
blended_latents = blended_latents.to("cpu")
|
|
torch.cuda.empty_cache()
|
|
if device == torch.device("mps"):
|
|
mps.empty_cache()
|
|
|
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
# context.services.latents.set(name, resized_latents)
|
|
context.services.latents.save(name, blended_latents)
|
|
return build_latents_output(latents_name=name, latents=blended_latents)
|