mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add threshold for switchover from Karras to LDM noise schedule
This commit is contained in:
@ -180,6 +180,7 @@ class Generate:
|
||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||
self.txt2mask = None
|
||||
self.safety_checker = None
|
||||
self.karras_max = None
|
||||
|
||||
# Note that in previous versions, there was an option to pass the
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
@ -270,6 +271,7 @@ class Generate:
|
||||
variation_amount = 0.0,
|
||||
threshold = 0.0,
|
||||
perlin = 0.0,
|
||||
karras_max = None,
|
||||
# these are specific to img2img and inpaint
|
||||
init_img = None,
|
||||
init_mask = None,
|
||||
@ -353,7 +355,8 @@ class Generate:
|
||||
strength = strength or self.strength
|
||||
self.seed = seed
|
||||
self.log_tokenization = log_tokenization
|
||||
self.step_callback = step_callback
|
||||
self.step_callback = step_callback
|
||||
self.karras_max = karras_max
|
||||
with_variations = [] if with_variations is None else with_variations
|
||||
|
||||
# will instantiate the model or return it from cache
|
||||
@ -398,6 +401,11 @@ class Generate:
|
||||
self.sampler_name = sampler_name
|
||||
self._set_sampler()
|
||||
|
||||
# bit of a hack to change the cached sampler's karras threshold to
|
||||
# whatever the user asked for
|
||||
if karras_max is not None and isinstance(self.sampler,KSampler):
|
||||
self.sampler.adjust_settings(karras_max=karras_max)
|
||||
|
||||
tic = time.time()
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
@ -878,26 +886,23 @@ class Generate:
|
||||
# consistent, at least
|
||||
def _set_sampler(self):
|
||||
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||
karras_max = self.karras_max # set in generate() call
|
||||
if self.sampler_name == 'plms':
|
||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||
elif self.sampler_name == 'ddim':
|
||||
self.sampler = DDIMSampler(self.model, device=self.device)
|
||||
elif self.sampler_name == 'k_dpm_2_a':
|
||||
self.sampler = KSampler(
|
||||
self.model, 'dpm_2_ancestral', device=self.device
|
||||
)
|
||||
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device, karras_max=karras_max)
|
||||
elif self.sampler_name == 'k_dpm_2':
|
||||
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
|
||||
self.sampler = KSampler(self.model, 'dpm_2', device=self.device, karras_max=karras_max)
|
||||
elif self.sampler_name == 'k_euler_a':
|
||||
self.sampler = KSampler(
|
||||
self.model, 'euler_ancestral', device=self.device
|
||||
)
|
||||
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device, karras_max=karras_max)
|
||||
elif self.sampler_name == 'k_euler':
|
||||
self.sampler = KSampler(self.model, 'euler', device=self.device)
|
||||
self.sampler = KSampler(self.model, 'euler', device=self.device, karras_max=karras_max)
|
||||
elif self.sampler_name == 'k_heun':
|
||||
self.sampler = KSampler(self.model, 'heun', device=self.device)
|
||||
self.sampler = KSampler(self.model, 'heun', device=self.device, karras_max=karras_max)
|
||||
elif self.sampler_name == 'k_lms':
|
||||
self.sampler = KSampler(self.model, 'lms', device=self.device)
|
||||
self.sampler = KSampler(self.model, 'lms', device=self.device, karras_max=karras_max)
|
||||
else:
|
||||
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'
|
||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||
|
Reference in New Issue
Block a user