add threshold for switchover from Karras to LDM noise schedule

This commit is contained in:
Lincoln Stein
2022-10-27 15:50:32 -04:00
parent 3e48b9ff85
commit 943808b925
5 changed files with 46 additions and 17 deletions

View File

@ -176,6 +176,7 @@ class Generate:
self.free_gpu_mem = free_gpu_mem
self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None
self.karras_max = None
# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
@ -253,6 +254,7 @@ class Generate:
variation_amount = 0.0,
threshold = 0.0,
perlin = 0.0,
karras_max = None,
# these are specific to img2img and inpaint
init_img = None,
init_mask = None,
@ -331,7 +333,8 @@ class Generate:
strength = strength or self.strength
self.seed = seed
self.log_tokenization = log_tokenization
self.step_callback = step_callback
self.step_callback = step_callback
self.karras_max = karras_max
with_variations = [] if with_variations is None else with_variations
# will instantiate the model or return it from cache
@ -376,6 +379,11 @@ class Generate:
self.sampler_name = sampler_name
self._set_sampler()
# bit of a hack to change the cached sampler's karras threshold to
# whatever the user asked for
if karras_max is not None and isinstance(self.sampler,KSampler):
self.sampler.adjust_settings(karras_max=karras_max)
tic = time.time()
if self._has_cuda():
torch.cuda.reset_peak_memory_stats()
@ -815,26 +823,23 @@ class Generate:
def _set_sampler(self):
msg = f'>> Setting Sampler to {self.sampler_name}'
karras_max = self.karras_max # set in generate() call
if self.sampler_name == 'plms':
self.sampler = PLMSSampler(self.model, device=self.device)
elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(
self.model, 'dpm_2_ancestral', device=self.device
)
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
self.sampler = KSampler(self.model, 'dpm_2', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(
self.model, 'euler_ancestral', device=self.device
)
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model, 'euler', device=self.device)
self.sampler = KSampler(self.model, 'euler', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model, 'heun', device=self.device)
self.sampler = KSampler(self.model, 'heun', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_lms':
self.sampler = KSampler(self.model, 'lms', device=self.device)
self.sampler = KSampler(self.model, 'lms', device=self.device, karras_max=karras_max)
else:
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'
self.sampler = PLMSSampler(self.model, device=self.device)