code cleanup

* check that fixed side provided when requesting variant parameter sweep
(-v)
* move _get_noise() into outer scope to improve readability -
refactoring of big method call needed
This commit is contained in:
Lincoln Stein 2022-09-03 10:40:20 -04:00
parent fe5cc79249
commit 5454a0edc2

View File

@ -286,7 +286,7 @@ class T2I:
0.0 <= variation_amount <= 1.0 0.0 <= variation_amount <= 1.0
), '-v --variation_amount must be in [0.0, 1.0]' ), '-v --variation_amount must be in [0.0, 1.0]'
if len(with_variations) > 0: if len(with_variations) > 0 or variation_amount > 1.0:
assert seed is not None,\ assert seed is not None,\
'seed must be specified when using with_variations' 'seed must be specified when using with_variations'
if variation_amount == 0.0: if variation_amount == 0.0:
@ -324,6 +324,7 @@ class T2I:
self.model.encode_first_stage(init_image) self.model.encode_first_stage(init_image)
) # move to latent space ) # move to latent space
print(f' DEBUG: seed at make_image time ={seed}')
make_image = self._img2img( make_image = self._img2img(
prompt, prompt,
steps=steps, steps=steps,
@ -346,35 +347,15 @@ class T2I:
callback=step_callback, callback=step_callback,
) )
def get_noise():
if init_img:
if self.device.type == 'mps':
return torch.randn_like(init_latent, device='cpu').to(self.device)
else:
return torch.randn_like(init_latent, device=self.device)
else:
if self.device.type == 'mps':
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(self.device)
else:
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=self.device)
initial_noise = None initial_noise = None
if variation_amount > 0 or len(with_variations) > 0: if variation_amount > 0 or len(with_variations) > 0:
# use fixed initial noise plus random noise per iteration # use fixed initial noise plus random noise per iteration
seed_everything(seed) seed_everything(seed)
initial_noise = get_noise() initial_noise = self._get_noise(init_img,width,height)
for v_seed, v_weight in with_variations: for v_seed, v_weight in with_variations:
seed = v_seed seed = v_seed
seed_everything(seed) seed_everything(seed)
next_noise = get_noise() next_noise = self._get_noise(init_img,width,height)
initial_noise = self.slerp(v_weight, initial_noise, next_noise) initial_noise = self.slerp(v_weight, initial_noise, next_noise)
if variation_amount > 0: if variation_amount > 0:
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
@ -386,7 +367,7 @@ class T2I:
x_T = None x_T = None
if variation_amount > 0: if variation_amount > 0:
seed_everything(seed) seed_everything(seed)
target_noise = get_noise() target_noise = self._get_noise(init_img,width,height)
x_T = self.slerp(variation_amount, initial_noise, target_noise) x_T = self.slerp(variation_amount, initial_noise, target_noise)
elif initial_noise is not None: elif initial_noise is not None:
# i.e. we specified particular variations # i.e. we specified particular variations
@ -394,8 +375,9 @@ class T2I:
else: else:
seed_everything(seed) seed_everything(seed)
if self.device.type == 'mps': if self.device.type == 'mps':
x_T = get_noise() x_T = self._get_noise(init_img,width,height)
# make_image will do the equivalent of get_noise itself # make_image will do the equivalent of get_noise itself
print(f' DEBUG: seed at make_image() invocation time ={seed}')
image = make_image(x_T) image = make_image(x_T)
results.append([image, seed]) results.append([image, seed])
if image_callback is not None: if image_callback is not None:
@ -623,6 +605,27 @@ class T2I:
return self.model return self.model
# returns a tensor filled with random numbers from a normal distribution
def _get_noise(self,init_img,width,height):
if init_img:
if self.device.type == 'mps':
return torch.randn_like(init_latent, device='cpu').to(self.device)
else:
return torch.randn_like(init_latent, device=self.device)
else:
if self.device.type == 'mps':
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(self.device)
else:
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=self.device)
def _set_sampler(self): def _set_sampler(self):
msg = f'>> Setting Sampler to {self.sampler_name}' msg = f'>> Setting Sampler to {self.sampler_name}'
if self.sampler_name == 'plms': if self.sampler_name == 'plms':