InvokeAI/invokeai/backend/generator/txt2img.py

126 lines
4.6 KiB
Python
Raw Normal View History

2023-03-03 06:02:00 +00:00
"""
2023-02-28 05:37:13 +00:00
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
2023-03-03 06:02:00 +00:00
"""
2023-02-28 05:37:13 +00:00
import PIL.Image
import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
2023-06-07 16:42:52 +00:00
from diffusers.pipelines.controlnet import MultiControlNetModel
2023-03-03 06:02:00 +00:00
from ..stable_diffusion import (
ConditioningData,
PostprocessingSettings,
StableDiffusionGeneratorPipeline,
)
2023-02-28 05:37:13 +00:00
from .base import Generator
2023-03-03 06:02:00 +00:00
2023-02-28 05:37:13 +00:00
class Txt2Img(Generator):
def __init__(self, model, precision,
control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
**kwargs):
self.control_model = control_model
if isinstance(self.control_model, list):
self.control_model = MultiControlNetModel(self.control_model)
super().__init__(model, precision, **kwargs)
2023-02-28 05:37:13 +00:00
@torch.no_grad()
2023-03-03 06:02:00 +00:00
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
width,
height,
step_callback=None,
threshold=0.0,
warmup=0.2,
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
attention_maps_callback=None,
**kwargs,
):
2023-02-28 05:37:13 +00:00
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
self.perlin = perlin
control_image = kwargs.get("control_image", None)
do_classifier_free_guidance = cfg_scale > 1.0
2023-02-28 05:37:13 +00:00
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.control_model = self.control_model
2023-02-28 05:37:13 +00:00
pipeline.scheduler = sampler
2023-03-03 06:02:00 +00:00
uc, c, extra_conditioning_info = conditioning
conditioning_data = ConditioningData(
uc,
c,
cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
2023-02-28 05:37:13 +00:00
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
if control_image is not None:
if isinstance(self.control_model, ControlNetModel):
control_image = pipeline.prepare_control_image(
image=control_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=self.control_model.device,
dtype=self.control_model.dtype,
)
elif isinstance(self.control_model, MultiControlNetModel):
images = []
for image_ in control_image:
image_ = self.model.prepare_control_image(
image=image_,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=self.control_model.device,
dtype=self.control_model.dtype,
)
images.append(image_)
control_image = images
kwargs["control_image"] = control_image
2023-03-13 13:11:09 +00:00
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
2023-02-28 05:37:13 +00:00
pipeline_output = pipeline.image_from_embeddings(
2023-03-03 06:02:00 +00:00
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
2023-02-28 05:37:13 +00:00
noise=x_T,
num_inference_steps=steps,
conditioning_data=conditioning_data,
callback=step_callback,
**kwargs,
2023-02-28 05:37:13 +00:00
)
2023-03-03 06:02:00 +00:00
if (
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
2023-02-28 05:37:13 +00:00
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image