InvokeAI/invokeai/backend/generator/txt2img.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

82 lines
2.4 KiB
Python
Raw Normal View History

2023-03-03 06:02:00 +00:00
"""
2023-02-28 05:37:13 +00:00
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
2023-03-03 06:02:00 +00:00
"""
2023-02-28 05:37:13 +00:00
import PIL.Image
import torch
2023-03-03 06:02:00 +00:00
from ..stable_diffusion import (
ConditioningData,
PostprocessingSettings,
StableDiffusionGeneratorPipeline,
)
2023-02-28 05:37:13 +00:00
from .base import Generator
2023-03-03 06:02:00 +00:00
2023-02-28 05:37:13 +00:00
class Txt2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
@torch.no_grad()
2023-03-03 06:02:00 +00:00
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
width,
height,
step_callback=None,
threshold=0.0,
warmup=0.2,
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
attention_maps_callback=None,
**kwargs,
):
2023-02-28 05:37:13 +00:00
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
self.perlin = perlin
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
2023-03-03 06:02:00 +00:00
uc, c, extra_conditioning_info = conditioning
conditioning_data = ConditioningData(
uc,
c,
cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
2023-02-28 05:37:13 +00:00
2023-03-13 13:11:09 +00:00
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
2023-02-28 05:37:13 +00:00
pipeline_output = pipeline.image_from_embeddings(
2023-03-03 06:02:00 +00:00
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
2023-02-28 05:37:13 +00:00
noise=x_T,
num_inference_steps=steps,
conditioning_data=conditioning_data,
callback=step_callback,
)
2023-03-03 06:02:00 +00:00
if (
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
2023-02-28 05:37:13 +00:00
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image