mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add resize and scale latents nodes
- this resize/scale latents is what is needed for hires fix - also remove unused `seed` from t2l
This commit is contained in:
parent
00a0cb3403
commit
3e80eaa342
@ -146,7 +146,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||||
# fmt: off
|
# fmt: off
|
||||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
|
||||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||||
@ -363,9 +362,87 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
session_id=context.graph_execution_state_id, node=self
|
session_id=context.graph_execution_state_id, node=self
|
||||||
)
|
)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, image, metadata)
|
context.services.images.save(image_type, image_name, image, metadata)
|
||||||
return build_image_output(
|
return build_image_output(
|
||||||
image_type=image_type,
|
image_type=image_type, image_name=image_name, image=image
|
||||||
image_name=image_name,
|
|
||||||
image=image
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
LATENTS_INTERPOLATION_MODE = Literal[
|
||||||
|
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
|
"""Resizes latents to explicit width/height."""
|
||||||
|
|
||||||
|
type: Literal["lresize"] = "lresize"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||||
|
width: int = Field(ge=64, multiple_of=8, description="The width to resize to")
|
||||||
|
height: int = Field(ge=64, multiple_of=8, description="The height to resize to")
|
||||||
|
downsample: int = Field(
|
||||||
|
default=8, ge=1, description="The downsampling factor (leave at 8 for SD)"
|
||||||
|
)
|
||||||
|
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||||
|
default="bilinear", description="The interpolation mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
# resizing
|
||||||
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
|
latents,
|
||||||
|
size=(
|
||||||
|
self.height // self.downsample,
|
||||||
|
self.width // self.downsample,
|
||||||
|
),
|
||||||
|
mode=self.mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
context.services.latents.set(name, resized_latents)
|
||||||
|
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||||
|
|
||||||
|
|
||||||
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
|
"""Scales latents by a given factor."""
|
||||||
|
|
||||||
|
type: Literal["lscale"] = "lscale"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||||
|
scale: int = Field(
|
||||||
|
default=2, ge=1, description="The factor by which to scale the latents"
|
||||||
|
)
|
||||||
|
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||||
|
default="bilinear", description="The interpolation mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
|
(_, _, h, w) = latents.size()
|
||||||
|
|
||||||
|
# resizing
|
||||||
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
|
latents,
|
||||||
|
size=(
|
||||||
|
h * self.scale,
|
||||||
|
w * self.scale,
|
||||||
|
),
|
||||||
|
mode=self.mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
context.services.latents.set(name, resized_latents)
|
||||||
|
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user