Make FLUX get_noise(...) consistent across devices/dtypes.

This commit is contained in:
Ryan Dick 2024-08-22 15:56:30 +00:00 committed by Brandon
parent 0c5649491e
commit 185f2a395f
2 changed files with 7 additions and 6 deletions

View File

@ -79,8 +79,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
inference_dtype = torch.bfloat16 inference_dtype = torch.bfloat16
# Prepare input noise. # Prepare input noise.
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
# CPU RNG?
x = get_noise( x = get_noise(
num_samples=1, num_samples=1,
height=self.height, height=self.height,

View File

@ -20,16 +20,19 @@ def get_noise(
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
): ):
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn( return torch.randn(
num_samples, num_samples,
16, 16,
# allow for packing # allow for packing
2 * math.ceil(height / 16), 2 * math.ceil(height / 16),
2 * math.ceil(width / 16), 2 * math.ceil(width / 16),
device=device, device=rand_device,
dtype=dtype, dtype=rand_dtype,
generator=torch.Generator(device=device).manual_seed(seed), generator=torch.Generator(device=rand_device).manual_seed(seed),
) ).to(device=device, dtype=dtype)
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: