updates per stalkers comments

This commit is contained in:
Kent Keirsey 2023-08-27 22:54:53 -04:00
parent 19e0f360e7
commit 5fdd25501b
3 changed files with 8 additions and 41 deletions

View File

@ -33,7 +33,7 @@ from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management.models import BaseModelType from ...backend.model_management.models import BaseModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.seamless import set_unet_seamless, set_vae_seamless from ...backend.model_management.seamless import set_seamless
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ConditioningData,
@ -401,7 +401,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet( with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader() unet_info.context.model, _lora_loader()
), set_unet_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet: ), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet:
latents = latents.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None: if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype) noise = noise.to(device=unet.device, dtype=unet.dtype)
@ -491,7 +491,7 @@ class LatentsToImageInvocation(BaseInvocation):
context=context, context=context,
) )
with set_vae_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
latents = latents.to(vae.device) latents = latents.to(vae.device)
if self.fp32: if self.fp32:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)

View File

@ -398,8 +398,8 @@ class SeamlessModeOutput(BaseInvocationOutput):
type: Literal["seamless_output"] = "seamless_output" type: Literal["seamless_output"] = "seamless_output"
# Outputs # Outputs
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
@title("Seamless") @title("Seamless")
@tags("seamless", "model") @tags("seamless", "model")

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import TypeVar, Union
import torch.nn as nn import torch.nn as nn
from diffusers.models import UNet2DModel, AutoencoderKL from diffusers.models import UNet2DModel, AutoencoderKL
@ -23,43 +23,10 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager @contextmanager
def set_unet_seamless(model: UNet2DModel, seamless_axes):
try: ModelType = TypeVar('ModelType', UNet2DModel, AutoencoderKL)
to_restore = []
def set_seamless(model: ModelType, seamless_axes):
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield
finally:
for module, orig_conv_forward in to_restore:
module._conv_forward = orig_conv_forward
if hasattr(m, "asymmetric_padding_mode"):
del m.asymmetric_padding_mode
if hasattr(m, "asymmetric_padding"):
del m.asymmetric_padding
def set_vae_seamless(model: AutoencoderKL, seamless_axes):
try: try:
to_restore = [] to_restore = []