This commit is contained in:
JPPhoto 2023-03-13 08:11:09 -05:00
parent 7fe2606cb3
commit b980e563b9
6 changed files with 22 additions and 20 deletions

View File

@ -347,7 +347,6 @@ class Generator:
h_symmetry_time_pct=h_symmetry_time_pct, h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct, v_symmetry_time_pct=v_symmetry_time_pct,
attention_maps_callback=attention_maps_callback, attention_maps_callback=attention_maps_callback,
seed=seed,
**kwargs, **kwargs,
) )
results = [] results = []
@ -375,7 +374,8 @@ class Generator:
print("** An error occurred while getting initial noise **") print("** An error occurred while getting initial noise **")
print(traceback.format_exc()) print(traceback.format_exc())
image = make_image(x_T) # Pass on the seed in case a layer beneath us needs to generate noise on its own.
image = make_image(x_T, seed)
if self.safety_checker is not None: if self.safety_checker is not None:
image = self.safety_checker.check(image) image = self.safety_checker.check(image)

View File

@ -37,7 +37,6 @@ class Img2Img(Generator):
h_symmetry_time_pct=None, h_symmetry_time_pct=None,
v_symmetry_time_pct=None, v_symmetry_time_pct=None,
attention_maps_callback=None, attention_maps_callback=None,
seed=None,
**kwargs, **kwargs,
): ):
""" """
@ -64,7 +63,7 @@ class Img2Img(Generator):
), ),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T): def make_image(x_T: torch.Tensor, seed: int):
# FIXME: use x_T for initial seeded noise # FIXME: use x_T for initial seeded noise
# We're not at the moment because the pipeline automatically resizes init_image if # We're not at the moment because the pipeline automatically resizes init_image if
# necessary, which the x_T input might not match. # necessary, which the x_T input might not match.
@ -77,7 +76,7 @@ class Img2Img(Generator):
conditioning_data, conditioning_data,
noise_func=self.get_noise_like, noise_func=self.get_noise_like,
callback=step_callback, callback=step_callback,
seed=seed seed=seed,
) )
if ( if (
pipeline_output.attention_map_saver is not None pipeline_output.attention_map_saver is not None
@ -88,9 +87,7 @@ class Img2Img(Generator):
return make_image return make_image
def get_noise_like(self, like: torch.Tensor, seed: Optional[int]): def get_noise_like(self, like: torch.Tensor):
if seed is not None:
set_seed(seed)
device = like.device device = like.device
if device.type == "mps": if device.type == "mps":
x = torch.randn_like(like, device="cpu").to(device) x = torch.randn_like(like, device="cpu").to(device)

View File

@ -311,7 +311,7 @@ class Inpaint(Img2Img):
uc, c, cfg_scale uc, c, cfg_scale
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T): def make_image(x_T: torch.Tensor, seed: int):
pipeline_output = pipeline.inpaint_from_embeddings( pipeline_output = pipeline.inpaint_from_embeddings(
init_image=init_image, init_image=init_image,
mask=1 - mask, # expects white means "paint here." mask=1 - mask, # expects white means "paint here."
@ -320,7 +320,7 @@ class Inpaint(Img2Img):
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
noise_func=self.get_noise_like, noise_func=self.get_noise_like,
callback=step_callback, callback=step_callback,
seed=seed seed=seed,
) )
if ( if (

View File

@ -61,7 +61,7 @@ class Txt2Img(Generator):
), ),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T) -> PIL.Image.Image: def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings( pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()), latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
noise=x_T, noise=x_T,

View File

@ -64,7 +64,7 @@ class Txt2Img2Img(Generator):
), ),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T): def make_image(x_T: torch.Tensor, _: int):
first_pass_latent_output, _ = pipeline.latents_from_embeddings( first_pass_latent_output, _ = pipeline.latents_from_embeddings(
latents=torch.zeros_like(x_T), latents=torch.zeros_like(x_T),
num_inference_steps=steps, num_inference_steps=steps,

View File

@ -9,6 +9,7 @@ from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
import einops import einops
import PIL.Image import PIL.Image
from accelerate.utils import set_seed
import psutil import psutil
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
@ -694,7 +695,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device=self._model_group.device_for(self.unet), device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype, dtype=self.unet.dtype,
) )
noise = noise_func(initial_latents, seed) if seed is not None:
set_seed(seed)
noise = noise_func(initial_latents)
return self.img2img_from_latents_and_embeddings( return self.img2img_from_latents_and_embeddings(
initial_latents, initial_latents,
@ -796,7 +799,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_image_latents = self.non_noised_latents_from_image( init_image_latents = self.non_noised_latents_from_image(
init_image, device=device, dtype=latents_dtype init_image, device=device, dtype=latents_dtype
) )
noise = noise_func(init_image_latents, seed) if seed is not None:
set_seed(seed)
noise = noise_func(init_image_latents)
if mask.dim() == 3: if mask.dim() == 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)