mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
* Update to resolve conflicts.
This commit is contained in:
parent
7ff94383ce
commit
8c8b34a889
@ -10,6 +10,7 @@ from PIL import Image
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from ldm.dream.devices import choose_autocast_device
|
from ldm.dream.devices import choose_autocast_device
|
||||||
|
from ldm.util import rand_perlin_2d
|
||||||
|
|
||||||
downsampling = 8
|
downsampling = 8
|
||||||
|
|
||||||
@ -36,7 +37,7 @@ class Generator():
|
|||||||
self.with_variations = with_variations
|
self.with_variations = with_variations
|
||||||
|
|
||||||
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
|
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
|
||||||
image_callback=None, step_callback=None,
|
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
device_type,scope = choose_autocast_device(self.model.device)
|
device_type,scope = choose_autocast_device(self.model.device)
|
||||||
make_image = self.get_make_image(
|
make_image = self.get_make_image(
|
||||||
@ -45,6 +46,8 @@ class Generator():
|
|||||||
width = width,
|
width = width,
|
||||||
height = height,
|
height = height,
|
||||||
step_callback = step_callback,
|
step_callback = step_callback,
|
||||||
|
threshold = threshold,
|
||||||
|
perlin = perlin,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -63,10 +66,8 @@ class Generator():
|
|||||||
x_T = initial_noise
|
x_T = initial_noise
|
||||||
else:
|
else:
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
if self.model.device.type == 'mps':
|
x_T = self.get_noise(width,height)
|
||||||
x_T = self.get_noise(width,height)
|
|
||||||
|
|
||||||
# make_image will do the equivalent of get_noise itself
|
|
||||||
image = make_image(x_T)
|
image = make_image(x_T)
|
||||||
results.append([image, seed])
|
results.append([image, seed])
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
@ -115,6 +116,10 @@ class Generator():
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("get_noise() must be implemented in a descendent class")
|
raise NotImplementedError("get_noise() must be implemented in a descendent class")
|
||||||
|
|
||||||
|
def get_perlin_noise(self,width,height):
|
||||||
|
return torch.stack([rand_perlin_2d((height, width), (8, 8)).to(self.model.device) for _ in range(self.latent_channels)], dim=0)
|
||||||
|
|
||||||
|
|
||||||
def new_seed(self):
|
def new_seed(self):
|
||||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||||
return self.seed
|
return self.seed
|
||||||
|
@ -15,12 +15,12 @@ class Img2Img(Generator):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
conditioning,init_image,strength,step_callback=None,**kwargs):
|
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
Return value depends on the seed at the time you call it.
|
Return value depends on the seed at the time you call it.
|
||||||
"""
|
"""
|
||||||
|
self.perlin = perlin
|
||||||
# PLMS sampler not supported yet, so ignore previous sampler
|
# PLMS sampler not supported yet, so ignore previous sampler
|
||||||
if not isinstance(sampler,DDIMSampler):
|
if not isinstance(sampler,DDIMSampler):
|
||||||
print(
|
print(
|
||||||
@ -67,6 +67,10 @@ class Img2Img(Generator):
|
|||||||
init_latent = self.init_latent
|
init_latent = self.init_latent
|
||||||
assert init_latent is not None,'call to get_noise() when init_latent not set'
|
assert init_latent is not None,'call to get_noise() when init_latent not set'
|
||||||
if device.type == 'mps':
|
if device.type == 'mps':
|
||||||
return torch.randn_like(init_latent, device='cpu').to(device)
|
x = torch.randn_like(init_latent, device='cpu').to(device)
|
||||||
else:
|
else:
|
||||||
return torch.randn_like(init_latent, device=device)
|
x = torch.randn_like(init_latent, device=device)
|
||||||
|
if self.perlin > 0.0:
|
||||||
|
shape = init_latent.shape
|
||||||
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||||
|
return x
|
||||||
|
@ -12,12 +12,13 @@ class Txt2Img(Generator):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
conditioning,width,height,step_callback=None,**kwargs):
|
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
Return value depends on the seed at the time you call it
|
Return value depends on the seed at the time you call it
|
||||||
kwargs are 'width' and 'height'
|
kwargs are 'width' and 'height'
|
||||||
"""
|
"""
|
||||||
|
self.perlin = perlin
|
||||||
uc, c = conditioning
|
uc, c = conditioning
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -37,7 +38,8 @@ class Txt2Img(Generator):
|
|||||||
unconditional_guidance_scale = cfg_scale,
|
unconditional_guidance_scale = cfg_scale,
|
||||||
unconditional_conditioning = uc,
|
unconditional_conditioning = uc,
|
||||||
eta = ddim_eta,
|
eta = ddim_eta,
|
||||||
img_callback = step_callback
|
img_callback = step_callback,
|
||||||
|
threshold = threshold,
|
||||||
)
|
)
|
||||||
return self.sample_to_image(samples)
|
return self.sample_to_image(samples)
|
||||||
|
|
||||||
@ -48,14 +50,18 @@ class Txt2Img(Generator):
|
|||||||
def get_noise(self,width,height):
|
def get_noise(self,width,height):
|
||||||
device = self.model.device
|
device = self.model.device
|
||||||
if device.type == 'mps':
|
if device.type == 'mps':
|
||||||
return torch.randn([1,
|
x = torch.randn([1,
|
||||||
self.latent_channels,
|
self.latent_channels,
|
||||||
height // self.downsampling_factor,
|
height // self.downsampling_factor,
|
||||||
width // self.downsampling_factor],
|
width // self.downsampling_factor],
|
||||||
device='cpu').to(device)
|
device='cpu').to(device)
|
||||||
else:
|
else:
|
||||||
return torch.randn([1,
|
x = torch.randn([1,
|
||||||
self.latent_channels,
|
self.latent_channels,
|
||||||
height // self.downsampling_factor,
|
height // self.downsampling_factor,
|
||||||
width // self.downsampling_factor],
|
width // self.downsampling_factor],
|
||||||
device=device)
|
device=device)
|
||||||
|
print(self.perlin)
|
||||||
|
if self.perlin > 0.0:
|
||||||
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||||
|
return x
|
||||||
|
@ -193,6 +193,8 @@ class Generate:
|
|||||||
log_tokenization= False,
|
log_tokenization= False,
|
||||||
with_variations = None,
|
with_variations = None,
|
||||||
variation_amount = 0.0,
|
variation_amount = 0.0,
|
||||||
|
threshold = 0.0,
|
||||||
|
perlin = 0.0,
|
||||||
# these are specific to img2img
|
# these are specific to img2img
|
||||||
init_img = None,
|
init_img = None,
|
||||||
mask = None,
|
mask = None,
|
||||||
@ -335,7 +337,9 @@ class Generate:
|
|||||||
height = height,
|
height = height,
|
||||||
init_image = init_image, # notice that init_image is different from init_img
|
init_image = init_image, # notice that init_image is different from init_img
|
||||||
init_mask = init_mask_image,
|
init_mask = init_mask_image,
|
||||||
strength = strength
|
strength = strength,
|
||||||
|
threshold = threshold,
|
||||||
|
perlin = perlin,
|
||||||
)
|
)
|
||||||
|
|
||||||
if upscale is not None or gfpgan_strength > 0:
|
if upscale is not None or gfpgan_strength > 0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user