From 103e34691b66e420bb62e0e7fd635ad23ddd07fd Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 5 Jun 2024 11:05:44 -0400 Subject: [PATCH] Move ScaleLatentsInvocation and ResizeLatentsInvocation to their own file. No functional changes. --- invokeai/app/invocations/latent.py | 89 ------------------ invokeai/app/invocations/resize_latents.py | 103 +++++++++++++++++++++ 2 files changed, 103 insertions(+), 89 deletions(-) create mode 100644 invokeai/app/invocations/resize_latents.py diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 5ab5969556..5946e66327 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1113,95 +1113,6 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): return ImageOutput.build(image_dto) -LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] - - -@invocation( - "lresize", - title="Resize Latents", - tags=["latents", "resize"], - category="latents", - version="1.0.2", -) -class ResizeLatentsInvocation(BaseInvocation): - """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" - - latents: LatentsField = InputField( - description=FieldDescriptions.latents, - input=Input.Connection, - ) - width: int = InputField( - ge=64, - multiple_of=LATENT_SCALE_FACTOR, - description=FieldDescriptions.width, - ) - height: int = InputField( - ge=64, - multiple_of=LATENT_SCALE_FACTOR, - description=FieldDescriptions.width, - ) - mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) - antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.load(self.latents.latents_name) - device = TorchDevice.choose_torch_device() - - resized_latents = torch.nn.functional.interpolate( - latents.to(device), - size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR), - mode=self.mode, - antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, - ) - - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - resized_latents = resized_latents.to("cpu") - - TorchDevice.empty_cache() - - name = context.tensors.save(tensor=resized_latents) - return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) - - -@invocation( - "lscale", - title="Scale Latents", - tags=["latents", "resize"], - category="latents", - version="1.0.2", -) -class ScaleLatentsInvocation(BaseInvocation): - """Scales latents by a given factor.""" - - latents: LatentsField = InputField( - description=FieldDescriptions.latents, - input=Input.Connection, - ) - scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor) - mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) - antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.tensors.load(self.latents.latents_name) - - device = TorchDevice.choose_torch_device() - - # resizing - resized_latents = torch.nn.functional.interpolate( - latents.to(device), - scale_factor=self.scale_factor, - mode=self.mode, - antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, - ) - - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - resized_latents = resized_latents.to("cpu") - TorchDevice.empty_cache() - - name = context.tensors.save(tensor=resized_latents) - return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) - - @invocation( "i2l", title="Image to Latents", diff --git a/invokeai/app/invocations/resize_latents.py b/invokeai/app/invocations/resize_latents.py new file mode 100644 index 0000000000..90253e52e8 --- /dev/null +++ b/invokeai/app/invocations/resize_latents.py @@ -0,0 +1,103 @@ +from typing import Literal + +import torch + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.invocations.fields import ( + FieldDescriptions, + Input, + InputField, + LatentsField, +) +from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.util.devices import TorchDevice + +LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] + + +@invocation( + "lresize", + title="Resize Latents", + tags=["latents", "resize"], + category="latents", + version="1.0.2", +) +class ResizeLatentsInvocation(BaseInvocation): + """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" + + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + width: int = InputField( + ge=64, + multiple_of=LATENT_SCALE_FACTOR, + description=FieldDescriptions.width, + ) + height: int = InputField( + ge=64, + multiple_of=LATENT_SCALE_FACTOR, + description=FieldDescriptions.width, + ) + mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) + antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = context.tensors.load(self.latents.latents_name) + device = TorchDevice.choose_torch_device() + + resized_latents = torch.nn.functional.interpolate( + latents.to(device), + size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR), + mode=self.mode, + antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, + ) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + resized_latents = resized_latents.to("cpu") + + TorchDevice.empty_cache() + + name = context.tensors.save(tensor=resized_latents) + return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) + + +@invocation( + "lscale", + title="Scale Latents", + tags=["latents", "resize"], + category="latents", + version="1.0.2", +) +class ScaleLatentsInvocation(BaseInvocation): + """Scales latents by a given factor.""" + + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor) + mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) + antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = context.tensors.load(self.latents.latents_name) + + device = TorchDevice.choose_torch_device() + + # resizing + resized_latents = torch.nn.functional.interpolate( + latents.to(device), + scale_factor=self.scale_factor, + mode=self.mode, + antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, + ) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + resized_latents = resized_latents.to("cpu") + TorchDevice.empty_cache() + + name = context.tensors.save(tensor=resized_latents) + return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)