add threshold for switchover from Karras to LDM noise schedule

This commit is contained in:
Lincoln Stein
2022-10-27 15:50:32 -04:00
parent 343ae8b7af
commit 1200fbd3bd
5 changed files with 45 additions and 18 deletions

View File

@ -180,6 +180,7 @@ class Generate:
self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None
self.safety_checker = None
self.karras_max = None
# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
@ -270,6 +271,7 @@ class Generate:
variation_amount = 0.0,
threshold = 0.0,
perlin = 0.0,
karras_max = None,
# these are specific to img2img and inpaint
init_img = None,
init_mask = None,
@ -353,7 +355,8 @@ class Generate:
strength = strength or self.strength
self.seed = seed
self.log_tokenization = log_tokenization
self.step_callback = step_callback
self.step_callback = step_callback
self.karras_max = karras_max
with_variations = [] if with_variations is None else with_variations
# will instantiate the model or return it from cache
@ -398,6 +401,11 @@ class Generate:
self.sampler_name = sampler_name
self._set_sampler()
# bit of a hack to change the cached sampler's karras threshold to
# whatever the user asked for
if karras_max is not None and isinstance(self.sampler,KSampler):
self.sampler.adjust_settings(karras_max=karras_max)
tic = time.time()
if self._has_cuda():
torch.cuda.reset_peak_memory_stats()
@ -878,26 +886,23 @@ class Generate:
# consistent, at least
def _set_sampler(self):
msg = f'>> Setting Sampler to {self.sampler_name}'
karras_max = self.karras_max # set in generate() call
if self.sampler_name == 'plms':
self.sampler = PLMSSampler(self.model, device=self.device)
elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(
self.model, 'dpm_2_ancestral', device=self.device
)
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
self.sampler = KSampler(self.model, 'dpm_2', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(
self.model, 'euler_ancestral', device=self.device
)
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model, 'euler', device=self.device)
self.sampler = KSampler(self.model, 'euler', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model, 'heun', device=self.device)
self.sampler = KSampler(self.model, 'heun', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_lms':
self.sampler = KSampler(self.model, 'lms', device=self.device)
self.sampler = KSampler(self.model, 'lms', device=self.device, karras_max=karras_max)
else:
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'
self.sampler = PLMSSampler(self.model, device=self.device)