InvokeAI/invokeai/backend/generator/base.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

645 lines
24 KiB
Python
Raw Normal View History

2023-03-03 06:02:00 +00:00
"""
2023-02-28 05:37:13 +00:00
Base class for invokeai.backend.generator.*
including img2img, txt2img, and inpaint
2023-03-03 06:02:00 +00:00
"""
2023-02-28 05:37:13 +00:00
from __future__ import annotations
import itertools
import dataclasses
import diffusers
2023-02-28 05:37:13 +00:00
import os
import random
import traceback
from abc import ABCMeta
from argparse import Namespace
2023-02-28 05:37:13 +00:00
from contextlib import nullcontext
import cv2
import numpy as np
import torch
2023-03-03 06:02:00 +00:00
from PIL import Image, ImageChops, ImageFilter
2023-03-05 02:16:59 +00:00
from accelerate.utils import set_seed
from diffusers import DiffusionPipeline
2023-02-28 05:37:13 +00:00
from tqdm import trange
2023-03-26 05:54:46 +00:00
from typing import Callable, List, Iterator, Optional, Type
from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler
2023-02-28 05:37:13 +00:00
2023-04-29 13:43:40 +00:00
import invokeai.backend.util.logging as logger
2023-03-11 13:33:23 +00:00
from ..image_util import configure_model_padding
2023-03-03 05:02:15 +00:00
from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
2023-05-11 12:40:03 +00:00
from ..stable_diffusion.schedulers import SCHEDULER_MAP
2023-02-28 05:37:13 +00:00
downsampling = 8
2023-03-03 06:02:00 +00:00
@dataclass
class InvokeAIGeneratorBasicParams:
2023-03-26 05:54:46 +00:00
seed: Optional[int]=None
width: int=512
height: int=512
2023-03-26 05:54:46 +00:00
cfg_scale: float=7.5
steps: int=20
ddim_eta: float=0.0
2023-03-26 05:54:46 +00:00
scheduler: str='ddim'
precision: str='float16'
perlin: float=0.0
2023-03-26 05:54:46 +00:00
threshold: float=0.0
2023-03-11 13:33:23 +00:00
seamless: bool=False
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
2023-03-26 05:54:46 +00:00
h_symmetry_time_pct: Optional[float]=None
v_symmetry_time_pct: Optional[float]=None
variation_amount: float = 0.0
with_variations: list=field(default_factory=list)
2023-03-26 05:54:46 +00:00
safety_checker: Optional[SafetyChecker]=None
@dataclass
class InvokeAIGeneratorOutput:
'''
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
operation, including the image, its seed, the model name used to generate the image
2023-03-13 13:11:09 +00:00
and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes)
'''
2023-03-26 05:54:46 +00:00
image: Image.Image
seed: int
model_hash: str
2023-03-26 05:54:46 +00:00
attention_maps_images: List[Image.Image]
params: Namespace
# we are interposing a wrapper around the original Generator classes so that
# old code that calls Generate will continue to work.
class InvokeAIGenerator(metaclass=ABCMeta):
def __init__(self,
model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
):
self.model_info=model_info
self.params=params
def generate(self,
prompt: str='',
2023-03-26 05:54:46 +00:00
callback: Optional[Callable]=None,
step_callback: Optional[Callable]=None,
iterations: int=1,
**keyword_args,
)->Iterator[InvokeAIGeneratorOutput]:
'''
Return an iterator across the indicated number of generations.
Each time the iterator is called it will return an InvokeAIGeneratorOutput
object. Use like this:
outputs = txt2img.generate(prompt='banana sushi', iterations=5)
for result in outputs:
print(result.image, result.seed)
In the typical case of wanting to get just a single image, iterations
defaults to 1 and do:
output = next(txt2img.generate(prompt='banana sushi')
Pass None to get an infinite iterator.
outputs = txt2img.generate(prompt='banana sushi', iterations=None)
for o in outputs:
print(o.image, o.seed)
2023-03-13 13:11:09 +00:00
'''
generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args)
model_info = self.model_info
model_name = model_info['model_name']
model:StableDiffusionGeneratorPipeline = model_info['model']
model_hash = model_info['hash']
scheduler: Scheduler = self.get_scheduler(
model=model,
scheduler_name=generator_args.get('scheduler')
)
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
gen_class = self._generator_class()
generator = gen_class(model, self.params.precision)
if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'),
generator_args.get('with_variations')
)
2023-03-11 13:33:23 +00:00
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
else:
configure_model_padding(model,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
step_callback=step_callback,
sampler=scheduler,
**generator_args,
)
output = InvokeAIGeneratorOutput(
image=results[0][0],
seed=results[0][1],
attention_maps_images=results[0][2],
model_hash = model_hash,
params=Namespace(model_name=model_name,**generator_args),
)
if callback:
callback(output)
yield output
2023-03-13 13:11:09 +00:00
@classmethod
def schedulers(self)->List[str]:
'''
Return list of all the schedulers that we currently handle.
'''
2023-05-11 12:40:03 +00:00
return list(SCHEDULER_MAP.keys())
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision)
2023-03-13 13:11:09 +00:00
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
2023-05-11 12:40:03 +00:00
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
2023-05-11 08:52:37 +00:00
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
return scheduler
@classmethod
def _generator_class(cls)->Type[Generator]:
'''
In derived classes return the name of the generator to apply.
If you don't override will return the name of the derived
class, which nicely parallels the generator class names.
'''
return Generator
# ------------------------------------
class Txt2Img(InvokeAIGenerator):
@classmethod
def _generator_class(cls):
from .txt2img import Txt2Img
return Txt2Img
# ------------------------------------
class Img2Img(InvokeAIGenerator):
def generate(self,
2023-03-26 05:54:46 +00:00
init_image: Image.Image | torch.FloatTensor,
strength: float=0.75,
**keyword_args
2023-03-26 05:54:46 +00:00
)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(init_image=init_image,
strength=strength,
**keyword_args
)
@classmethod
def _generator_class(cls):
from .img2img import Img2Img
return Img2Img
# ------------------------------------
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img):
def generate(self,
2023-03-26 05:54:46 +00:00
mask_image: Image.Image | torch.FloatTensor,
# Seam settings - when 0, doesn't fill seam
2023-05-04 14:06:34 +00:00
seam_size: int = 96,
seam_blur: int = 16,
seam_strength: float = 0.7,
2023-05-04 14:06:34 +00:00
seam_steps: int = 30,
tile_size: int = 32,
inpaint_replace=False,
infill_method=None,
inpaint_width=None,
inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
**keyword_args
2023-03-26 05:54:46 +00:00
)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(
mask_image=mask_image,
seam_size=seam_size,
seam_blur=seam_blur,
seam_strength=seam_strength,
seam_steps=seam_steps,
tile_size=tile_size,
inpaint_replace=inpaint_replace,
infill_method=infill_method,
inpaint_width=inpaint_width,
inpaint_height=inpaint_height,
inpaint_fill=inpaint_fill,
**keyword_args
)
@classmethod
def _generator_class(cls):
from .inpaint import Inpaint
return Inpaint
# ------------------------------------
class Embiggen(Txt2Img):
def generate(
self,
embiggen: list=None,
embiggen_tiles: list = None,
strength: float=0.75,
2023-03-26 05:54:46 +00:00
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
strength=strength,
**kwargs)
2023-03-13 13:11:09 +00:00
@classmethod
def _generator_class(cls):
from .embiggen import Embiggen
return Embiggen
2023-03-13 13:11:09 +00:00
2023-02-28 05:37:13 +00:00
class Generator:
downsampling_factor: int
latent_channels: int
precision: str
2023-03-05 02:16:59 +00:00
model: DiffusionPipeline
2023-02-28 05:37:13 +00:00
2023-03-05 02:16:59 +00:00
def __init__(self, model: DiffusionPipeline, precision: str):
2023-02-28 05:37:13 +00:00
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
2023-03-03 06:02:00 +00:00
self.downsampling_factor = downsampling # BUG: should come from model or config
2023-02-28 05:37:13 +00:00
self.safety_checker = None
self.perlin = 0.0
self.threshold = 0
self.variation_amount = 0
self.with_variations = []
self.use_mps_noise = False
self.free_gpu_mem = None
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
2023-03-03 06:02:00 +00:00
def get_make_image(self, prompt, **kwargs):
2023-02-28 05:37:13 +00:00
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
"""
2023-03-03 06:02:00 +00:00
raise NotImplementedError(
"image_iterator() must be implemented in a descendent class"
)
2023-02-28 05:37:13 +00:00
def set_variation(self, seed, variation_amount, with_variations):
2023-03-03 06:02:00 +00:00
self.seed = seed
2023-02-28 05:37:13 +00:00
self.variation_amount = variation_amount
2023-03-03 06:02:00 +00:00
self.with_variations = with_variations
def generate(
self,
prompt,
width,
height,
sampler,
init_image=None,
2023-03-03 06:02:00 +00:00
iterations=1,
seed=None,
image_callback=None,
step_callback=None,
threshold=0.0,
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
safety_checker: SafetyChecker=None,
2023-03-03 06:02:00 +00:00
free_gpu_mem: bool = False,
**kwargs,
):
2023-02-28 05:37:13 +00:00
scope = nullcontext
self.safety_checker = safety_checker
self.free_gpu_mem = free_gpu_mem
attention_maps_images = []
2023-03-03 06:02:00 +00:00
attention_maps_callback = lambda saver: attention_maps_images.append(
saver.get_stacked_maps_image()
)
2023-02-28 05:37:13 +00:00
make_image = self.get_make_image(
prompt,
2023-03-03 06:02:00 +00:00
sampler=sampler,
init_image=init_image,
width=width,
height=height,
step_callback=step_callback,
threshold=threshold,
perlin=perlin,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
attention_maps_callback=attention_maps_callback,
**kwargs,
2023-02-28 05:37:13 +00:00
)
2023-03-03 06:02:00 +00:00
results = []
seed = seed if seed is not None and seed >= 0 else self.new_seed()
first_seed = seed
2023-02-28 05:37:13 +00:00
seed, initial_noise = self.generate_initial_noise(seed, width, height)
# There used to be an additional self.model.ema_scope() here, but it breaks
# the inpaint-1.5 model. Not sure what it did.... ?
with scope(self.model.device.type):
2023-03-03 06:02:00 +00:00
for n in trange(iterations, desc="Generating"):
2023-02-28 05:37:13 +00:00
x_T = None
if self.variation_amount > 0:
2023-03-05 02:16:59 +00:00
set_seed(seed)
2023-03-03 06:02:00 +00:00
target_noise = self.get_noise(width, height)
2023-02-28 05:37:13 +00:00
x_T = self.slerp(self.variation_amount, initial_noise, target_noise)
elif initial_noise is not None:
# i.e. we specified particular variations
x_T = initial_noise
else:
2023-03-05 02:16:59 +00:00
set_seed(seed)
2023-02-28 05:37:13 +00:00
try:
2023-03-03 06:02:00 +00:00
x_T = self.get_noise(width, height)
2023-02-28 05:37:13 +00:00
except:
2023-04-29 13:43:40 +00:00
logger.error("An error occurred while getting initial noise")
2023-02-28 05:37:13 +00:00
print(traceback.format_exc())
2023-03-13 13:11:09 +00:00
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
image = make_image(x_T, seed)
2023-02-28 05:37:13 +00:00
if self.safety_checker is not None:
image = self.safety_checker.check(image)
2023-02-28 05:37:13 +00:00
results.append([image, seed, attention_maps_images])
2023-02-28 05:37:13 +00:00
if image_callback is not None:
2023-03-03 06:02:00 +00:00
attention_maps_image = (
None
if len(attention_maps_images) == 0
else attention_maps_images[-1]
)
image_callback(
image,
seed,
first_seed=first_seed,
attention_maps_image=attention_maps_image,
)
2023-02-28 05:37:13 +00:00
seed = self.new_seed()
# Free up memory from the last generation.
2023-03-03 06:02:00 +00:00
clear_cuda_cache = (
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
)
2023-02-28 05:37:13 +00:00
if clear_cuda_cache is not None:
clear_cuda_cache()
return results
2023-03-03 06:02:00 +00:00
def sample_to_image(self, samples) -> Image.Image:
2023-02-28 05:37:13 +00:00
"""
Given samples returned from a sampler, converts
it into a PIL Image
"""
with torch.inference_mode():
image = self.model.decode_latents(samples)
return self.model.numpy_to_pil(image)[0]
2023-03-03 06:02:00 +00:00
def repaste_and_color_correct(
self,
result: Image.Image,
init_image: Image.Image,
init_mask: Image.Image,
mask_blur_radius: int = 8,
) -> Image.Image:
2023-02-28 05:37:13 +00:00
if init_image is None or init_mask is None:
return result
# Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
2023-03-03 06:02:00 +00:00
pil_init_mask = (
init_mask.getchannel("A")
if init_mask.mode == "RGBA"
else init_mask.convert("L")
)
pil_init_image = init_image.convert(
"RGBA"
) # Add an alpha channel if one doesn't exist
2023-02-28 05:37:13 +00:00
# Build an image with only visible pixels from source to use as reference for color-matching.
2023-03-03 06:02:00 +00:00
init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8)
init_a_pixels = np.asarray(pil_init_image.getchannel("A"), dtype=np.uint8)
2023-02-28 05:37:13 +00:00
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
# Get numpy version of result
np_image = np.asarray(result, dtype=np.uint8)
# Mask and calculate mean and standard deviation
mask_pixels = init_a_pixels * init_mask_pixels > 0
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
np_image_masked = np_image[mask_pixels, :]
if np_init_rgb_pixels_masked.size > 0:
init_means = np_init_rgb_pixels_masked.mean(axis=0)
init_std = np_init_rgb_pixels_masked.std(axis=0)
gen_means = np_image_masked.mean(axis=0)
gen_std = np_image_masked.std(axis=0)
# Color correct
np_matched_result = np_image.copy()
2023-03-03 06:02:00 +00:00
np_matched_result[:, :, :] = (
(
(
(
np_matched_result[:, :, :].astype(np.float32)
- gen_means[None, None, :]
)
/ gen_std[None, None, :]
)
* init_std[None, None, :]
+ init_means[None, None, :]
)
.clip(0, 255)
.astype(np.uint8)
)
matched_result = Image.fromarray(np_matched_result, mode="RGB")
2023-02-28 05:37:13 +00:00
else:
2023-03-03 06:02:00 +00:00
matched_result = Image.fromarray(np_image, mode="RGB")
2023-02-28 05:37:13 +00:00
# Blur the mask out (into init image) by specified amount
if mask_blur_radius > 0:
nm = np.asarray(pil_init_mask, dtype=np.uint8)
2023-03-03 06:02:00 +00:00
nmd = cv2.erode(
nm,
kernel=np.ones((3, 3), dtype=np.uint8),
iterations=int(mask_blur_radius / 2),
)
pmd = Image.fromarray(nmd, mode="L")
2023-02-28 05:37:13 +00:00
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
else:
blurred_init_mask = pil_init_mask
2023-03-03 06:02:00 +00:00
multiplied_blurred_init_mask = ImageChops.multiply(
blurred_init_mask, self.pil_image.split()[-1]
)
2023-02-28 05:37:13 +00:00
# Paste original on color-corrected generation (using blurred mask)
2023-03-03 06:02:00 +00:00
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
2023-02-28 05:37:13 +00:00
return matched_result
@staticmethod
def sample_to_lowres_estimated_image(samples):
2023-02-28 05:37:13 +00:00
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
2023-03-03 06:02:00 +00:00
v1_5_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
],
dtype=samples.dtype,
device=samples.device,
)
2023-02-28 05:37:13 +00:00
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
2023-03-03 06:02:00 +00:00
latents_ubyte = (
((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()
).cpu()
2023-02-28 05:37:13 +00:00
return Image.fromarray(latents_ubyte.numpy())
def generate_initial_noise(self, seed, width, height):
initial_noise = None
if self.variation_amount > 0 or len(self.with_variations) > 0:
# use fixed initial noise plus random noise per iteration
2023-03-05 02:16:59 +00:00
set_seed(seed)
2023-03-03 06:02:00 +00:00
initial_noise = self.get_noise(width, height)
2023-02-28 05:37:13 +00:00
for v_seed, v_weight in self.with_variations:
seed = v_seed
2023-03-05 02:16:59 +00:00
set_seed(seed)
2023-03-03 06:02:00 +00:00
next_noise = self.get_noise(width, height)
2023-02-28 05:37:13 +00:00
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
if self.variation_amount > 0:
2023-03-03 06:02:00 +00:00
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
seed = random.randrange(0, np.iinfo(np.uint32).max)
return (seed, initial_noise)
2023-02-28 05:37:13 +00:00
2023-03-03 06:02:00 +00:00
def get_perlin_noise(self, width, height):
fixdevice = "cpu" if (self.model.device.type == "mps") else self.model.device
2023-02-28 05:37:13 +00:00
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4)
# round up to the nearest block of 8
temp_width = int((width + 7) / 8) * 8
temp_height = int((height + 7) / 8) * 8
2023-03-03 06:02:00 +00:00
noise = torch.stack(
[
rand_perlin_2d(
(temp_height, temp_width), (8, 8), device=self.model.device
).to(fixdevice)
for _ in range(input_channels)
],
dim=0,
).to(self.model.device)
2023-02-28 05:37:13 +00:00
return noise[0:4, 0:height, 0:width]
def new_seed(self):
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
return self.seed
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
2023-03-03 06:02:00 +00:00
"""
2023-02-28 05:37:13 +00:00
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
2023-03-03 06:02:00 +00:00
"""
2023-02-28 05:37:13 +00:00
inputs_are_torch = False
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1, np.ndarray):
inputs_are_torch = True
v1 = v1.detach().cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(self.model.device)
return v2
# this is a handy routine for debugging use. Given a generated sample,
# convert it into a PNG image and store it at the indicated path
def save_sample(self, sample, filepath):
image = self.sample_to_image(sample)
2023-03-03 06:02:00 +00:00
dirname = os.path.dirname(filepath) or "."
2023-02-28 05:37:13 +00:00
if not os.path.exists(dirname):
2023-04-29 13:43:40 +00:00
logger.info(f"creating directory {dirname}")
2023-02-28 05:37:13 +00:00
os.makedirs(dirname, exist_ok=True)
2023-03-03 06:02:00 +00:00
image.save(filepath, "PNG")
2023-02-28 05:37:13 +00:00
2023-03-03 06:02:00 +00:00
def torch_dtype(self) -> torch.dtype:
return torch.float16 if self.precision == "float16" else torch.float32
2023-02-28 05:37:13 +00:00
# returns a tensor filled with random numbers from a normal distribution
2023-03-03 06:02:00 +00:00
def get_noise(self, width, height):
device = self.model.device
2023-02-28 05:37:13 +00:00
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4)
2023-03-03 06:02:00 +00:00
if self.use_mps_noise or device.type == "mps":
x = torch.randn(
[
1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor,
],
dtype=self.torch_dtype(),
device="cpu",
).to(device)
2023-02-28 05:37:13 +00:00
else:
2023-03-03 06:02:00 +00:00
x = torch.randn(
[
1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor,
],
dtype=self.torch_dtype(),
device=device,
)
2023-02-28 05:37:13 +00:00
if self.perlin > 0.0:
2023-03-03 06:02:00 +00:00
perlin_noise = self.get_perlin_noise(
width // self.downsampling_factor, height // self.downsampling_factor
)
x = (1 - self.perlin) * x + self.perlin * perlin_noise
2023-02-28 05:37:13 +00:00
return x