diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index ab59b41865..26294ed7f7 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1166,3 +1166,56 @@ class BlendLatentsInvocation(BaseInvocation): # context.services.latents.set(name, resized_latents) context.services.latents.save(name, blended_latents) return build_latents_output(latents_name=name, latents=blended_latents) + + +@invocation( + "lcrop", + title="Crop Latents", + tags=["latents", "crop"], + category="latents", + version="1.0.0", +) +class CropLatentsInvocation(BaseInvocation): + """Crops latents""" + + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + width: int = InputField( + ge=64, + multiple_of=_downsampling_factor, + description=FieldDescriptions.width, + ) + height: int = InputField( + ge=64, + multiple_of=_downsampling_factor, + description=FieldDescriptions.width, + ) + x_offset: int = InputField( + ge=0, + multiple_of=_downsampling_factor, + description="x-coordinate", + ) + y_offset: int = InputField( + ge=0, + multiple_of=_downsampling_factor, + description="y-coordinate", + ) + + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = context.services.latents.get(self.latents.latents_name) + + x1 = self.x_offset // _downsampling_factor + y1 = self.y_offset // _downsampling_factor + x2 = x1 + (self.width // _downsampling_factor) + y2 = y1 + (self.height // _downsampling_factor) + + cropped_latents = latents[:, :, y1:y2, x1:x2] + + # resized_latents = resized_latents.to("cpu") + + name = f"{context.graph_execution_state_id}__{self.id}" + context.services.latents.save(name, cropped_latents) + + return build_latents_output(latents_name=name, latents=cropped_latents)