* Update to resolve conflicts.

This commit is contained in:
Peter Baylies 2022-09-05 22:57:33 -04:00
parent 7ff94383ce
commit 8c8b34a889
4 changed files with 32 additions and 13 deletions

View File

@ -10,6 +10,7 @@ from PIL import Image
from einops import rearrange, repeat
from pytorch_lightning import seed_everything
from ldm.dream.devices import choose_autocast_device
from ldm.util import rand_perlin_2d
downsampling = 8
@ -36,7 +37,7 @@ class Generator():
self.with_variations = with_variations
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
image_callback=None, step_callback=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
**kwargs):
device_type,scope = choose_autocast_device(self.model.device)
make_image = self.get_make_image(
@ -45,6 +46,8 @@ class Generator():
width = width,
height = height,
step_callback = step_callback,
threshold = threshold,
perlin = perlin,
**kwargs
)
@ -63,10 +66,8 @@ class Generator():
x_T = initial_noise
else:
seed_everything(seed)
if self.model.device.type == 'mps':
x_T = self.get_noise(width,height)
x_T = self.get_noise(width,height)
# make_image will do the equivalent of get_noise itself
image = make_image(x_T)
results.append([image, seed])
if image_callback is not None:
@ -115,6 +116,10 @@ class Generator():
"""
raise NotImplementedError("get_noise() must be implemented in a descendent class")
def get_perlin_noise(self,width,height):
return torch.stack([rand_perlin_2d((height, width), (8, 8)).to(self.model.device) for _ in range(self.latent_channels)], dim=0)
def new_seed(self):
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
return self.seed

View File

@ -15,12 +15,12 @@ class Img2Img(Generator):
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,**kwargs):
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
"""
self.perlin = perlin
# PLMS sampler not supported yet, so ignore previous sampler
if not isinstance(sampler,DDIMSampler):
print(
@ -67,6 +67,10 @@ class Img2Img(Generator):
init_latent = self.init_latent
assert init_latent is not None,'call to get_noise() when init_latent not set'
if device.type == 'mps':
return torch.randn_like(init_latent, device='cpu').to(device)
x = torch.randn_like(init_latent, device='cpu').to(device)
else:
return torch.randn_like(init_latent, device=device)
x = torch.randn_like(init_latent, device=device)
if self.perlin > 0.0:
shape = init_latent.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x

View File

@ -12,12 +12,13 @@ class Txt2Img(Generator):
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,step_callback=None,**kwargs):
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
self.perlin = perlin
uc, c = conditioning
@torch.no_grad()
@ -37,7 +38,8 @@ class Txt2Img(Generator):
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
eta = ddim_eta,
img_callback = step_callback
img_callback = step_callback,
threshold = threshold,
)
return self.sample_to_image(samples)
@ -48,14 +50,18 @@ class Txt2Img(Generator):
def get_noise(self,width,height):
device = self.model.device
if device.type == 'mps':
return torch.randn([1,
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(device)
else:
return torch.randn([1,
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=device)
print(self.perlin)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x

View File

@ -193,6 +193,8 @@ class Generate:
log_tokenization= False,
with_variations = None,
variation_amount = 0.0,
threshold = 0.0,
perlin = 0.0,
# these are specific to img2img
init_img = None,
mask = None,
@ -335,7 +337,9 @@ class Generate:
height = height,
init_image = init_image, # notice that init_image is different from init_img
init_mask = init_mask_image,
strength = strength
strength = strength,
threshold = threshold,
perlin = perlin,
)
if upscale is not None or gfpgan_strength > 0: