mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'switch-ksampler-noise-scheduler-adaptively' into development
- This sets a step switchover point at which the k-samplers stop using the Karras noise schedule and start using the LatentDiffusion noise schedule. The advantage of this is that the Karras schedule produces excellent results at low step counts but starts to become unstable at high steps. - A new command argument --karras_max, lets the user set where the switchover occurs. Default is 29 steps (1-29 steps Karras), (30 or greater LDM) - Tildebyte, sorry to do a fast forward three-way merge for this but rebasing was just too painful due to extensive recent changes to the diffuser code.
This commit is contained in:
@ -180,6 +180,7 @@ class Generate:
|
||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||
self.txt2mask = None
|
||||
self.safety_checker = None
|
||||
self.karras_max = None
|
||||
|
||||
# Note that in previous versions, there was an option to pass the
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
@ -270,6 +271,7 @@ class Generate:
|
||||
variation_amount = 0.0,
|
||||
threshold = 0.0,
|
||||
perlin = 0.0,
|
||||
karras_max = None,
|
||||
# these are specific to img2img and inpaint
|
||||
init_img = None,
|
||||
init_mask = None,
|
||||
@ -353,7 +355,8 @@ class Generate:
|
||||
strength = strength or self.strength
|
||||
self.seed = seed
|
||||
self.log_tokenization = log_tokenization
|
||||
self.step_callback = step_callback
|
||||
self.step_callback = step_callback
|
||||
self.karras_max = karras_max
|
||||
with_variations = [] if with_variations is None else with_variations
|
||||
|
||||
# will instantiate the model or return it from cache
|
||||
@ -398,6 +401,11 @@ class Generate:
|
||||
self.sampler_name = sampler_name
|
||||
self._set_sampler()
|
||||
|
||||
# bit of a hack to change the cached sampler's karras threshold to
|
||||
# whatever the user asked for
|
||||
if karras_max is not None and isinstance(self.sampler,KSampler):
|
||||
self.sampler.adjust_settings(karras_max=karras_max)
|
||||
|
||||
tic = time.time()
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
@ -878,26 +886,23 @@ class Generate:
|
||||
# consistent, at least
|
||||
def _set_sampler(self):
|
||||
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||
karras_max = self.karras_max # set in generate() call
|
||||
if self.sampler_name == 'plms':
|
||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||
elif self.sampler_name == 'ddim':
|
||||
self.sampler = DDIMSampler(self.model, device=self.device)
|
||||
elif self.sampler_name == 'k_dpm_2_a':
|
||||
self.sampler = KSampler(
|
||||
self.model, 'dpm_2_ancestral', device=self.device
|
||||
)
|
||||
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device, karras_max=karras_max)
|
||||
elif self.sampler_name == 'k_dpm_2':
|
||||
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
|
||||
self.sampler = KSampler(self.model, 'dpm_2', device=self.device, karras_max=karras_max)
|
||||
elif self.sampler_name == 'k_euler_a':
|
||||
self.sampler = KSampler(
|
||||
self.model, 'euler_ancestral', device=self.device
|
||||
)
|
||||
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device, karras_max=karras_max)
|
||||
elif self.sampler_name == 'k_euler':
|
||||
self.sampler = KSampler(self.model, 'euler', device=self.device)
|
||||
self.sampler = KSampler(self.model, 'euler', device=self.device, karras_max=karras_max)
|
||||
elif self.sampler_name == 'k_heun':
|
||||
self.sampler = KSampler(self.model, 'heun', device=self.device)
|
||||
self.sampler = KSampler(self.model, 'heun', device=self.device, karras_max=karras_max)
|
||||
elif self.sampler_name == 'k_lms':
|
||||
self.sampler = KSampler(self.model, 'lms', device=self.device)
|
||||
self.sampler = KSampler(self.model, 'lms', device=self.device, karras_max=karras_max)
|
||||
else:
|
||||
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'
|
||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||
|
Reference in New Issue
Block a user