mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add VAE
This commit is contained in:
parent
0d2e194213
commit
ea40a7844a
@ -33,7 +33,7 @@ from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.seamless import set_unet_seamless
|
||||
from ...backend.model_management.seamless import set_unet_seamless, set_vae_seamless
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
@ -491,7 +491,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
context=context,
|
||||
)
|
||||
|
||||
with vae_info as vae:
|
||||
with set_vae_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
@ -46,6 +46,7 @@ class ClipField(BaseModel):
|
||||
class VaeField(BaseModel):
|
||||
# TODO: better naming?
|
||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description="Axes(\"x\" and \"y\") to which apply seamless")
|
||||
|
||||
|
||||
class ModelLoaderOutput(BaseInvocationOutput):
|
||||
@ -398,6 +399,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
|
||||
|
||||
# Outputs
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
@title("Seamless")
|
||||
@tags("seamless", "model")
|
||||
@ -410,7 +412,9 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
unet: UNetField = InputField(
|
||||
description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||
)
|
||||
|
||||
vae_model: VAEModelField = InputField(
|
||||
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
||||
)
|
||||
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
||||
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
||||
|
||||
@ -418,6 +422,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
|
||||
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
|
||||
unet = copy.deepcopy(self.unet)
|
||||
vae = copy.deepcopy(self.vae)
|
||||
|
||||
seamless_axes_list = []
|
||||
|
||||
@ -427,7 +432,9 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
seamless_axes_list.append('y')
|
||||
|
||||
unet.seamless_axes = seamless_axes_list
|
||||
|
||||
vae.seamless_axes = seamless_axes_list
|
||||
|
||||
return SeamlessModeOutput(
|
||||
unet=unet,
|
||||
vae=vae
|
||||
)
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch.nn as nn
|
||||
from diffusers.models import UNet2DModel
|
||||
from diffusers.models import UNet2DModel, AutoencoderKL
|
||||
|
||||
def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
"""
|
||||
@ -51,6 +51,42 @@ def set_unet_seamless(model: UNet2DModel, seamless_axes):
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
for module, orig_conv_forward in to_restore:
|
||||
module._conv_forward = orig_conv_forward
|
||||
if hasattr(m, "asymmetric_padding_mode"):
|
||||
del m.asymmetric_padding_mode
|
||||
if hasattr(m, "asymmetric_padding"):
|
||||
del m.asymmetric_padding
|
||||
|
||||
def set_vae_seamless(model: AutoencoderKL, seamless_axes):
|
||||
try:
|
||||
to_restore = []
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
m.asymmetric_padding_mode = {}
|
||||
m.asymmetric_padding = {}
|
||||
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
|
||||
m.asymmetric_padding["x"] = (
|
||||
m._reversed_padding_repeated_twice[0],
|
||||
m._reversed_padding_repeated_twice[1],
|
||||
0,
|
||||
0,
|
||||
)
|
||||
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
|
||||
m.asymmetric_padding["y"] = (
|
||||
0,
|
||||
0,
|
||||
m._reversed_padding_repeated_twice[2],
|
||||
m._reversed_padding_repeated_twice[3],
|
||||
)
|
||||
|
||||
to_restore.append((m, m._conv_forward))
|
||||
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
for module, orig_conv_forward in to_restore:
|
||||
module._conv_forward = orig_conv_forward
|
||||
|
Loading…
Reference in New Issue
Block a user