Merge branch 'main' into feat_longer_prompts

This commit is contained in:
Damian Stewart 2023-03-09 00:13:01 +01:00 committed by GitHub
commit 9ee648e0c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 26 additions and 12 deletions

View File

@ -99,6 +99,7 @@ class Generator:
h_symmetry_time_pct=h_symmetry_time_pct, h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct, v_symmetry_time_pct=v_symmetry_time_pct,
attention_maps_callback=attention_maps_callback, attention_maps_callback=attention_maps_callback,
seed=seed,
**kwargs, **kwargs,
) )
results = [] results = []
@ -290,8 +291,6 @@ class Generator:
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
seed = random.randrange(0, np.iinfo(np.uint32).max) seed = random.randrange(0, np.iinfo(np.uint32).max)
return (seed, initial_noise) return (seed, initial_noise)
else:
return (seed, None)
# returns a tensor filled with random numbers from a normal distribution # returns a tensor filled with random numbers from a normal distribution
def get_noise(self, width, height): def get_noise(self, width, height):

View File

@ -1,8 +1,10 @@
""" """
invokeai.backend.generator.img2img descends from .generator invokeai.backend.generator.img2img descends from .generator
""" """
from typing import Optional
import torch import torch
from accelerate.utils import set_seed
from diffusers import logging from diffusers import logging
from ..stable_diffusion import ( from ..stable_diffusion import (
@ -35,6 +37,7 @@ class Img2Img(Generator):
h_symmetry_time_pct=None, h_symmetry_time_pct=None,
v_symmetry_time_pct=None, v_symmetry_time_pct=None,
attention_maps_callback=None, attention_maps_callback=None,
seed=None,
**kwargs, **kwargs,
): ):
""" """
@ -65,6 +68,7 @@ class Img2Img(Generator):
# FIXME: use x_T for initial seeded noise # FIXME: use x_T for initial seeded noise
# We're not at the moment because the pipeline automatically resizes init_image if # We're not at the moment because the pipeline automatically resizes init_image if
# necessary, which the x_T input might not match. # necessary, which the x_T input might not match.
# In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
logging.set_verbosity_error() # quench safety check warnings logging.set_verbosity_error() # quench safety check warnings
pipeline_output = pipeline.img2img_from_embeddings( pipeline_output = pipeline.img2img_from_embeddings(
init_image, init_image,
@ -73,6 +77,7 @@ class Img2Img(Generator):
conditioning_data, conditioning_data,
noise_func=self.get_noise_like, noise_func=self.get_noise_like,
callback=step_callback, callback=step_callback,
seed=seed
) )
if ( if (
pipeline_output.attention_map_saver is not None pipeline_output.attention_map_saver is not None
@ -83,7 +88,9 @@ class Img2Img(Generator):
return make_image return make_image
def get_noise_like(self, like: torch.Tensor): def get_noise_like(self, like: torch.Tensor, seed: Optional[int]):
if seed is not None:
set_seed(seed)
device = like.device device = like.device
if device.type == "mps": if device.type == "mps":
x = torch.randn_like(like, device="cpu").to(device) x = torch.randn_like(like, device="cpu").to(device)

View File

@ -223,6 +223,7 @@ class Inpaint(Img2Img):
inpaint_height=None, inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF), inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
attention_maps_callback=None, attention_maps_callback=None,
seed=None,
**kwargs, **kwargs,
): ):
""" """
@ -319,6 +320,7 @@ class Inpaint(Img2Img):
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
noise_func=self.get_noise_like, noise_func=self.get_noise_like,
callback=step_callback, callback=step_callback,
seed=seed
) )
if ( if (

View File

@ -690,6 +690,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None, run_id=None,
noise_func=None, noise_func=None,
seed=None,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
if isinstance(init_image, PIL.Image.Image): if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB")) init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
@ -703,7 +704,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device=self._model_group.device_for(self.unet), device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype, dtype=self.unet.dtype,
) )
noise = noise_func(initial_latents) noise = noise_func(initial_latents, seed)
return self.img2img_from_latents_and_embeddings( return self.img2img_from_latents_and_embeddings(
initial_latents, initial_latents,
@ -731,9 +732,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device=self._model_group.device_for(self.unet), device=self._model_group.device_for(self.unet),
) )
result_latents, result_attention_maps = self.latents_from_embeddings( result_latents, result_attention_maps = self.latents_from_embeddings(
initial_latents, latents=initial_latents if strength < 1.0 else torch.zeros_like(
num_inference_steps, initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
conditioning_data, ),
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
timesteps=timesteps, timesteps=timesteps,
noise=noise, noise=noise,
run_id=run_id, run_id=run_id,
@ -779,6 +782,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None, run_id=None,
noise_func=None, noise_func=None,
seed=None,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
device = self._model_group.device_for(self.unet) device = self._model_group.device_for(self.unet)
latents_dtype = self.unet.dtype latents_dtype = self.unet.dtype
@ -802,7 +806,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_image_latents = self.non_noised_latents_from_image( init_image_latents = self.non_noised_latents_from_image(
init_image, device=device, dtype=latents_dtype init_image, device=device, dtype=latents_dtype
) )
noise = noise_func(init_image_latents) noise = noise_func(init_image_latents, seed)
if mask.dim() == 3: if mask.dim() == 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
@ -831,9 +835,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
try: try:
result_latents, result_attention_maps = self.latents_from_embeddings( result_latents, result_attention_maps = self.latents_from_embeddings(
init_image_latents, latents=init_image_latents if strength < 1.0 else torch.zeros_like(
num_inference_steps, init_image_latents, device=init_image_latents.device, dtype=init_image_latents.dtype
conditioning_data, ),
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
noise=noise, noise=noise,
timesteps=timesteps, timesteps=timesteps,
additional_guidance=guidance, additional_guidance=guidance,