resolved whitespace difference

This commit is contained in:
Lincoln Stein 2022-10-27 17:12:22 -04:00
commit 19a6e904ec
2 changed files with 6 additions and 8 deletions

View File

@ -886,23 +886,22 @@ class Generate:
# consistent, at least # consistent, at least
def _set_sampler(self): def _set_sampler(self):
msg = f'>> Setting Sampler to {self.sampler_name}' msg = f'>> Setting Sampler to {self.sampler_name}'
karras_max = self.karras_max # set in generate() call
if self.sampler_name == 'plms': if self.sampler_name == 'plms':
self.sampler = PLMSSampler(self.model, device=self.device) self.sampler = PLMSSampler(self.model, device=self.device)
elif self.sampler_name == 'ddim': elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device) self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a': elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device, karras_max=karras_max) self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device)
elif self.sampler_name == 'k_dpm_2': elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model, 'dpm_2', device=self.device, karras_max=karras_max) self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
elif self.sampler_name == 'k_euler_a': elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device, karras_max=karras_max) self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device)
elif self.sampler_name == 'k_euler': elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model, 'euler', device=self.device, karras_max=karras_max) self.sampler = KSampler(self.model, 'euler', device=self.device)
elif self.sampler_name == 'k_heun': elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model, 'heun', device=self.device, karras_max=karras_max) self.sampler = KSampler(self.model, 'heun', device=self.device)
elif self.sampler_name == 'k_lms': elif self.sampler_name == 'k_lms':
self.sampler = KSampler(self.model, 'lms', device=self.device, karras_max=karras_max) self.sampler = KSampler(self.model, 'lms', device=self.device)
else: else:
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms' msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'
self.sampler = PLMSSampler(self.model, device=self.device) self.sampler = PLMSSampler(self.model, device=self.device)

View File

@ -3,7 +3,6 @@ ldm.models.diffusion.sampler
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
''' '''
import torch import torch
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm