mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
3ffff023b2
It was breaking coz the sampler was not being reset. So needs a key on each. Will simplify this later.
1252 lines
48 KiB
Python
1252 lines
48 KiB
Python
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
|
# Derived from source code carrying the following copyrights
|
|
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
|
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
|
|
|
import gc
|
|
import importlib
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from typing import List
|
|
|
|
import cv2
|
|
import diffusers
|
|
import numpy as np
|
|
import skimage
|
|
import torch
|
|
import transformers
|
|
from PIL import Image, ImageOps
|
|
from accelerate.utils import set_seed
|
|
from diffusers.pipeline_utils import DiffusionPipeline
|
|
from diffusers.utils.import_utils import is_xformers_available
|
|
from omegaconf import OmegaConf
|
|
from pathlib import Path
|
|
|
|
import invokeai.backend.util.logging as logger
|
|
from .args import metadata_from_png
|
|
from .generator import infill_methods
|
|
from .globals import Globals, global_cache_dir
|
|
from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding
|
|
from .model_management import ModelManager
|
|
from .safety_checker import SafetyChecker
|
|
from .prompting import get_uc_and_c_and_ec
|
|
from .prompting.conditioning import log_tokenization
|
|
from .stable_diffusion import HuggingFaceConceptsLibrary
|
|
from .util import choose_precision, choose_torch_device
|
|
|
|
def fix_func(orig):
|
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
|
|
def new_func(*args, **kw):
|
|
device = kw.get("device", "mps")
|
|
kw["device"] = "cpu"
|
|
return orig(*args, **kw).to(device)
|
|
|
|
return new_func
|
|
return orig
|
|
|
|
|
|
torch.rand = fix_func(torch.rand)
|
|
torch.rand_like = fix_func(torch.rand_like)
|
|
torch.randn = fix_func(torch.randn)
|
|
torch.randn_like = fix_func(torch.randn_like)
|
|
torch.randint = fix_func(torch.randint)
|
|
torch.randint_like = fix_func(torch.randint_like)
|
|
torch.bernoulli = fix_func(torch.bernoulli)
|
|
torch.multinomial = fix_func(torch.multinomial)
|
|
|
|
# this is fallback model in case no default is defined
|
|
FALLBACK_MODEL_NAME = "stable-diffusion-1.5"
|
|
|
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
|
|
|
Example Usage:
|
|
|
|
from ldm.generate import Generate
|
|
|
|
# Create an object with default values
|
|
gr = Generate('stable-diffusion-1.4')
|
|
|
|
# do the slow model initialization
|
|
gr.load_model()
|
|
|
|
# Do the fast inference & image generation. Any options passed here
|
|
# override the default values assigned during class initialization
|
|
# Will call load_model() if the model was not previously loaded and so
|
|
# may be slow at first.
|
|
# The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
|
|
results = gr.prompt2png(prompt = "an astronaut riding a horse",
|
|
outdir = "./outputs/samples",
|
|
iterations = 3)
|
|
|
|
for row in results:
|
|
print(f'filename={row[0]}')
|
|
print(f'seed ={row[1]}')
|
|
|
|
# Same thing, but using an initial image.
|
|
results = gr.prompt2png(prompt = "an astronaut riding a horse",
|
|
outdir = "./outputs/,
|
|
iterations = 3,
|
|
init_img = "./sketches/horse+rider.png")
|
|
|
|
for row in results:
|
|
print(f'filename={row[0]}')
|
|
print(f'seed ={row[1]}')
|
|
|
|
# Same thing, but we return a series of Image objects, which lets you manipulate them,
|
|
# combine them, and save them under arbitrary names
|
|
|
|
results = gr.prompt2image(prompt = "an astronaut riding a horse"
|
|
outdir = "./outputs/")
|
|
for row in results:
|
|
im = row[0]
|
|
seed = row[1]
|
|
im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png')
|
|
im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg')
|
|
|
|
Note that the old txt2img() and img2img() calls are deprecated but will
|
|
still work.
|
|
|
|
The full list of arguments to Generate() are:
|
|
gr = Generate(
|
|
# these values are set once and shouldn't be changed
|
|
conf:str = path to configuration file ('configs/models.yaml')
|
|
model:str = symbolic name of the model in the configuration file
|
|
precision:float = float precision to be used
|
|
safety_checker:bool = activate safety checker [False]
|
|
|
|
# this value is sticky and maintained between generation calls
|
|
sampler_name:str = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_dpmpp_2', 'k_dpmpp_2_a', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
|
|
|
# these are deprecated - use conf and model instead
|
|
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
|
|
config = path to model configuration ('configs/stable-diffusion/v1-inference.yaml')
|
|
)
|
|
|
|
"""
|
|
|
|
|
|
class Generate:
|
|
"""Generate class
|
|
Stores default values for multiple configuration items
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model=None,
|
|
conf="configs/models.yaml",
|
|
embedding_path=None,
|
|
sampler_name="k_lms",
|
|
ddim_eta=0.0, # deterministic
|
|
full_precision=False,
|
|
precision="auto",
|
|
outdir="outputs/img-samples",
|
|
gfpgan=None,
|
|
codeformer=None,
|
|
esrgan=None,
|
|
free_gpu_mem: bool = False,
|
|
safety_checker: bool = False,
|
|
max_loaded_models: int = 2,
|
|
# these are deprecated; if present they override values in the conf file
|
|
weights=None,
|
|
config=None,
|
|
):
|
|
mconfig = OmegaConf.load(conf)
|
|
self.height = None
|
|
self.width = None
|
|
self.model_manager = None
|
|
self.iterations = 1
|
|
self.steps = 50
|
|
self.cfg_scale = 7.5
|
|
self.sampler_name = sampler_name
|
|
self.ddim_eta = ddim_eta # same seed always produces same image
|
|
self.precision = precision
|
|
self.strength = 0.75
|
|
self.seamless = False
|
|
self.seamless_axes = {"x", "y"}
|
|
self.hires_fix = False
|
|
self.embedding_path = embedding_path
|
|
self.model = None # empty for now
|
|
self.model_hash = None
|
|
self.sampler = None
|
|
self.device = None
|
|
self.max_memory_allocated = 0
|
|
self.memory_allocated = 0
|
|
self.session_peakmem = 0
|
|
self.base_generator = None
|
|
self.seed = None
|
|
self.outdir = outdir
|
|
self.gfpgan = gfpgan
|
|
self.codeformer = codeformer
|
|
self.esrgan = esrgan
|
|
self.free_gpu_mem = free_gpu_mem
|
|
self.max_loaded_models = (max_loaded_models,)
|
|
self.size_matters = True # used to warn once about large image sizes and VRAM
|
|
self.txt2mask = None
|
|
self.safety_checker = None
|
|
self.karras_max = None
|
|
self.infill_method = None
|
|
|
|
# Note that in previous versions, there was an option to pass the
|
|
# device to Generate(). However the device was then ignored, so
|
|
# it wasn't actually doing anything. This logic could be reinstated.
|
|
self.device = torch.device(choose_torch_device())
|
|
logger.info(f"Using device_type {self.device.type}")
|
|
if full_precision:
|
|
if self.precision != "auto":
|
|
raise ValueError("Remove --full_precision / -F if using --precision")
|
|
logger.warning("Please remove deprecated --full_precision / -F")
|
|
logger.warning("If auto config does not work you can use --precision=float32")
|
|
self.precision = "float32"
|
|
if self.precision == "auto":
|
|
self.precision = choose_precision(self.device)
|
|
Globals.full_precision = self.precision == "float32"
|
|
|
|
if is_xformers_available():
|
|
if torch.cuda.is_available() and not Globals.disable_xformers:
|
|
logger.info("xformers memory-efficient attention is available and enabled")
|
|
else:
|
|
logger.info(
|
|
"xformers memory-efficient attention is available but disabled"
|
|
)
|
|
else:
|
|
logger.info("xformers not installed")
|
|
|
|
# model caching system for fast switching
|
|
self.model_manager = ModelManager(
|
|
mconfig,
|
|
self.device,
|
|
self.precision,
|
|
max_loaded_models=max_loaded_models,
|
|
sequential_offload=self.free_gpu_mem,
|
|
embedding_path=Path(self.embedding_path),
|
|
)
|
|
# don't accept invalid models
|
|
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
|
model = model or fallback
|
|
if not self.model_manager.valid_model(model):
|
|
logger.warning(
|
|
f'"{model}" is not a known model name; falling back to {fallback}.'
|
|
)
|
|
model = None
|
|
self.model_name = model or fallback
|
|
|
|
# for VRAM usage statistics
|
|
self.session_peakmem = (
|
|
torch.cuda.max_memory_allocated(self.device) if self._has_cuda else None
|
|
)
|
|
transformers.logging.set_verbosity_error()
|
|
|
|
# gets rid of annoying messages about random seed
|
|
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
|
|
|
|
# load safety checker if requested
|
|
if safety_checker:
|
|
logger.info("Initializing NSFW checker")
|
|
self.safety_checker = SafetyChecker(self.device)
|
|
else:
|
|
logger.info("NSFW checker is disabled")
|
|
|
|
def prompt2png(self, prompt, outdir, **kwargs):
|
|
"""
|
|
Takes a prompt and an output directory, writes out the requested number
|
|
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
|
|
Optional named arguments are the same as those passed to Generate and prompt2image()
|
|
"""
|
|
results = self.prompt2image(prompt, **kwargs)
|
|
pngwriter = PngWriter(outdir)
|
|
prefix = pngwriter.unique_prefix()
|
|
outputs = []
|
|
for image, seed in results:
|
|
name = f"{prefix}.{seed}.png"
|
|
path = pngwriter.save_image_and_prompt_to_png(
|
|
image, dream_prompt=f"{prompt} -S{seed}", name=name
|
|
)
|
|
outputs.append([path, seed])
|
|
return outputs
|
|
|
|
def txt2img(self, prompt, **kwargs):
|
|
outdir = kwargs.pop("outdir", self.outdir)
|
|
return self.prompt2png(prompt, outdir, **kwargs)
|
|
|
|
def img2img(self, prompt, **kwargs):
|
|
outdir = kwargs.pop("outdir", self.outdir)
|
|
assert (
|
|
"init_img" in kwargs
|
|
), "call to img2img() must include the init_img argument"
|
|
return self.prompt2png(prompt, outdir, **kwargs)
|
|
|
|
def prompt2image(
|
|
self,
|
|
# these are common
|
|
prompt,
|
|
iterations=None,
|
|
steps=None,
|
|
seed=None,
|
|
cfg_scale=None,
|
|
ddim_eta=None,
|
|
skip_normalize=False,
|
|
image_callback=None,
|
|
step_callback=None,
|
|
width=None,
|
|
height=None,
|
|
sampler_name=None,
|
|
seamless=False,
|
|
seamless_axes={"x", "y"},
|
|
log_tokenization=False,
|
|
with_variations=None,
|
|
variation_amount=0.0,
|
|
threshold=0.0,
|
|
perlin=0.0,
|
|
h_symmetry_time_pct=None,
|
|
v_symmetry_time_pct=None,
|
|
karras_max=None,
|
|
outdir=None,
|
|
# these are specific to img2img and inpaint
|
|
init_img=None,
|
|
init_mask=None,
|
|
text_mask=None,
|
|
invert_mask=False,
|
|
fit=False,
|
|
strength=None,
|
|
init_color=None,
|
|
# these are specific to embiggen (which also relies on img2img args)
|
|
embiggen=None,
|
|
embiggen_tiles=None,
|
|
embiggen_strength=None,
|
|
# these are specific to GFPGAN/ESRGAN
|
|
gfpgan_strength=0,
|
|
facetool=None,
|
|
facetool_strength=0,
|
|
codeformer_fidelity=None,
|
|
save_original=False,
|
|
upscale=None,
|
|
upscale_denoise_str=0.75,
|
|
# this is specific to inpainting and causes more extreme inpainting
|
|
inpaint_replace=0.0,
|
|
# This controls the size at which inpaint occurs (scaled up for inpaint, then back down for the result)
|
|
inpaint_width=None,
|
|
inpaint_height=None,
|
|
# This will help match inpainted areas to the original image more smoothly
|
|
mask_blur_radius: int = 8,
|
|
# Set this True to handle KeyboardInterrupt internally
|
|
catch_interrupts=False,
|
|
hires_fix=False,
|
|
use_mps_noise=False,
|
|
# Seam settings for outpainting
|
|
seam_size: int = 0,
|
|
seam_blur: int = 0,
|
|
seam_strength: float = 0.7,
|
|
seam_steps: int = 10,
|
|
tile_size: int = 32,
|
|
infill_method=None,
|
|
force_outpaint: bool = False,
|
|
enable_image_debugging=False,
|
|
**args,
|
|
): # eat up additional cruft
|
|
self.clear_cuda_stats()
|
|
"""
|
|
ldm.generate.prompt2image() is the common entry point for txt2img() and img2img()
|
|
It takes the following arguments:
|
|
prompt // prompt string (no default)
|
|
iterations // iterations (1); image count=iterations
|
|
steps // refinement steps per iteration
|
|
seed // seed for random number generator
|
|
width // width of image, in multiples of 64 (512)
|
|
height // height of image, in multiples of 64 (512)
|
|
cfg_scale // how strongly the prompt influences the image (7.5) (must be >1)
|
|
seamless // whether the generated image should tile
|
|
hires_fix // whether the Hires Fix should be applied during generation
|
|
init_img // path to an initial image
|
|
init_mask // path to a mask for the initial image
|
|
text_mask // a text string that will be used to guide clipseg generation of the init_mask
|
|
invert_mask // boolean, if true invert the mask
|
|
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
|
facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely
|
|
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
|
step_callback // a function or method that will be called each step
|
|
image_callback // a function or method that will be called each time an image is generated
|
|
with_variations // a weighted list [(seed_1, weight_1), (seed_2, weight_2), ...] of variations which should be applied before doing any generation
|
|
variation_amount // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image)
|
|
threshold // optional value >=0 to add thresholding to latent values for k-diffusion samplers (0 disables)
|
|
perlin // optional 0-1 value to add a percentage of perlin noise to the initial noise
|
|
h_symmetry_time_pct // optional 0-1 value that indicates the time at which horizontal symmetry is applied
|
|
v_symmetry_time_pct // optional 0-1 value that indicates the time at which vertical symmetry is applied
|
|
embiggen // scale factor relative to the size of the --init_img (-I), followed by ESRGAN upscaling strength (0-1.0), followed by minimum amount of overlap between tiles as a decimal ratio (0 - 1.0) or number of pixels
|
|
embiggen_tiles // list of tiles by number in order to process and replace onto the image e.g. `0 2 4`
|
|
embiggen_strength // strength for embiggen. 0.0 preserves image exactly, 1.0 replaces it completely
|
|
|
|
To use the step callback, define a function that receives two arguments:
|
|
- Image GPU data
|
|
- The step number
|
|
|
|
To use the image callback, define a function of method that receives two arguments, an Image object
|
|
and the seed. You can then do whatever you like with the image, including converting it to
|
|
different formats and manipulating it. For example:
|
|
|
|
def process_image(image,seed):
|
|
image.save(f{'images/seed.png'})
|
|
|
|
The code used to save images to a directory can be found in ldm/invoke/pngwriter.py.
|
|
It contains code to create the requested output directory, select a unique informative
|
|
name for each image, and write the prompt into the PNG metadata.
|
|
"""
|
|
# TODO: convert this into a getattr() loop
|
|
steps = steps or self.steps
|
|
width = width or self.width
|
|
height = height or self.height
|
|
seamless = seamless or self.seamless
|
|
seamless_axes = seamless_axes or self.seamless_axes
|
|
hires_fix = hires_fix or self.hires_fix
|
|
cfg_scale = cfg_scale or self.cfg_scale
|
|
ddim_eta = ddim_eta or self.ddim_eta
|
|
iterations = iterations or self.iterations
|
|
strength = strength or self.strength
|
|
outdir = outdir or self.outdir
|
|
self.seed = seed
|
|
self.log_tokenization = log_tokenization
|
|
self.step_callback = step_callback
|
|
self.karras_max = karras_max
|
|
self.infill_method = (
|
|
infill_method or infill_methods()[0],
|
|
) # The infill method to use
|
|
with_variations = [] if with_variations is None else with_variations
|
|
|
|
# will instantiate the model or return it from cache
|
|
model = self.set_model(self.model_name)
|
|
|
|
# self.width and self.height are set by set_model()
|
|
# to the width and height of the image training set
|
|
width = width or self.width
|
|
height = height or self.height
|
|
|
|
if isinstance(model, DiffusionPipeline):
|
|
configure_model_padding(model.unet, seamless, seamless_axes)
|
|
configure_model_padding(model.vae, seamless, seamless_axes)
|
|
else:
|
|
configure_model_padding(model, seamless, seamless_axes)
|
|
|
|
assert cfg_scale > 1.0, "CFG_Scale (-C) must be >1.0"
|
|
assert threshold >= 0.0, "--threshold must be >=0.0"
|
|
assert (
|
|
0.0 < strength <= 1.0
|
|
), "img2img and inpaint strength can only work with 0.0 < strength < 1.0"
|
|
assert (
|
|
0.0 <= variation_amount <= 1.0
|
|
), "-v --variation_amount must be in [0.0, 1.0]"
|
|
assert 0.0 <= perlin <= 1.0, "--perlin must be in [0.0, 1.0]"
|
|
assert (embiggen == None and embiggen_tiles == None) or (
|
|
(embiggen != None or embiggen_tiles != None) and init_img != None
|
|
), "Embiggen requires an init/input image to be specified"
|
|
|
|
if len(with_variations) > 0 or variation_amount > 1.0:
|
|
assert seed is not None, "seed must be specified when using with_variations"
|
|
if variation_amount == 0.0:
|
|
assert (
|
|
iterations == 1
|
|
), "when using --with_variations, multiple iterations are only possible when using --variation_amount"
|
|
assert all(
|
|
0 <= weight <= 1 for _, weight in with_variations
|
|
), f"variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}"
|
|
|
|
width, height, _ = self._resolution_check(width, height, log=True)
|
|
assert (
|
|
inpaint_replace >= 0.0 and inpaint_replace <= 1.0
|
|
), "inpaint_replace must be between 0.0 and 1.0"
|
|
|
|
if sampler_name and (sampler_name != self.sampler_name):
|
|
self.sampler_name = sampler_name
|
|
self._set_scheduler()
|
|
|
|
# apply the concepts library to the prompt
|
|
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
|
|
prompt,
|
|
lambda concepts: self.load_huggingface_concepts(concepts),
|
|
self.model.textual_inversion_manager.get_all_trigger_strings(),
|
|
)
|
|
|
|
tic = time.time()
|
|
if self._has_cuda():
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
results = list()
|
|
|
|
try:
|
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
|
prompt,
|
|
model=self.model,
|
|
skip_normalize_legacy_blend=skip_normalize,
|
|
log_tokens=self.log_tokenization,
|
|
)
|
|
|
|
init_image, mask_image = self._make_images(
|
|
init_img,
|
|
init_mask,
|
|
width,
|
|
height,
|
|
fit=fit,
|
|
text_mask=text_mask,
|
|
invert_mask=invert_mask,
|
|
force_outpaint=force_outpaint,
|
|
)
|
|
|
|
# TODO: Hacky selection of operation to perform. Needs to be refactored.
|
|
generator = self.select_generator(
|
|
init_image, mask_image, embiggen, hires_fix, force_outpaint
|
|
)
|
|
|
|
generator.set_variation(self.seed, variation_amount, with_variations)
|
|
generator.use_mps_noise = use_mps_noise
|
|
|
|
results = generator.generate(
|
|
prompt,
|
|
iterations=iterations,
|
|
seed=self.seed,
|
|
sampler=self.sampler,
|
|
steps=steps,
|
|
cfg_scale=cfg_scale,
|
|
conditioning=(uc, c, extra_conditioning_info),
|
|
ddim_eta=ddim_eta,
|
|
image_callback=image_callback, # called after the final image is generated
|
|
step_callback=step_callback, # called after each intermediate image is generated
|
|
width=width,
|
|
height=height,
|
|
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
|
init_image=init_image, # notice that init_image is different from init_img
|
|
mask_image=mask_image,
|
|
strength=strength,
|
|
threshold=threshold,
|
|
perlin=perlin,
|
|
h_symmetry_time_pct=h_symmetry_time_pct,
|
|
v_symmetry_time_pct=v_symmetry_time_pct,
|
|
embiggen=embiggen,
|
|
embiggen_tiles=embiggen_tiles,
|
|
embiggen_strength=embiggen_strength,
|
|
inpaint_replace=inpaint_replace,
|
|
mask_blur_radius=mask_blur_radius,
|
|
safety_checker=self.safety_checker,
|
|
seam_size=seam_size,
|
|
seam_blur=seam_blur,
|
|
seam_strength=seam_strength,
|
|
seam_steps=seam_steps,
|
|
tile_size=tile_size,
|
|
infill_method=infill_method,
|
|
force_outpaint=force_outpaint,
|
|
inpaint_height=inpaint_height,
|
|
inpaint_width=inpaint_width,
|
|
enable_image_debugging=enable_image_debugging,
|
|
free_gpu_mem=self.free_gpu_mem,
|
|
clear_cuda_cache=self.clear_cuda_cache,
|
|
)
|
|
|
|
if init_color:
|
|
self.correct_colors(
|
|
image_list=results,
|
|
reference_image_path=init_color,
|
|
image_callback=image_callback,
|
|
)
|
|
|
|
if upscale is not None or facetool_strength > 0:
|
|
self.upscale_and_reconstruct(
|
|
results,
|
|
upscale=upscale,
|
|
upscale_denoise_str=upscale_denoise_str,
|
|
facetool=facetool,
|
|
strength=facetool_strength,
|
|
codeformer_fidelity=codeformer_fidelity,
|
|
save_original=save_original,
|
|
image_callback=image_callback,
|
|
)
|
|
|
|
except KeyboardInterrupt:
|
|
# Clear the CUDA cache on an exception
|
|
self.clear_cuda_cache()
|
|
|
|
if catch_interrupts:
|
|
logger.warning("Interrupted** Partial results will be returned.")
|
|
else:
|
|
raise KeyboardInterrupt
|
|
except RuntimeError:
|
|
# Clear the CUDA cache on an exception
|
|
self.clear_cuda_cache()
|
|
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
logger.info("Could not generate image.")
|
|
|
|
toc = time.time()
|
|
logger.info("Usage stats:")
|
|
logger.info(f"{len(results)} image(s) generated in "+"%4.2fs" % (toc - tic))
|
|
self.print_cuda_stats()
|
|
return results
|
|
|
|
def gather_cuda_stats(self):
|
|
if self._has_cuda():
|
|
self.max_memory_allocated = max(
|
|
self.max_memory_allocated, torch.cuda.max_memory_allocated(self.device)
|
|
)
|
|
self.memory_allocated = max(
|
|
self.memory_allocated, torch.cuda.memory_allocated(self.device)
|
|
)
|
|
self.session_peakmem = max(
|
|
self.session_peakmem, torch.cuda.max_memory_allocated(self.device)
|
|
)
|
|
|
|
def clear_cuda_cache(self):
|
|
if self._has_cuda():
|
|
self.gather_cuda_stats()
|
|
# Run garbage collection prior to emptying the CUDA cache
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
def clear_cuda_stats(self):
|
|
self.max_memory_allocated = 0
|
|
self.memory_allocated = 0
|
|
|
|
def print_cuda_stats(self):
|
|
if self._has_cuda():
|
|
self.gather_cuda_stats()
|
|
logger.info(
|
|
"Max VRAM used for this generation: "+
|
|
"%4.2fG. " % (self.max_memory_allocated / 1e9)+
|
|
"Current VRAM utilization: "+
|
|
"%4.2fG" % (self.memory_allocated / 1e9)
|
|
)
|
|
|
|
logger.info(
|
|
"Max VRAM used since script start: " +
|
|
"%4.2fG" % (self.session_peakmem / 1e9)
|
|
)
|
|
|
|
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
|
|
# in a nice harmonized call signature. For now we have a bunch of if/elses!
|
|
def apply_postprocessor(
|
|
self,
|
|
image_path,
|
|
tool="gfpgan", # one of 'upscale', 'gfpgan', 'codeformer', 'outpaint', or 'embiggen'
|
|
facetool_strength=0.0,
|
|
codeformer_fidelity=0.75,
|
|
upscale=None,
|
|
upscale_denoise_str=0.75,
|
|
out_direction=None,
|
|
outcrop=[],
|
|
save_original=True, # to get new name
|
|
callback=None,
|
|
opt=None,
|
|
):
|
|
# retrieve the seed from the image;
|
|
seed = None
|
|
prompt = None
|
|
|
|
args = metadata_from_png(image_path)
|
|
seed = opt.seed or args.seed
|
|
if seed is None or seed < 0:
|
|
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
|
|
|
prompt = opt.prompt or args.prompt or ""
|
|
logger.info(f'using seed {seed} and prompt "{prompt}" for {image_path}')
|
|
|
|
# try to reuse the same filename prefix as the original file.
|
|
# we take everything up to the first period
|
|
prefix = None
|
|
m = re.match(r"^([^.]+)\.", os.path.basename(image_path))
|
|
if m:
|
|
prefix = m.groups()[0]
|
|
|
|
# face fixers and esrgan take an Image, but embiggen takes a path
|
|
image = Image.open(image_path)
|
|
|
|
# used by multiple postfixers
|
|
# todo: cross-attention control
|
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
|
prompt,
|
|
model=self.model,
|
|
skip_normalize_legacy_blend=opt.skip_normalize,
|
|
log_tokens=log_tokenization,
|
|
)
|
|
|
|
if tool in ("gfpgan", "codeformer", "upscale"):
|
|
if tool == "gfpgan":
|
|
facetool = "gfpgan"
|
|
elif tool == "codeformer":
|
|
facetool = "codeformer"
|
|
elif tool == "upscale":
|
|
facetool = "gfpgan" # but won't be run
|
|
facetool_strength = 0
|
|
return self.upscale_and_reconstruct(
|
|
[[image, seed]],
|
|
facetool=facetool,
|
|
strength=facetool_strength,
|
|
codeformer_fidelity=codeformer_fidelity,
|
|
save_original=save_original,
|
|
upscale=upscale,
|
|
upscale_denoise_str=upscale_denoise_str,
|
|
image_callback=callback,
|
|
prefix=prefix,
|
|
)
|
|
|
|
elif tool == "outcrop":
|
|
from .restoration.outcrop import Outcrop
|
|
|
|
extend_instructions = {}
|
|
for direction, pixels in _pairwise(opt.outcrop):
|
|
try:
|
|
extend_instructions[direction] = int(pixels)
|
|
except ValueError:
|
|
logger.warning(
|
|
'invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
|
|
)
|
|
|
|
opt.seed = seed
|
|
opt.prompt = prompt
|
|
|
|
if len(extend_instructions) > 0:
|
|
restorer = Outcrop(
|
|
image,
|
|
self,
|
|
)
|
|
return restorer.process(
|
|
extend_instructions,
|
|
opt=opt,
|
|
orig_opt=args,
|
|
image_callback=callback,
|
|
prefix=prefix,
|
|
)
|
|
|
|
elif tool == "embiggen":
|
|
# fetch the metadata from the image
|
|
generator = self.select_generator(embiggen=True)
|
|
opt.strength = opt.embiggen_strength or 0.40
|
|
logger.info(
|
|
f"Setting img2img strength to {opt.strength} for happy embiggening"
|
|
)
|
|
generator.generate(
|
|
prompt,
|
|
sampler=self.sampler,
|
|
steps=opt.steps,
|
|
cfg_scale=opt.cfg_scale,
|
|
ddim_eta=self.ddim_eta,
|
|
conditioning=(uc, c, extra_conditioning_info),
|
|
init_img=image_path, # not the Image! (sigh)
|
|
init_image=image, # embiggen wants both! (sigh)
|
|
strength=opt.strength,
|
|
width=opt.width,
|
|
height=opt.height,
|
|
embiggen=opt.embiggen,
|
|
embiggen_tiles=opt.embiggen_tiles,
|
|
embiggen_strength=opt.embiggen_strength,
|
|
image_callback=callback,
|
|
clear_cuda_cache=self.clear_cuda_cache,
|
|
)
|
|
elif tool == "outpaint":
|
|
from .restoration.outpaint import Outpaint
|
|
|
|
restorer = Outpaint(image, self)
|
|
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
|
|
|
|
elif tool is None:
|
|
logger.warning(
|
|
"please provide at least one postprocessing option, such as -G or -U"
|
|
)
|
|
return None
|
|
else:
|
|
logger.warning(f"postprocessing tool {tool} is not yet supported")
|
|
return None
|
|
|
|
def select_generator(
|
|
self,
|
|
init_image: Image.Image = None,
|
|
mask_image: Image.Image = None,
|
|
embiggen: bool = False,
|
|
hires_fix: bool = False,
|
|
force_outpaint: bool = False,
|
|
):
|
|
if hires_fix:
|
|
return self._make_txt2img2img()
|
|
|
|
if embiggen is not None:
|
|
return self._make_embiggen()
|
|
|
|
if ((init_image is not None) and (mask_image is not None)) or force_outpaint:
|
|
return self._make_inpaint()
|
|
|
|
if init_image is not None:
|
|
return self._make_img2img()
|
|
|
|
return self._make_txt2img()
|
|
|
|
def _make_images(
|
|
self,
|
|
img,
|
|
mask,
|
|
width,
|
|
height,
|
|
fit=False,
|
|
text_mask=None,
|
|
invert_mask=False,
|
|
force_outpaint=False,
|
|
):
|
|
init_image = None
|
|
init_mask = None
|
|
if not img:
|
|
return None, None
|
|
|
|
image = self._load_img(img)
|
|
|
|
if image.width < self.width and image.height < self.height:
|
|
logger.warning(
|
|
f"img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
|
|
)
|
|
|
|
# if image has a transparent area and no mask was provided, then try to generate mask
|
|
if self._has_transparency(image):
|
|
self._transparency_check_and_warning(image, mask, force_outpaint)
|
|
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
|
|
|
if (image.width * image.height) > (
|
|
self.width * self.height
|
|
) and self.size_matters:
|
|
logger.info(
|
|
"This input is larger than your defaults. If you run out of memory, please use a smaller image."
|
|
)
|
|
self.size_matters = False
|
|
|
|
init_image = self._create_init_image(image, width, height, fit=fit)
|
|
|
|
if mask:
|
|
mask_image = self._load_img(mask)
|
|
init_mask = self._create_init_mask(mask_image, width, height, fit=fit)
|
|
|
|
elif text_mask:
|
|
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
|
|
|
if init_mask and invert_mask:
|
|
init_mask = ImageOps.invert(init_mask)
|
|
|
|
return init_image, init_mask
|
|
|
|
def _make_base(self):
|
|
return self._load_generator("", "Generator")
|
|
|
|
def _make_txt2img(self):
|
|
return self._load_generator(".txt2img", "Txt2Img")
|
|
|
|
def _make_img2img(self):
|
|
return self._load_generator(".img2img", "Img2Img")
|
|
|
|
def _make_embiggen(self):
|
|
return self._load_generator(".embiggen", "Embiggen")
|
|
|
|
def _make_txt2img2img(self):
|
|
return self._load_generator(".txt2img2img", "Txt2Img2Img")
|
|
|
|
def _make_inpaint(self):
|
|
return self._load_generator(".inpaint", "Inpaint")
|
|
|
|
def _load_generator(self, module, class_name):
|
|
mn = f"invokeai.backend.generator{module}"
|
|
cn = class_name
|
|
module = importlib.import_module(mn)
|
|
constructor = getattr(module, cn)
|
|
return constructor(self.model, self.precision)
|
|
|
|
def load_model(self):
|
|
"""
|
|
preload model identified in self.model_name
|
|
"""
|
|
return self.set_model(self.model_name)
|
|
|
|
def set_model(self, model_name):
|
|
"""
|
|
Given the name of a model defined in models.yaml, will load and initialize it
|
|
and return the model object. Previously-used models will be cached.
|
|
|
|
If the passed model_name is invalid, raises a KeyError.
|
|
If the model fails to load for some reason, will attempt to load the previously-
|
|
loaded model (if any). If that fallback fails, will raise an AssertionError
|
|
"""
|
|
if self.model_name == model_name and self.model is not None:
|
|
return self.model
|
|
|
|
previous_model_name = self.model_name
|
|
|
|
# the model cache does the loading and offloading
|
|
cache = self.model_manager
|
|
if not cache.valid_model(model_name):
|
|
raise KeyError(
|
|
f'** "{model_name}" is not a known model name. Cannot change.'
|
|
)
|
|
|
|
cache.print_vram_usage()
|
|
|
|
# have to get rid of all references to model in order
|
|
# to free it from GPU memory
|
|
self.model = None
|
|
self.sampler = None
|
|
self.generators = {}
|
|
gc.collect()
|
|
try:
|
|
model_data = cache.get_model(model_name)
|
|
except Exception as e:
|
|
logger.warning(f"model {model_name} could not be loaded: {str(e)}")
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
if previous_model_name is None:
|
|
raise e
|
|
logger.warning("trying to reload previous model")
|
|
model_data = cache.get_model(previous_model_name) # load previous
|
|
if model_data is None:
|
|
raise e
|
|
model_name = previous_model_name
|
|
|
|
self.model = model_data["model"]
|
|
self.width = model_data["width"]
|
|
self.height = model_data["height"]
|
|
self.model_hash = model_data["hash"]
|
|
|
|
# uncache generators so they pick up new models
|
|
self.generators = {}
|
|
|
|
set_seed(random.randrange(0, np.iinfo(np.uint32).max))
|
|
self.model_name = model_name
|
|
self._set_scheduler() # requires self.model_name to be set first
|
|
return self.model
|
|
|
|
def load_huggingface_concepts(self, concepts: list[str]):
|
|
self.model.textual_inversion_manager.load_huggingface_concepts(concepts)
|
|
|
|
@property
|
|
def huggingface_concepts_library(self) -> HuggingFaceConceptsLibrary:
|
|
return self.model.textual_inversion_manager.hf_concepts_library
|
|
|
|
@property
|
|
def embedding_trigger_strings(self) -> List[str]:
|
|
return self.model.textual_inversion_manager.get_all_trigger_strings()
|
|
|
|
def correct_colors(self, image_list, reference_image_path, image_callback=None):
|
|
reference_image = Image.open(reference_image_path)
|
|
correction_target = cv2.cvtColor(np.asarray(reference_image), cv2.COLOR_RGB2LAB)
|
|
for r in image_list:
|
|
image, seed = r
|
|
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2LAB)
|
|
image = skimage.exposure.match_histograms(
|
|
image, correction_target, channel_axis=2
|
|
)
|
|
image = Image.fromarray(
|
|
cv2.cvtColor(image, cv2.COLOR_LAB2RGB).astype("uint8")
|
|
)
|
|
if image_callback is not None:
|
|
image_callback(image, seed)
|
|
else:
|
|
r[0] = image
|
|
|
|
def upscale_and_reconstruct(
|
|
self,
|
|
image_list,
|
|
facetool="gfpgan",
|
|
upscale=None,
|
|
upscale_denoise_str=0.75,
|
|
strength=0.0,
|
|
codeformer_fidelity=0.75,
|
|
save_original=False,
|
|
image_callback=None,
|
|
prefix=None,
|
|
):
|
|
results = []
|
|
for r in image_list:
|
|
image, seed, _ = r
|
|
try:
|
|
if strength > 0:
|
|
if self.gfpgan is not None or self.codeformer is not None:
|
|
if facetool == "gfpgan":
|
|
if self.gfpgan is None:
|
|
logger.info(
|
|
"GFPGAN not found. Face restoration is disabled."
|
|
)
|
|
else:
|
|
image = self.gfpgan.process(image, strength, seed)
|
|
if facetool == "codeformer":
|
|
if self.codeformer is None:
|
|
logger.info(
|
|
"CodeFormer not found. Face restoration is disabled."
|
|
)
|
|
else:
|
|
cf_device = (
|
|
"cpu" if str(self.device) == "mps" else self.device
|
|
)
|
|
image = self.codeformer.process(
|
|
image=image,
|
|
strength=strength,
|
|
device=cf_device,
|
|
seed=seed,
|
|
fidelity=codeformer_fidelity,
|
|
)
|
|
else:
|
|
logger.info("Face Restoration is disabled.")
|
|
if upscale is not None:
|
|
if self.esrgan is not None:
|
|
if len(upscale) < 2:
|
|
upscale.append(0.75)
|
|
image = self.esrgan.process(
|
|
image,
|
|
upscale[1],
|
|
seed,
|
|
int(upscale[0]),
|
|
denoise_str=upscale_denoise_str,
|
|
)
|
|
else:
|
|
logger.info("ESRGAN is disabled. Image not upscaled.")
|
|
except Exception as e:
|
|
logger.info(
|
|
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
|
)
|
|
|
|
if image_callback is not None:
|
|
image_callback(image, seed, upscaled=True, use_prefix=prefix)
|
|
else:
|
|
r[0] = image
|
|
|
|
results.append([image, seed])
|
|
|
|
return results
|
|
|
|
def apply_textmask(
|
|
self, image_path: str, prompt: str, callback, threshold: float = 0.5
|
|
):
|
|
assert os.path.exists(
|
|
image_path
|
|
), f'** "{image_path}" not found. Please enter the name of an existing image file to mask **'
|
|
basename, _ = os.path.splitext(os.path.basename(image_path))
|
|
if self.txt2mask is None:
|
|
self.txt2mask = Txt2Mask(device=self.device, refined=True)
|
|
segmented = self.txt2mask.segment(image_path, prompt)
|
|
trans = segmented.to_transparent()
|
|
inverse = segmented.to_transparent(invert=True)
|
|
mask = segmented.to_mask(threshold)
|
|
|
|
path_filter = re.compile(r'[<>:"/\\|?*]')
|
|
safe_prompt = path_filter.sub("_", prompt)[:50].rstrip(" .")
|
|
|
|
callback(trans, f"{safe_prompt}.deselected", use_prefix=basename)
|
|
callback(inverse, f"{safe_prompt}.selected", use_prefix=basename)
|
|
callback(mask, f"{safe_prompt}.masked", use_prefix=basename)
|
|
|
|
# to help WebGUI - front end to generator util function
|
|
def sample_to_image(self, samples):
|
|
return self._make_base().sample_to_image(samples)
|
|
|
|
def sample_to_lowres_estimated_image(self, samples):
|
|
return self._make_base().sample_to_lowres_estimated_image(samples)
|
|
|
|
def is_legacy_model(self, model_name) -> bool:
|
|
return self.model_manager.is_legacy(model_name)
|
|
|
|
def _set_scheduler(self):
|
|
default = self.model.scheduler
|
|
|
|
# See https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672
|
|
scheduler_map = dict(
|
|
ddim=(diffusers.DDIMScheduler, dict(cpu_only=False)),
|
|
dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
|
k_dpm_2=(diffusers.KDPM2DiscreteScheduler, dict(cpu_only=False)),
|
|
k_dpm_2_a=(diffusers.KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
|
|
# DPMSolverMultistepScheduler is technically not `k_` anything, as it is neither
|
|
# the k-diffusers implementation nor included in EDM (Karras 2022), but we can
|
|
# provide an alias for compatibility.
|
|
k_dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
|
k_euler=(diffusers.EulerDiscreteScheduler, dict(cpu_only=False)),
|
|
k_euler_a=(diffusers.EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
|
|
k_heun=(diffusers.HeunDiscreteScheduler, dict(cpu_only=False)),
|
|
k_lms=(diffusers.LMSDiscreteScheduler, dict(cpu_only=False)),
|
|
plms=(diffusers.PNDMScheduler, dict(cpu_only=False)),
|
|
unipc=(diffusers.UniPCMultistepScheduler, dict(cpu_only=True))
|
|
)
|
|
|
|
if self.sampler_name in scheduler_map:
|
|
sampler_class, sampler_extra_config = scheduler_map[self.sampler_name]
|
|
msg = (
|
|
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
|
)
|
|
self.sampler = sampler_class.from_config({**self.model.scheduler.config, **sampler_extra_config})
|
|
else:
|
|
msg = (
|
|
f" Unsupported Sampler: {self.sampler_name} "+
|
|
f"Defaulting to {default}"
|
|
)
|
|
self.sampler = default
|
|
|
|
logger.info(msg)
|
|
|
|
if not hasattr(self.sampler, "uses_inpainting_model"):
|
|
# FIXME: terrible kludge!
|
|
self.sampler.uses_inpainting_model = lambda: False
|
|
|
|
def _load_img(self, img) -> Image:
|
|
if isinstance(img, Image.Image):
|
|
image = img
|
|
logger.info(f"using provided input image of size {image.width}x{image.height}")
|
|
elif isinstance(img, str):
|
|
assert os.path.exists(img), f"{img}: File not found"
|
|
|
|
image = Image.open(img)
|
|
logger.info(
|
|
f"loaded input image of size {image.width}x{image.height} from {img}"
|
|
)
|
|
else:
|
|
image = Image.open(img)
|
|
logger.info(f"loaded input image of size {image.width}x{image.height}")
|
|
image = ImageOps.exif_transpose(image)
|
|
return image
|
|
|
|
def _create_init_image(self, image: Image.Image, width, height, fit=True):
|
|
if image.mode != "RGBA":
|
|
image = image.convert("RGBA")
|
|
image = (
|
|
self._fit_image(image, (width, height))
|
|
if fit
|
|
else self._squeeze_image(image)
|
|
)
|
|
return image
|
|
|
|
def _create_init_mask(self, image, width, height, fit=True):
|
|
# convert into a black/white mask
|
|
image = self._image_to_mask(image)
|
|
image = image.convert("RGB")
|
|
image = (
|
|
self._fit_image(image, (width, height))
|
|
if fit
|
|
else self._squeeze_image(image)
|
|
)
|
|
return image
|
|
|
|
# The mask is expected to have the region to be inpainted
|
|
# with alpha transparency. It converts it into a black/white
|
|
# image with the transparent part black.
|
|
def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image:
|
|
# Obtain the mask from the transparency channel
|
|
if mask_image.mode == "L":
|
|
mask = mask_image
|
|
elif mask_image.mode in ("RGB", "P"):
|
|
mask = mask_image.convert("L")
|
|
else:
|
|
# Obtain the mask from the transparency channel
|
|
mask = Image.new(mode="L", size=mask_image.size, color=255)
|
|
mask.putdata(mask_image.getdata(band=3))
|
|
if invert:
|
|
mask = ImageOps.invert(mask)
|
|
return mask
|
|
|
|
def _txt2mask(
|
|
self, image: Image, text_mask: list, width, height, fit=True
|
|
) -> Image:
|
|
prompt = text_mask[0]
|
|
confidence_level = text_mask[1] if len(text_mask) > 1 else 0.5
|
|
if self.txt2mask is None:
|
|
self.txt2mask = Txt2Mask(device=self.device)
|
|
|
|
segmented = self.txt2mask.segment(image, prompt)
|
|
mask = segmented.to_mask(float(confidence_level))
|
|
mask = mask.convert("RGB")
|
|
mask = (
|
|
self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask)
|
|
)
|
|
return mask
|
|
|
|
def _has_transparency(self, image):
|
|
if image.info.get("transparency", None) is not None:
|
|
return True
|
|
if image.mode == "P":
|
|
transparent = image.info.get("transparency", -1)
|
|
for _, index in image.getcolors():
|
|
if index == transparent:
|
|
return True
|
|
elif image.mode == "RGBA":
|
|
extrema = image.getextrema()
|
|
if extrema[3][0] < 255:
|
|
return True
|
|
return False
|
|
|
|
def _check_for_erasure(self, image: Image.Image) -> bool:
|
|
if image.mode not in ("RGBA", "RGB"):
|
|
return False
|
|
width, height = image.size
|
|
pixdata = image.load()
|
|
colored = 0
|
|
for y in range(height):
|
|
for x in range(width):
|
|
if pixdata[x, y][3] == 0:
|
|
r, g, b, _ = pixdata[x, y]
|
|
if (r, g, b) != (0, 0, 0) and (r, g, b) != (255, 255, 255):
|
|
colored += 1
|
|
return colored == 0
|
|
|
|
def _transparency_check_and_warning(self, image, mask, force_outpaint=False):
|
|
if not mask:
|
|
logger.info(
|
|
"Initial image has transparent areas. Will inpaint in these regions."
|
|
)
|
|
if (not force_outpaint) and self._check_for_erasure(image):
|
|
logger.info(
|
|
"Colors underneath the transparent region seem to have been erased.\n" +
|
|
"Inpainting will be suboptimal. Please preserve the colors when making\n" +
|
|
"a transparency mask, or provide mask explicitly using --init_mask (-M)."
|
|
)
|
|
|
|
def _squeeze_image(self, image):
|
|
x, y, resize_needed = self._resolution_check(image.width, image.height)
|
|
if resize_needed:
|
|
return InitImageResizer(image).resize(x, y)
|
|
return image
|
|
|
|
def _fit_image(self, image, max_dimensions):
|
|
w, h = max_dimensions
|
|
logger.info(f"image will be resized to fit inside a box {w}x{h} in size.")
|
|
# note that InitImageResizer does the multiple of 64 truncation internally
|
|
image = InitImageResizer(image).resize(width=w, height=h)
|
|
logger.info(
|
|
f"after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
|
|
)
|
|
return image
|
|
|
|
def _resolution_check(self, width, height, log=False):
|
|
resize_needed = False
|
|
w, h = map(
|
|
lambda x: x - x % 64, (width, height)
|
|
) # resize to integer multiple of 64
|
|
if h != height or w != width:
|
|
if log:
|
|
logger.info(
|
|
f"Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
|
|
)
|
|
height = h
|
|
width = w
|
|
resize_needed = True
|
|
return width, height, resize_needed
|
|
|
|
def _has_cuda(self):
|
|
return self.device.type == "cuda"
|
|
|
|
def write_intermediate_images(self, modulus, path):
|
|
counter = -1
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|
|
def callback(img):
|
|
nonlocal counter
|
|
counter += 1
|
|
if counter % modulus != 0:
|
|
return
|
|
image = self.sample_to_image(img)
|
|
image.save(os.path.join(path, f"{counter:03}.png"), "PNG")
|
|
|
|
return callback
|
|
|
|
|
|
def _pairwise(iterable):
|
|
"s -> (s0, s1), (s2, s3), (s4, s5), ..."
|
|
a = iter(iterable)
|
|
return zip(a, a)
|