mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update l2i invoke and seamless to support AutoencoderTiny, remove att… (#5936)
…ention processors if no mid_block is detected ## What type of PR is this? (check all applicable) - [ ] Refactor - [ ] Feature - [x] Bug Fix - [ ] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Have you discussed this change with the InvokeAI team? - [x] Yes - [ ] No, because: ## Have you updated all relevant documentation? - [ ] Yes - [x] No ## Description L2i throws an assertion error when run with `madebyollin/taesdxl` due to it requiring a different class in diffusers to load it. This is a small PR to update seamless and l2i to accept AutoencoderTiny models and not throw exceptions while processing them. ## QA Instructions, Screenshots, Recordings <img width="445" alt="Screenshot 2024-03-12 at 12 04 29 PM" src="https://github.com/invoke-ai/InvokeAI/assets/58442074/34a17e44-d911-4fef-8fc1-71f7b688688c"> Run an sdxl pipeline using a vae that requires AutoencoderTiny and validate that the image successfully encodes and decodes. ## Merge Plan This PR can be merged when approved
This commit is contained in:
commit
54f1a1f952
@ -837,14 +837,14 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
latents = context.tensors.load(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.models.load(self.vae.vae)
|
vae_info = context.models.load(self.vae.vae)
|
||||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
|
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
||||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
assert isinstance(vae, torch.nn.Module)
|
assert isinstance(vae, torch.nn.Module)
|
||||||
latents = latents.to(vae.device)
|
latents = latents.to(vae.device)
|
||||||
if self.fp32:
|
if self.fp32:
|
||||||
vae.to(dtype=torch.float32)
|
vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
use_torch_2_0_or_xformers = isinstance(
|
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||||
vae.decoder.mid_block.attentions[0].processor,
|
vae.decoder.mid_block.attentions[0].processor,
|
||||||
(
|
(
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
@ -1018,7 +1018,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
if upcast:
|
if upcast:
|
||||||
vae.to(dtype=torch.float32)
|
vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
use_torch_2_0_or_xformers = isinstance(
|
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||||
vae.decoder.mid_block.attentions[0].processor,
|
vae.decoder.mid_block.attentions[0].processor,
|
||||||
(
|
(
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
|
@ -5,6 +5,7 @@ from typing import Callable, List, Union
|
|||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||||
|
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
|
||||||
|
|
||||||
@ -26,7 +27,7 @@ def _conv_forward_asymmetric(self, input, weight, bias):
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
|
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
|
||||||
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
|
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
|
||||||
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
|
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user