InvokeAI/invokeai/backend/generator/img2img.py
Jonathan 2db180d909
Make img2img strength 1 behave the same as txt2img (#2895)
* Fix img2img and inpainting code so a strength of 1 behaves the same as txt2img.

* Make generated images identical to their txt2img counterparts when strength is 1.
2023-03-08 22:50:16 +01:00

105 lines
3.3 KiB
Python

"""
invokeai.backend.generator.img2img descends from .generator
"""
from typing import Optional
import torch
from accelerate.utils import set_seed
from diffusers import logging
from ..stable_diffusion import (
ConditioningData,
PostprocessingSettings,
StableDiffusionGeneratorPipeline,
)
from .base import Generator
class Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # by get_noise()
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
init_image,
strength,
step_callback=None,
threshold=0.0,
warmup=0.2,
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
attention_maps_callback=None,
seed=None,
**kwargs,
):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
"""
self.perlin = perlin
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning
conditioning_data = ConditioningData(
uc,
c,
cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T):
# FIXME: use x_T for initial seeded noise
# We're not at the moment because the pipeline automatically resizes init_image if
# necessary, which the x_T input might not match.
# In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
logging.set_verbosity_error() # quench safety check warnings
pipeline_output = pipeline.img2img_from_embeddings(
init_image,
strength,
steps,
conditioning_data,
noise_func=self.get_noise_like,
callback=step_callback,
seed=seed
)
if (
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image
def get_noise_like(self, like: torch.Tensor, seed: Optional[int]):
if seed is not None:
set_seed(seed)
device = like.device
if device.type == "mps":
x = torch.randn_like(like, device="cpu").to(device)
else:
x = torch.randn_like(like, device=device)
if self.perlin > 0.0:
shape = like.shape
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
shape[3], shape[2]
)
return x