feat(nodes): default to CPU noise

This commit is contained in:
psychedelicious 2023-06-27 13:57:31 +10:00
parent 3c30368c62
commit 2e14528e4c
3 changed files with 137 additions and 77 deletions

View File

@ -23,7 +23,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.util.devices import torch_dtype
from ...backend.model_management.lora import ModelPatcher
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
@ -59,31 +59,12 @@ def build_latents_output(latents_name: str, latents: torch.Tensor):
height=latents.size()[2] * 8,
)
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
#fmt: off
type: Literal["noise_output"] = "noise_output"
# Inputs
noise: LatentsField = Field(default=None, description="The output noise")
width: int = Field(description="The width of the noise in pixels")
height: int = Field(description="The height of the noise in pixels")
#fmt: on
def build_noise_output(latents_name: str, latents: torch.Tensor):
return NoiseOutput(
noise=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
SAMPLER_NAME_VALUES = Literal[
tuple(list(SCHEDULER_MAP.keys()))
]
def get_scheduler(
context: InvocationContext,
scheduler_info: ModelInfo,
@ -105,62 +86,6 @@ def get_scheduler(
return scheduler
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(latent_channels, 4)
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
generator = torch.Generator(device=use_device).manual_seed(seed)
x = torch.randn(
[
1,
input_channels,
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
device=use_device,
generator=generator,
).to(device)
# if self.perlin > 0.0:
# perlin_noise = self.get_perlin_noise(
# width // self.downsampling_factor, height // self.downsampling_factor
# )
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "noise"],
},
}
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
return v % SEED_MAX
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(choose_torch_device())
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, noise)
return build_noise_output(latents_name=name, latents=noise)
# Text to image
class TextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""

View File

@ -0,0 +1,134 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
import math
from typing import Literal
from pydantic import Field, validator
import torch
from invokeai.app.invocations.latent import LatentsField
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from ...backend.util.devices import choose_torch_device, torch_dtype
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationConfig,
InvocationContext,
)
"""
Utilities
"""
def get_noise(
width: int,
height: int,
device: torch.device,
seed: int = 0,
latent_channels: int = 4,
downsampling_factor: int = 8,
use_cpu: bool = True,
perlin: float = 0.0,
):
"""Generate noise for a given image size."""
noise_device_type = "cpu" if (use_cpu or device.type == "mps") else device.type
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(latent_channels, 4)
generator = torch.Generator(device=noise_device_type).manual_seed(seed)
noise_tensor = torch.randn(
[
1,
input_channels,
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
device=noise_device_type,
generator=generator,
).to(device)
return noise_tensor
"""
Nodes
"""
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
# fmt: off
type: Literal["noise_output"] = "noise_output"
# Inputs
noise: LatentsField = Field(default=None, description="The output noise")
width: int = Field(description="The width of the noise in pixels")
height: int = Field(description="The height of the noise in pixels")
# fmt: on
def build_noise_output(latents_name: str, latents: torch.Tensor):
return NoiseOutput(
noise=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(
ge=0,
le=SEED_MAX,
description="The seed to use",
default_factory=get_random_seed,
)
width: int = Field(
default=512,
multiple_of=8,
gt=0,
description="The width of the resulting noise",
)
height: int = Field(
default=512,
multiple_of=8,
gt=0,
description="The height of the resulting noise",
)
use_cpu: bool = Field(
default=True,
description="Use CPU for noise generation (for reproducible results across platforms)",
)
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "noise"],
},
}
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
return v % SEED_MAX
def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise(
width=self.width,
height=self.height,
device=choose_torch_device(),
seed=self.seed,
use_cpu=self.use_cpu,
)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, noise)
return build_noise_output(latents_name=name, latents=noise)

View File

@ -1,4 +1,5 @@
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation
from ..invocations.noise import NoiseInvocation
from ..invocations.compel import CompelInvocation
from ..invocations.params import ParamIntInvocation
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph