Merge branch 'switch-ksampler-noise-scheduler-adaptively' into development

- This sets a step switchover point at which the k-samplers stop using the
  Karras noise schedule and start using the LatentDiffusion noise schedule.
  The advantage of this is that the Karras schedule produces excellent
  results at low step counts but starts to become unstable at high
  steps.

- A new command argument --karras_max, lets the user set where the
  switchover occurs. Default is 29 steps (1-29 steps Karras),
  (30 or greater LDM)

- Tildebyte, sorry to do a fast forward three-way merge for this
  but rebasing was just too painful due to extensive recent
  changes to the diffuser code.
This commit is contained in:
Lincoln Stein
2022-10-27 16:11:26 -04:00
5 changed files with 52 additions and 17 deletions

View File

@ -180,6 +180,7 @@ class Generate:
self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None
self.safety_checker = None
self.karras_max = None
# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
@ -270,6 +271,7 @@ class Generate:
variation_amount = 0.0,
threshold = 0.0,
perlin = 0.0,
karras_max = None,
# these are specific to img2img and inpaint
init_img = None,
init_mask = None,
@ -353,7 +355,8 @@ class Generate:
strength = strength or self.strength
self.seed = seed
self.log_tokenization = log_tokenization
self.step_callback = step_callback
self.step_callback = step_callback
self.karras_max = karras_max
with_variations = [] if with_variations is None else with_variations
# will instantiate the model or return it from cache
@ -398,6 +401,11 @@ class Generate:
self.sampler_name = sampler_name
self._set_sampler()
# bit of a hack to change the cached sampler's karras threshold to
# whatever the user asked for
if karras_max is not None and isinstance(self.sampler,KSampler):
self.sampler.adjust_settings(karras_max=karras_max)
tic = time.time()
if self._has_cuda():
torch.cuda.reset_peak_memory_stats()
@ -878,26 +886,23 @@ class Generate:
# consistent, at least
def _set_sampler(self):
msg = f'>> Setting Sampler to {self.sampler_name}'
karras_max = self.karras_max # set in generate() call
if self.sampler_name == 'plms':
self.sampler = PLMSSampler(self.model, device=self.device)
elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(
self.model, 'dpm_2_ancestral', device=self.device
)
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
self.sampler = KSampler(self.model, 'dpm_2', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(
self.model, 'euler_ancestral', device=self.device
)
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model, 'euler', device=self.device)
self.sampler = KSampler(self.model, 'euler', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model, 'heun', device=self.device)
self.sampler = KSampler(self.model, 'heun', device=self.device, karras_max=karras_max)
elif self.sampler_name == 'k_lms':
self.sampler = KSampler(self.model, 'lms', device=self.device)
self.sampler = KSampler(self.model, 'lms', device=self.device, karras_max=karras_max)
else:
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'
self.sampler = PLMSSampler(self.model, device=self.device)