From 2e14528e4c2fffa9b23415f0a016bc05bd09090b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 27 Jun 2023 13:57:31 +1000 Subject: [PATCH] feat(nodes): default to CPU noise --- invokeai/app/invocations/latent.py | 77 +------------- invokeai/app/invocations/noise.py | 134 ++++++++++++++++++++++++ invokeai/app/services/default_graphs.py | 3 +- 3 files changed, 137 insertions(+), 77 deletions(-) create mode 100644 invokeai/app/invocations/noise.py diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 015ecab211..f42743eb62 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -23,7 +23,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP -from ...backend.util.devices import choose_torch_device, torch_dtype +from ...backend.util.devices import torch_dtype from ...backend.model_management.lora import ModelPatcher from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) @@ -59,31 +59,12 @@ def build_latents_output(latents_name: str, latents: torch.Tensor): height=latents.size()[2] * 8, ) -class NoiseOutput(BaseInvocationOutput): - """Invocation noise output""" - #fmt: off - type: Literal["noise_output"] = "noise_output" - - # Inputs - noise: LatentsField = Field(default=None, description="The output noise") - width: int = Field(description="The width of the noise in pixels") - height: int = Field(description="The height of the noise in pixels") - #fmt: on - -def build_noise_output(latents_name: str, latents: torch.Tensor): - return NoiseOutput( - noise=LatentsField(latents_name=latents_name), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) - SAMPLER_NAME_VALUES = Literal[ tuple(list(SCHEDULER_MAP.keys())) ] - def get_scheduler( context: InvocationContext, scheduler_info: ModelInfo, @@ -105,62 +86,6 @@ def get_scheduler( return scheduler -def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8): - # limit noise to only the diffusion image channels, not the mask channels - input_channels = min(latent_channels, 4) - use_device = "cpu" if (use_mps_noise or device.type == "mps") else device - generator = torch.Generator(device=use_device).manual_seed(seed) - x = torch.randn( - [ - 1, - input_channels, - height // downsampling_factor, - width // downsampling_factor, - ], - dtype=torch_dtype(device), - device=use_device, - generator=generator, - ).to(device) - # if self.perlin > 0.0: - # perlin_noise = self.get_perlin_noise( - # width // self.downsampling_factor, height // self.downsampling_factor - # ) - # x = (1 - self.perlin) * x + self.perlin * perlin_noise - return x - -class NoiseInvocation(BaseInvocation): - """Generates latent noise.""" - - type: Literal["noise"] = "noise" - - # Inputs - seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed) - width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", ) - height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", ) - - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["latents", "noise"], - }, - } - - @validator("seed", pre=True) - def modulo_seed(cls, v): - """Returns the seed modulo SEED_MAX to ensure it is within the valid range.""" - return v % SEED_MAX - - def invoke(self, context: InvocationContext) -> NoiseOutput: - device = torch.device(choose_torch_device()) - noise = get_noise(self.width, self.height, device, self.seed) - - name = f'{context.graph_execution_state_id}__{self.id}' - context.services.latents.save(name, noise) - return build_noise_output(latents_name=name, latents=noise) - - # Text to image class TextToLatentsInvocation(BaseInvocation): """Generates latents from conditionings.""" diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py new file mode 100644 index 0000000000..c5866f3608 --- /dev/null +++ b/invokeai/app/invocations/noise.py @@ -0,0 +1,134 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team + +import math +from typing import Literal + +from pydantic import Field, validator +import torch +from invokeai.app.invocations.latent import LatentsField + +from invokeai.app.util.misc import SEED_MAX, get_random_seed +from ...backend.util.devices import choose_torch_device, torch_dtype +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InvocationConfig, + InvocationContext, +) + +""" +Utilities +""" + + +def get_noise( + width: int, + height: int, + device: torch.device, + seed: int = 0, + latent_channels: int = 4, + downsampling_factor: int = 8, + use_cpu: bool = True, + perlin: float = 0.0, +): + """Generate noise for a given image size.""" + noise_device_type = "cpu" if (use_cpu or device.type == "mps") else device.type + + # limit noise to only the diffusion image channels, not the mask channels + input_channels = min(latent_channels, 4) + generator = torch.Generator(device=noise_device_type).manual_seed(seed) + + noise_tensor = torch.randn( + [ + 1, + input_channels, + height // downsampling_factor, + width // downsampling_factor, + ], + dtype=torch_dtype(device), + device=noise_device_type, + generator=generator, + ).to(device) + + return noise_tensor + + +""" +Nodes +""" + + +class NoiseOutput(BaseInvocationOutput): + """Invocation noise output""" + + # fmt: off + type: Literal["noise_output"] = "noise_output" + + # Inputs + noise: LatentsField = Field(default=None, description="The output noise") + width: int = Field(description="The width of the noise in pixels") + height: int = Field(description="The height of the noise in pixels") + # fmt: on + + +def build_noise_output(latents_name: str, latents: torch.Tensor): + return NoiseOutput( + noise=LatentsField(latents_name=latents_name), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) + + +class NoiseInvocation(BaseInvocation): + """Generates latent noise.""" + + type: Literal["noise"] = "noise" + + # Inputs + seed: int = Field( + ge=0, + le=SEED_MAX, + description="The seed to use", + default_factory=get_random_seed, + ) + width: int = Field( + default=512, + multiple_of=8, + gt=0, + description="The width of the resulting noise", + ) + height: int = Field( + default=512, + multiple_of=8, + gt=0, + description="The height of the resulting noise", + ) + use_cpu: bool = Field( + default=True, + description="Use CPU for noise generation (for reproducible results across platforms)", + ) + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["latents", "noise"], + }, + } + + @validator("seed", pre=True) + def modulo_seed(cls, v): + """Returns the seed modulo SEED_MAX to ensure it is within the valid range.""" + return v % SEED_MAX + + def invoke(self, context: InvocationContext) -> NoiseOutput: + noise = get_noise( + width=self.width, + height=self.height, + device=choose_torch_device(), + seed=self.seed, + use_cpu=self.use_cpu, + ) + name = f"{context.graph_execution_state_id}__{self.id}" + context.services.latents.save(name, noise) + return build_noise_output(latents_name=name, latents=noise) diff --git a/invokeai/app/services/default_graphs.py b/invokeai/app/services/default_graphs.py index 5eda5e957d..92263751b7 100644 --- a/invokeai/app/services/default_graphs.py +++ b/invokeai/app/services/default_graphs.py @@ -1,4 +1,5 @@ -from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation +from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation +from ..invocations.noise import NoiseInvocation from ..invocations.compel import CompelInvocation from ..invocations.params import ParamIntInvocation from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph