Fix issue with seamless context managers when seamless is not configured.

This commit is contained in:
Ryan Dick
2024-01-05 10:31:58 -05:00
parent 7f3be627c2
commit 9b763b9e4c

View File

@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import contextlib
from contextlib import ExitStack from contextlib import ExitStack
from functools import singledispatchmethod from functools import singledispatchmethod
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
@ -716,10 +717,23 @@ class DenoiseLatentsInvocation(BaseInvocation):
**self.unet.unet.model_dump(), **self.unet.unet.model_dump(),
context=context, context=context,
) )
# Prepare seamless context, if configured.
seamless_context = contextlib.nullcontext()
seamless_config = self.unet.seamless
if seamless_config is not None:
seamless_context = set_seamless(
model=unet_info.context.model,
axes=seamless_config.axes,
skipped_layers=seamless_config.skipped_layers,
skip_second_resnet=seamless_config.skip_second_resnet,
skip_conv2=seamless_config.skip_conv2,
)
with ( with (
ExitStack() as exit_stack, ExitStack() as exit_stack,
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
set_seamless(unet_info.context.model, **self.unet.seamless.dict()), seamless_context,
unet_info as unet, unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching. # Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()), ModelPatcher.apply_lora_unet(unet, _lora_loader()),
@ -826,7 +840,19 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
context=context, context=context,
) )
with set_seamless(vae_info.context.model, **self.vae.seamless.dict()), vae_info as vae: # Prepare seamless context, if configured.
seamless_context = contextlib.nullcontext()
seamless_config = self.vae.seamless
if seamless_config is not None:
seamless_context = set_seamless(
model=vae_info.context.model,
axes=seamless_config.axes,
skipped_layers=seamless_config.skipped_layers,
skip_second_resnet=seamless_config.skip_second_resnet,
skip_conv2=seamless_config.skip_conv2,
)
with seamless_context, vae_info as vae:
latents = latents.to(vae.device) latents = latents.to(vae.device)
if self.fp32: if self.fp32:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)