fix img2img variations/MPS (#353)

* fix img2img variations

* fix assert for variation_amount
This commit is contained in:
Kevin Gibbons 2022-09-03 23:34:20 -07:00 committed by GitHub
parent c22c3dec56
commit 751283a2de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -286,7 +286,7 @@ class T2I:
0.0 <= variation_amount <= 1.0
), '-v --variation_amount must be in [0.0, 1.0]'
if len(with_variations) > 0 or variation_amount > 1.0:
if len(with_variations) > 0 or variation_amount > 0.0:
assert seed is not None,\
'seed must be specified when using with_variations'
if variation_amount == 0.0:
@ -336,6 +336,7 @@ class T2I:
callback=step_callback,
)
else:
init_latent = None
make_image = self._txt2img(
prompt,
steps=steps,
@ -351,11 +352,11 @@ class T2I:
if variation_amount > 0 or len(with_variations) > 0:
# use fixed initial noise plus random noise per iteration
seed_everything(seed)
initial_noise = self._get_noise(init_img,width,height)
initial_noise = self._get_noise(init_latent,width,height)
for v_seed, v_weight in with_variations:
seed = v_seed
seed_everything(seed)
next_noise = self._get_noise(init_img,width,height)
next_noise = self._get_noise(init_latent,width,height)
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
if variation_amount > 0:
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
@ -367,7 +368,7 @@ class T2I:
x_T = None
if variation_amount > 0:
seed_everything(seed)
target_noise = self._get_noise(init_img,width,height)
target_noise = self._get_noise(init_latent,width,height)
x_T = self.slerp(variation_amount, initial_noise, target_noise)
elif initial_noise is not None:
# i.e. we specified particular variations
@ -375,7 +376,7 @@ class T2I:
else:
seed_everything(seed)
if self.device.type == 'mps':
x_T = self._get_noise(init_img,width,height)
x_T = self._get_noise(init_latent,width,height)
# make_image will do the equivalent of get_noise itself
print(f' DEBUG: seed at make_image() invocation time ={seed}')
image = make_image(x_T)
@ -606,8 +607,8 @@ class T2I:
return self.model
# returns a tensor filled with random numbers from a normal distribution
def _get_noise(self,init_img,width,height):
if init_img:
def _get_noise(self,init_latent,width,height):
if init_latent is not None:
if self.device.type == 'mps':
return torch.randn_like(init_latent, device='cpu').to(self.device)
else: