InvokeAI/invokeai/backend/generator/img2img.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

102 lines
3.2 KiB
Python
Raw Normal View History

2023-03-03 06:02:00 +00:00
"""
invokeai.backend.generator.img2img descends from .generator
"""
from typing import Optional
2023-02-28 05:37:13 +00:00
import torch
from accelerate.utils import set_seed
2023-02-28 05:37:13 +00:00
from diffusers import logging
2023-03-03 06:02:00 +00:00
from ..stable_diffusion import (
ConditioningData,
PostprocessingSettings,
StableDiffusionGeneratorPipeline,
)
2023-02-28 05:37:13 +00:00
from .base import Generator
2023-03-03 06:02:00 +00:00
2023-02-28 05:37:13 +00:00
class Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
2023-03-03 06:02:00 +00:00
self.init_latent = None # by get_noise()
2023-02-28 05:37:13 +00:00
2023-03-03 06:02:00 +00:00
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
init_image,
strength,
step_callback=None,
threshold=0.0,
warmup=0.2,
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
attention_maps_callback=None,
**kwargs,
):
2023-02-28 05:37:13 +00:00
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
"""
self.perlin = perlin
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
2023-03-03 06:02:00 +00:00
uc, c, extra_conditioning_info = conditioning
conditioning_data = ConditioningData(
uc,
c,
cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
2023-02-28 05:37:13 +00:00
2023-03-13 13:11:09 +00:00
def make_image(x_T: torch.Tensor, seed: int):
2023-02-28 05:37:13 +00:00
# FIXME: use x_T for initial seeded noise
# We're not at the moment because the pipeline automatically resizes init_image if
# necessary, which the x_T input might not match.
# In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
2023-03-03 06:02:00 +00:00
logging.set_verbosity_error() # quench safety check warnings
2023-02-28 05:37:13 +00:00
pipeline_output = pipeline.img2img_from_embeddings(
2023-03-03 06:02:00 +00:00
init_image,
strength,
steps,
conditioning_data,
2023-02-28 05:37:13 +00:00
noise_func=self.get_noise_like,
2023-03-03 06:02:00 +00:00
callback=step_callback,
2023-03-13 13:11:09 +00:00
seed=seed,
2023-02-28 05:37:13 +00:00
)
2023-03-03 06:02:00 +00:00
if (
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
2023-02-28 05:37:13 +00:00
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image
2023-03-13 13:11:09 +00:00
def get_noise_like(self, like: torch.Tensor):
2023-02-28 05:37:13 +00:00
device = like.device
2023-03-03 06:02:00 +00:00
if device.type == "mps":
x = torch.randn_like(like, device="cpu").to(device)
2023-02-28 05:37:13 +00:00
else:
x = torch.randn_like(like, device=device)
if self.perlin > 0.0:
shape = like.shape
2023-03-03 06:02:00 +00:00
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
shape[3], shape[2]
)
2023-02-28 05:37:13 +00:00
return x