mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): default to CPU noise
This commit is contained in:
parent
3c30368c62
commit
2e14528e4c
@ -23,7 +23,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
||||||
PostprocessingSettings
|
PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import torch_dtype
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
InvocationConfig, InvocationContext)
|
InvocationConfig, InvocationContext)
|
||||||
@ -59,31 +59,12 @@ def build_latents_output(latents_name: str, latents: torch.Tensor):
|
|||||||
height=latents.size()[2] * 8,
|
height=latents.size()[2] * 8,
|
||||||
)
|
)
|
||||||
|
|
||||||
class NoiseOutput(BaseInvocationOutput):
|
|
||||||
"""Invocation noise output"""
|
|
||||||
#fmt: off
|
|
||||||
type: Literal["noise_output"] = "noise_output"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
noise: LatentsField = Field(default=None, description="The output noise")
|
|
||||||
width: int = Field(description="The width of the noise in pixels")
|
|
||||||
height: int = Field(description="The height of the noise in pixels")
|
|
||||||
#fmt: on
|
|
||||||
|
|
||||||
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
|
||||||
return NoiseOutput(
|
|
||||||
noise=LatentsField(latents_name=latents_name),
|
|
||||||
width=latents.size()[3] * 8,
|
|
||||||
height=latents.size()[2] * 8,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
tuple(list(SCHEDULER_MAP.keys()))
|
tuple(list(SCHEDULER_MAP.keys()))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler(
|
def get_scheduler(
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
scheduler_info: ModelInfo,
|
scheduler_info: ModelInfo,
|
||||||
@ -105,62 +86,6 @@ def get_scheduler(
|
|||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
|
|
||||||
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
|
|
||||||
# limit noise to only the diffusion image channels, not the mask channels
|
|
||||||
input_channels = min(latent_channels, 4)
|
|
||||||
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
|
|
||||||
generator = torch.Generator(device=use_device).manual_seed(seed)
|
|
||||||
x = torch.randn(
|
|
||||||
[
|
|
||||||
1,
|
|
||||||
input_channels,
|
|
||||||
height // downsampling_factor,
|
|
||||||
width // downsampling_factor,
|
|
||||||
],
|
|
||||||
dtype=torch_dtype(device),
|
|
||||||
device=use_device,
|
|
||||||
generator=generator,
|
|
||||||
).to(device)
|
|
||||||
# if self.perlin > 0.0:
|
|
||||||
# perlin_noise = self.get_perlin_noise(
|
|
||||||
# width // self.downsampling_factor, height // self.downsampling_factor
|
|
||||||
# )
|
|
||||||
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
|
||||||
return x
|
|
||||||
|
|
||||||
class NoiseInvocation(BaseInvocation):
|
|
||||||
"""Generates latent noise."""
|
|
||||||
|
|
||||||
type: Literal["noise"] = "noise"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
|
|
||||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
|
||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
|
||||||
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["latents", "noise"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@validator("seed", pre=True)
|
|
||||||
def modulo_seed(cls, v):
|
|
||||||
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
|
|
||||||
return v % SEED_MAX
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
|
||||||
device = torch.device(choose_torch_device())
|
|
||||||
noise = get_noise(self.width, self.height, device, self.seed)
|
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
|
||||||
context.services.latents.save(name, noise)
|
|
||||||
return build_noise_output(latents_name=name, latents=noise)
|
|
||||||
|
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
class TextToLatentsInvocation(BaseInvocation):
|
class TextToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from conditionings."""
|
||||||
|
134
invokeai/app/invocations/noise.py
Normal file
134
invokeai/app/invocations/noise.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import Field, validator
|
||||||
|
import torch
|
||||||
|
from invokeai.app.invocations.latent import LatentsField
|
||||||
|
|
||||||
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InvocationConfig,
|
||||||
|
InvocationContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Utilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_noise(
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
device: torch.device,
|
||||||
|
seed: int = 0,
|
||||||
|
latent_channels: int = 4,
|
||||||
|
downsampling_factor: int = 8,
|
||||||
|
use_cpu: bool = True,
|
||||||
|
perlin: float = 0.0,
|
||||||
|
):
|
||||||
|
"""Generate noise for a given image size."""
|
||||||
|
noise_device_type = "cpu" if (use_cpu or device.type == "mps") else device.type
|
||||||
|
|
||||||
|
# limit noise to only the diffusion image channels, not the mask channels
|
||||||
|
input_channels = min(latent_channels, 4)
|
||||||
|
generator = torch.Generator(device=noise_device_type).manual_seed(seed)
|
||||||
|
|
||||||
|
noise_tensor = torch.randn(
|
||||||
|
[
|
||||||
|
1,
|
||||||
|
input_channels,
|
||||||
|
height // downsampling_factor,
|
||||||
|
width // downsampling_factor,
|
||||||
|
],
|
||||||
|
dtype=torch_dtype(device),
|
||||||
|
device=noise_device_type,
|
||||||
|
generator=generator,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
return noise_tensor
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Nodes
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseOutput(BaseInvocationOutput):
|
||||||
|
"""Invocation noise output"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["noise_output"] = "noise_output"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
noise: LatentsField = Field(default=None, description="The output noise")
|
||||||
|
width: int = Field(description="The width of the noise in pixels")
|
||||||
|
height: int = Field(description="The height of the noise in pixels")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
||||||
|
return NoiseOutput(
|
||||||
|
noise=LatentsField(latents_name=latents_name),
|
||||||
|
width=latents.size()[3] * 8,
|
||||||
|
height=latents.size()[2] * 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseInvocation(BaseInvocation):
|
||||||
|
"""Generates latent noise."""
|
||||||
|
|
||||||
|
type: Literal["noise"] = "noise"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
seed: int = Field(
|
||||||
|
ge=0,
|
||||||
|
le=SEED_MAX,
|
||||||
|
description="The seed to use",
|
||||||
|
default_factory=get_random_seed,
|
||||||
|
)
|
||||||
|
width: int = Field(
|
||||||
|
default=512,
|
||||||
|
multiple_of=8,
|
||||||
|
gt=0,
|
||||||
|
description="The width of the resulting noise",
|
||||||
|
)
|
||||||
|
height: int = Field(
|
||||||
|
default=512,
|
||||||
|
multiple_of=8,
|
||||||
|
gt=0,
|
||||||
|
description="The height of the resulting noise",
|
||||||
|
)
|
||||||
|
use_cpu: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Use CPU for noise generation (for reproducible results across platforms)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["latents", "noise"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@validator("seed", pre=True)
|
||||||
|
def modulo_seed(cls, v):
|
||||||
|
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
|
||||||
|
return v % SEED_MAX
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||||
|
noise = get_noise(
|
||||||
|
width=self.width,
|
||||||
|
height=self.height,
|
||||||
|
device=choose_torch_device(),
|
||||||
|
seed=self.seed,
|
||||||
|
use_cpu=self.use_cpu,
|
||||||
|
)
|
||||||
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
context.services.latents.save(name, noise)
|
||||||
|
return build_noise_output(latents_name=name, latents=noise)
|
@ -1,4 +1,5 @@
|
|||||||
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation
|
||||||
|
from ..invocations.noise import NoiseInvocation
|
||||||
from ..invocations.compel import CompelInvocation
|
from ..invocations.compel import CompelInvocation
|
||||||
from ..invocations.params import ParamIntInvocation
|
from ..invocations.params import ParamIntInvocation
|
||||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||||
|
Loading…
Reference in New Issue
Block a user