mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix bug #2931
This commit is contained in:
parent
7fe2606cb3
commit
b980e563b9
@ -58,7 +58,7 @@ class InvokeAIGeneratorOutput:
|
|||||||
'''
|
'''
|
||||||
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
|
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
|
||||||
operation, including the image, its seed, the model name used to generate the image
|
operation, including the image, its seed, the model name used to generate the image
|
||||||
and the model hash, as well as all the generate() parameters that went into
|
and the model hash, as well as all the generate() parameters that went into
|
||||||
generating the image (in .params, also available as attributes)
|
generating the image (in .params, also available as attributes)
|
||||||
'''
|
'''
|
||||||
image: Image
|
image: Image
|
||||||
@ -116,7 +116,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
outputs = txt2img.generate(prompt='banana sushi', iterations=None)
|
outputs = txt2img.generate(prompt='banana sushi', iterations=None)
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
print(o.image, o.seed)
|
print(o.image, o.seed)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
generator_args = dataclasses.asdict(self.params)
|
generator_args = dataclasses.asdict(self.params)
|
||||||
generator_args.update(keyword_args)
|
generator_args.update(keyword_args)
|
||||||
@ -167,7 +167,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
if callback:
|
if callback:
|
||||||
callback(output)
|
callback(output)
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def schedulers(self)->List[str]:
|
def schedulers(self)->List[str]:
|
||||||
'''
|
'''
|
||||||
@ -177,7 +177,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
|
|
||||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||||
return generator_class(model, self.params.precision)
|
return generator_class(model, self.params.precision)
|
||||||
|
|
||||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||||
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||||
@ -267,12 +267,12 @@ class Embiggen(Txt2Img):
|
|||||||
embiggen_tiles=embiggen_tiles,
|
embiggen_tiles=embiggen_tiles,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _generator_class(cls):
|
def _generator_class(cls):
|
||||||
from .embiggen import Embiggen
|
from .embiggen import Embiggen
|
||||||
return Embiggen
|
return Embiggen
|
||||||
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
downsampling_factor: int
|
downsampling_factor: int
|
||||||
@ -347,7 +347,6 @@ class Generator:
|
|||||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||||
attention_maps_callback=attention_maps_callback,
|
attention_maps_callback=attention_maps_callback,
|
||||||
seed=seed,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
results = []
|
results = []
|
||||||
@ -375,7 +374,8 @@ class Generator:
|
|||||||
print("** An error occurred while getting initial noise **")
|
print("** An error occurred while getting initial noise **")
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
image = make_image(x_T)
|
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
||||||
|
image = make_image(x_T, seed)
|
||||||
|
|
||||||
if self.safety_checker is not None:
|
if self.safety_checker is not None:
|
||||||
image = self.safety_checker.check(image)
|
image = self.safety_checker.check(image)
|
||||||
|
@ -37,7 +37,6 @@ class Img2Img(Generator):
|
|||||||
h_symmetry_time_pct=None,
|
h_symmetry_time_pct=None,
|
||||||
v_symmetry_time_pct=None,
|
v_symmetry_time_pct=None,
|
||||||
attention_maps_callback=None,
|
attention_maps_callback=None,
|
||||||
seed=None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -64,7 +63,7 @@ class Img2Img(Generator):
|
|||||||
),
|
),
|
||||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||||
|
|
||||||
def make_image(x_T):
|
def make_image(x_T: torch.Tensor, seed: int):
|
||||||
# FIXME: use x_T for initial seeded noise
|
# FIXME: use x_T for initial seeded noise
|
||||||
# We're not at the moment because the pipeline automatically resizes init_image if
|
# We're not at the moment because the pipeline automatically resizes init_image if
|
||||||
# necessary, which the x_T input might not match.
|
# necessary, which the x_T input might not match.
|
||||||
@ -77,7 +76,7 @@ class Img2Img(Generator):
|
|||||||
conditioning_data,
|
conditioning_data,
|
||||||
noise_func=self.get_noise_like,
|
noise_func=self.get_noise_like,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
seed=seed
|
seed=seed,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
pipeline_output.attention_map_saver is not None
|
pipeline_output.attention_map_saver is not None
|
||||||
@ -88,9 +87,7 @@ class Img2Img(Generator):
|
|||||||
|
|
||||||
return make_image
|
return make_image
|
||||||
|
|
||||||
def get_noise_like(self, like: torch.Tensor, seed: Optional[int]):
|
def get_noise_like(self, like: torch.Tensor):
|
||||||
if seed is not None:
|
|
||||||
set_seed(seed)
|
|
||||||
device = like.device
|
device = like.device
|
||||||
if device.type == "mps":
|
if device.type == "mps":
|
||||||
x = torch.randn_like(like, device="cpu").to(device)
|
x = torch.randn_like(like, device="cpu").to(device)
|
||||||
|
@ -311,7 +311,7 @@ class Inpaint(Img2Img):
|
|||||||
uc, c, cfg_scale
|
uc, c, cfg_scale
|
||||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||||
|
|
||||||
def make_image(x_T):
|
def make_image(x_T: torch.Tensor, seed: int):
|
||||||
pipeline_output = pipeline.inpaint_from_embeddings(
|
pipeline_output = pipeline.inpaint_from_embeddings(
|
||||||
init_image=init_image,
|
init_image=init_image,
|
||||||
mask=1 - mask, # expects white means "paint here."
|
mask=1 - mask, # expects white means "paint here."
|
||||||
@ -320,7 +320,7 @@ class Inpaint(Img2Img):
|
|||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
noise_func=self.get_noise_like,
|
noise_func=self.get_noise_like,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
seed=seed
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -61,7 +61,7 @@ class Txt2Img(Generator):
|
|||||||
),
|
),
|
||||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||||
|
|
||||||
def make_image(x_T) -> PIL.Image.Image:
|
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
||||||
pipeline_output = pipeline.image_from_embeddings(
|
pipeline_output = pipeline.image_from_embeddings(
|
||||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
||||||
noise=x_T,
|
noise=x_T,
|
||||||
|
@ -64,7 +64,7 @@ class Txt2Img2Img(Generator):
|
|||||||
),
|
),
|
||||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||||
|
|
||||||
def make_image(x_T):
|
def make_image(x_T: torch.Tensor, _: int):
|
||||||
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
|
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
|
||||||
latents=torch.zeros_like(x_T),
|
latents=torch.zeros_like(x_T),
|
||||||
num_inference_steps=steps,
|
num_inference_steps=steps,
|
||||||
|
@ -9,6 +9,7 @@ from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
|||||||
|
|
||||||
import einops
|
import einops
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
from accelerate.utils import set_seed
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
@ -694,7 +695,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
device=self._model_group.device_for(self.unet),
|
device=self._model_group.device_for(self.unet),
|
||||||
dtype=self.unet.dtype,
|
dtype=self.unet.dtype,
|
||||||
)
|
)
|
||||||
noise = noise_func(initial_latents, seed)
|
if seed is not None:
|
||||||
|
set_seed(seed)
|
||||||
|
noise = noise_func(initial_latents)
|
||||||
|
|
||||||
return self.img2img_from_latents_and_embeddings(
|
return self.img2img_from_latents_and_embeddings(
|
||||||
initial_latents,
|
initial_latents,
|
||||||
@ -796,7 +799,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
init_image_latents = self.non_noised_latents_from_image(
|
init_image_latents = self.non_noised_latents_from_image(
|
||||||
init_image, device=device, dtype=latents_dtype
|
init_image, device=device, dtype=latents_dtype
|
||||||
)
|
)
|
||||||
noise = noise_func(init_image_latents, seed)
|
if seed is not None:
|
||||||
|
set_seed(seed)
|
||||||
|
noise = noise_func(init_image_latents)
|
||||||
|
|
||||||
if mask.dim() == 3:
|
if mask.dim() == 3:
|
||||||
mask = mask.unsqueeze(0)
|
mask = mask.unsqueeze(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user