mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move LatentsToImageInvocation to its own file. No functional changes.
This commit is contained in:
parent
58697141bf
commit
045caddee1
@ -8,16 +8,7 @@ import torch
|
|||||||
import torchvision
|
import torchvision
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
|
||||||
from diffusers.models.adapter import T2IAdapter
|
from diffusers.models.adapter import T2IAdapter
|
||||||
from diffusers.models.attention_processor import (
|
|
||||||
AttnProcessor2_0,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
XFormersAttnProcessor,
|
|
||||||
)
|
|
||||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
|
||||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
|
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
|
||||||
from diffusers.schedulers.scheduling_tcd import TCDScheduler
|
from diffusers.schedulers.scheduling_tcd import TCDScheduler
|
||||||
@ -38,11 +29,9 @@ from invokeai.app.invocations.fields import (
|
|||||||
LatentsField,
|
LatentsField,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIType,
|
UIType,
|
||||||
WithBoard,
|
|
||||||
WithMetadata,
|
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||||
from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput
|
from invokeai.app.invocations.primitives import DenoiseMaskOutput, LatentsOutput
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
@ -1033,83 +1022,3 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = context.tensors.save(tensor=result_latents)
|
name = context.tensors.save(tensor=result_latents)
|
||||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
|
||||||
"l2i",
|
|
||||||
title="Latents to Image",
|
|
||||||
tags=["latents", "image", "vae", "l2i"],
|
|
||||||
category="latents",
|
|
||||||
version="1.2.2",
|
|
||||||
)
|
|
||||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
||||||
"""Generates an image from latents."""
|
|
||||||
|
|
||||||
latents: LatentsField = InputField(
|
|
||||||
description=FieldDescriptions.latents,
|
|
||||||
input=Input.Connection,
|
|
||||||
)
|
|
||||||
vae: VAEField = InputField(
|
|
||||||
description=FieldDescriptions.vae,
|
|
||||||
input=Input.Connection,
|
|
||||||
)
|
|
||||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
|
||||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
|
||||||
|
|
||||||
vae_info = context.models.load(self.vae.vae)
|
|
||||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
|
||||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
|
||||||
assert isinstance(vae, torch.nn.Module)
|
|
||||||
latents = latents.to(vae.device)
|
|
||||||
if self.fp32:
|
|
||||||
vae.to(dtype=torch.float32)
|
|
||||||
|
|
||||||
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
|
||||||
vae.decoder.mid_block.attentions[0].processor,
|
|
||||||
(
|
|
||||||
AttnProcessor2_0,
|
|
||||||
XFormersAttnProcessor,
|
|
||||||
LoRAXFormersAttnProcessor,
|
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# if xformers or torch_2_0 is used attention block does not need
|
|
||||||
# to be in float32 which can save lots of memory
|
|
||||||
if use_torch_2_0_or_xformers:
|
|
||||||
vae.post_quant_conv.to(latents.dtype)
|
|
||||||
vae.decoder.conv_in.to(latents.dtype)
|
|
||||||
vae.decoder.mid_block.to(latents.dtype)
|
|
||||||
else:
|
|
||||||
latents = latents.float()
|
|
||||||
|
|
||||||
else:
|
|
||||||
vae.to(dtype=torch.float16)
|
|
||||||
latents = latents.half()
|
|
||||||
|
|
||||||
if self.tiled or context.config.get().force_tiled_decode:
|
|
||||||
vae.enable_tiling()
|
|
||||||
else:
|
|
||||||
vae.disable_tiling()
|
|
||||||
|
|
||||||
# clear memory as vae decode can request a lot
|
|
||||||
TorchDevice.empty_cache()
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
# copied from diffusers pipeline
|
|
||||||
latents = latents / vae.config.scaling_factor
|
|
||||||
image = vae.decode(latents, return_dict=False)[0]
|
|
||||||
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
|
|
||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
|
||||||
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
||||||
|
|
||||||
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
|
||||||
|
|
||||||
TorchDevice.empty_cache()
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=image)
|
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
107
invokeai/app/invocations/latents_to_image.py
Normal file
107
invokeai/app/invocations/latents_to_image.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import torch
|
||||||
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
|
from diffusers.models.attention_processor import (
|
||||||
|
AttnProcessor2_0,
|
||||||
|
LoRAAttnProcessor2_0,
|
||||||
|
LoRAXFormersAttnProcessor,
|
||||||
|
XFormersAttnProcessor,
|
||||||
|
)
|
||||||
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||||
|
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||||
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
LatentsField,
|
||||||
|
WithBoard,
|
||||||
|
WithMetadata,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.latent import DEFAULT_PRECISION
|
||||||
|
from invokeai.app.invocations.model import VAEField
|
||||||
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.stable_diffusion import set_seamless
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"l2i",
|
||||||
|
title="Latents to Image",
|
||||||
|
tags=["latents", "image", "vae", "l2i"],
|
||||||
|
category="latents",
|
||||||
|
version="1.2.2",
|
||||||
|
)
|
||||||
|
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
|
latents: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
vae: VAEField = InputField(
|
||||||
|
description=FieldDescriptions.vae,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||||
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
|
vae_info = context.models.load(self.vae.vae)
|
||||||
|
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
||||||
|
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
|
assert isinstance(vae, torch.nn.Module)
|
||||||
|
latents = latents.to(vae.device)
|
||||||
|
if self.fp32:
|
||||||
|
vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||||
|
vae.decoder.mid_block.attentions[0].processor,
|
||||||
|
(
|
||||||
|
AttnProcessor2_0,
|
||||||
|
XFormersAttnProcessor,
|
||||||
|
LoRAXFormersAttnProcessor,
|
||||||
|
LoRAAttnProcessor2_0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# if xformers or torch_2_0 is used attention block does not need
|
||||||
|
# to be in float32 which can save lots of memory
|
||||||
|
if use_torch_2_0_or_xformers:
|
||||||
|
vae.post_quant_conv.to(latents.dtype)
|
||||||
|
vae.decoder.conv_in.to(latents.dtype)
|
||||||
|
vae.decoder.mid_block.to(latents.dtype)
|
||||||
|
else:
|
||||||
|
latents = latents.float()
|
||||||
|
|
||||||
|
else:
|
||||||
|
vae.to(dtype=torch.float16)
|
||||||
|
latents = latents.half()
|
||||||
|
|
||||||
|
if self.tiled or context.config.get().force_tiled_decode:
|
||||||
|
vae.enable_tiling()
|
||||||
|
else:
|
||||||
|
vae.disable_tiling()
|
||||||
|
|
||||||
|
# clear memory as vae decode can request a lot
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
# copied from diffusers pipeline
|
||||||
|
latents = latents / vae.config.scaling_factor
|
||||||
|
image = vae.decode(latents, return_dict=False)[0]
|
||||||
|
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
|
||||||
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||||
|
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
|
|
||||||
|
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
||||||
|
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
|
image_dto = context.images.save(image=image)
|
||||||
|
|
||||||
|
return ImageOutput.build(image_dto)
|
Loading…
Reference in New Issue
Block a user