This commit is contained in:
JPPhoto 2023-03-13 08:11:09 -05:00
parent 7fe2606cb3
commit b980e563b9
6 changed files with 22 additions and 20 deletions

View File

@ -58,7 +58,7 @@ class InvokeAIGeneratorOutput:
''' '''
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
operation, including the image, its seed, the model name used to generate the image operation, including the image, its seed, the model name used to generate the image
and the model hash, as well as all the generate() parameters that went into and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes) generating the image (in .params, also available as attributes)
''' '''
image: Image image: Image
@ -116,7 +116,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
outputs = txt2img.generate(prompt='banana sushi', iterations=None) outputs = txt2img.generate(prompt='banana sushi', iterations=None)
for o in outputs: for o in outputs:
print(o.image, o.seed) print(o.image, o.seed)
''' '''
generator_args = dataclasses.asdict(self.params) generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args) generator_args.update(keyword_args)
@ -167,7 +167,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
if callback: if callback:
callback(output) callback(output)
yield output yield output
@classmethod @classmethod
def schedulers(self)->List[str]: def schedulers(self)->List[str]:
''' '''
@ -177,7 +177,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]): def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision) return generator_class(model, self.params.precision)
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim') scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
scheduler = scheduler_class.from_config(model.scheduler.config) scheduler = scheduler_class.from_config(model.scheduler.config)
@ -267,12 +267,12 @@ class Embiggen(Txt2Img):
embiggen_tiles=embiggen_tiles, embiggen_tiles=embiggen_tiles,
strength=strength, strength=strength,
**kwargs) **kwargs)
@classmethod @classmethod
def _generator_class(cls): def _generator_class(cls):
from .embiggen import Embiggen from .embiggen import Embiggen
return Embiggen return Embiggen
class Generator: class Generator:
downsampling_factor: int downsampling_factor: int
@ -347,7 +347,6 @@ class Generator:
h_symmetry_time_pct=h_symmetry_time_pct, h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct, v_symmetry_time_pct=v_symmetry_time_pct,
attention_maps_callback=attention_maps_callback, attention_maps_callback=attention_maps_callback,
seed=seed,
**kwargs, **kwargs,
) )
results = [] results = []
@ -375,7 +374,8 @@ class Generator:
print("** An error occurred while getting initial noise **") print("** An error occurred while getting initial noise **")
print(traceback.format_exc()) print(traceback.format_exc())
image = make_image(x_T) # Pass on the seed in case a layer beneath us needs to generate noise on its own.
image = make_image(x_T, seed)
if self.safety_checker is not None: if self.safety_checker is not None:
image = self.safety_checker.check(image) image = self.safety_checker.check(image)

View File

@ -37,7 +37,6 @@ class Img2Img(Generator):
h_symmetry_time_pct=None, h_symmetry_time_pct=None,
v_symmetry_time_pct=None, v_symmetry_time_pct=None,
attention_maps_callback=None, attention_maps_callback=None,
seed=None,
**kwargs, **kwargs,
): ):
""" """
@ -64,7 +63,7 @@ class Img2Img(Generator):
), ),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T): def make_image(x_T: torch.Tensor, seed: int):
# FIXME: use x_T for initial seeded noise # FIXME: use x_T for initial seeded noise
# We're not at the moment because the pipeline automatically resizes init_image if # We're not at the moment because the pipeline automatically resizes init_image if
# necessary, which the x_T input might not match. # necessary, which the x_T input might not match.
@ -77,7 +76,7 @@ class Img2Img(Generator):
conditioning_data, conditioning_data,
noise_func=self.get_noise_like, noise_func=self.get_noise_like,
callback=step_callback, callback=step_callback,
seed=seed seed=seed,
) )
if ( if (
pipeline_output.attention_map_saver is not None pipeline_output.attention_map_saver is not None
@ -88,9 +87,7 @@ class Img2Img(Generator):
return make_image return make_image
def get_noise_like(self, like: torch.Tensor, seed: Optional[int]): def get_noise_like(self, like: torch.Tensor):
if seed is not None:
set_seed(seed)
device = like.device device = like.device
if device.type == "mps": if device.type == "mps":
x = torch.randn_like(like, device="cpu").to(device) x = torch.randn_like(like, device="cpu").to(device)

View File

@ -311,7 +311,7 @@ class Inpaint(Img2Img):
uc, c, cfg_scale uc, c, cfg_scale
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T): def make_image(x_T: torch.Tensor, seed: int):
pipeline_output = pipeline.inpaint_from_embeddings( pipeline_output = pipeline.inpaint_from_embeddings(
init_image=init_image, init_image=init_image,
mask=1 - mask, # expects white means "paint here." mask=1 - mask, # expects white means "paint here."
@ -320,7 +320,7 @@ class Inpaint(Img2Img):
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
noise_func=self.get_noise_like, noise_func=self.get_noise_like,
callback=step_callback, callback=step_callback,
seed=seed seed=seed,
) )
if ( if (

View File

@ -61,7 +61,7 @@ class Txt2Img(Generator):
), ),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T) -> PIL.Image.Image: def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings( pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()), latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
noise=x_T, noise=x_T,

View File

@ -64,7 +64,7 @@ class Txt2Img2Img(Generator):
), ),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T): def make_image(x_T: torch.Tensor, _: int):
first_pass_latent_output, _ = pipeline.latents_from_embeddings( first_pass_latent_output, _ = pipeline.latents_from_embeddings(
latents=torch.zeros_like(x_T), latents=torch.zeros_like(x_T),
num_inference_steps=steps, num_inference_steps=steps,

View File

@ -9,6 +9,7 @@ from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
import einops import einops
import PIL.Image import PIL.Image
from accelerate.utils import set_seed
import psutil import psutil
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
@ -694,7 +695,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device=self._model_group.device_for(self.unet), device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype, dtype=self.unet.dtype,
) )
noise = noise_func(initial_latents, seed) if seed is not None:
set_seed(seed)
noise = noise_func(initial_latents)
return self.img2img_from_latents_and_embeddings( return self.img2img_from_latents_and_embeddings(
initial_latents, initial_latents,
@ -796,7 +799,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_image_latents = self.non_noised_latents_from_image( init_image_latents = self.non_noised_latents_from_image(
init_image, device=device, dtype=latents_dtype init_image, device=device, dtype=latents_dtype
) )
noise = noise_func(init_image_latents, seed) if seed is not None:
set_seed(seed)
noise = noise_func(init_image_latents)
if mask.dim() == 3: if mask.dim() == 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)