txt2img: support switching diffusers schedulers

This commit is contained in:
Kevin Turner 2022-11-10 14:36:45 -08:00
parent ae9b482acf
commit d55e22981a
5 changed files with 50 additions and 29 deletions

View File

@ -12,6 +12,9 @@ SAMPLER_CHOICES = [
"k_heun", "k_heun",
"k_lms", "k_lms",
"plms", "plms",
# diffusers:
"ipndm",
"pndm",
] ]

View File

@ -19,7 +19,7 @@ import hashlib
import cv2 import cv2
import skimage import skimage
from diffusers import DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \ from diffusers import DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \
EulerAncestralDiscreteScheduler EulerAncestralDiscreteScheduler, PNDMScheduler, IPNDMScheduler
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ldm.invoke.generator.base import downsampling from ldm.invoke.generator.base import downsampling
@ -1004,36 +1004,39 @@ class Generate:
print(msg) print(msg)
def _set_scheduler(self): def _set_scheduler(self):
msg = f'>> Setting Sampler to {self.sampler_name}'
default = self.model.scheduler default = self.model.scheduler
# TODO: Test me! Not all schedulers take the same args.
scheduler_args = dict( higher_order_samplers = [
num_train_timesteps=default.num_train_timesteps, 'k_dpm_2',
beta_start=default.beta_start, 'k_dpm_2_a',
beta_end=default.beta_end, 'k_heun',
beta_schedule=default.beta_schedule, 'plms', # Its first step is like Heun
]
scheduler_map = dict(
ddim=DDIMScheduler,
ipndm=IPNDMScheduler,
k_euler=EulerDiscreteScheduler,
k_euler_a=EulerAncestralDiscreteScheduler,
k_lms=LMSDiscreteScheduler,
pndm=PNDMScheduler,
) )
trained_betas = getattr(self.model.scheduler, 'trained_betas')
if trained_betas is not None: if self.sampler_name in scheduler_map:
scheduler_args.update(trained_betas=trained_betas) sampler_class = scheduler_map[self.sampler_name]
if self.sampler_name == 'plms': msg = f'>> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})'
raise NotImplementedError("What's the diffusers implementation of PLMS?") self.sampler = sampler_class.from_config(
elif self.sampler_name == 'ddim': self.model_cache.model_name_or_path(self.model_name),
self.sampler = DDIMScheduler(**scheduler_args) subfolder="scheduler"
elif self.sampler_name == 'k_dpm_2_a': )
raise NotImplementedError("no diffusers implementation of dpm_2 samplers") elif self.sampler_name in higher_order_samplers:
elif self.sampler_name == 'k_dpm_2': msg = (f'>> Unsupported Sampler: {self.sampler_name} '
raise NotImplementedError("no diffusers implementation of dpm_2 samplers") f'— diffusers does not yet support higher-order samplers, '
elif self.sampler_name == 'k_euler_a': f'Defaulting to {default}')
self.sampler = EulerAncestralDiscreteScheduler(**scheduler_args) self.sampler = default
elif self.sampler_name == 'k_euler':
self.sampler = EulerDiscreteScheduler(**scheduler_args)
elif self.sampler_name == 'k_heun':
raise NotImplementedError("no diffusers implementation of Heun's sampler")
elif self.sampler_name == 'k_lms':
self.sampler = LMSDiscreteScheduler(**scheduler_args)
else: else:
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to {default}' msg = (f'>> Unsupported Sampler: {self.sampler_name} '
f'Defaulting to {default}')
self.sampler = default
print(msg) print(msg)

View File

@ -108,6 +108,9 @@ SAMPLER_CHOICES = [
'k_heun', 'k_heun',
'k_lms', 'k_lms',
'plms', 'plms',
# diffusers:
"ipndm",
"pndm",
] ]
PRECISION_CHOICES = [ PRECISION_CHOICES = [

View File

@ -24,7 +24,7 @@ class Txt2Img(Generator):
uc, c, extra_conditioning_info = conditioning uc, c, extra_conditioning_info = conditioning
pipeline = self.model pipeline = self.model
# TODO: customize a new pipeline for the given sampler (Scheduler) pipeline.scheduler = sampler
def make_image(x_T) -> PIL.Image.Image: def make_image(x_T) -> PIL.Image.Image:
# FIXME: restore free_gpu_mem functionality # FIXME: restore free_gpu_mem functionality

View File

@ -318,6 +318,18 @@ class ModelCache(object):
return pipeline, width, height, model_hash return pipeline, width, height, model_hash
def model_name_or_path(self, model_name:str) -> str | Path:
if model_name not in self.config:
raise ValueError(f'"{model_name}" is not a known model name. Please check your models.yaml file')
mconfig = self.config[model_name]
if 'repo_name' in mconfig:
return mconfig['repo_name']
elif 'path' in mconfig:
return Path(mconfig['path'])
else:
raise ValueError("Model config must specify either repo_name or path.")
def offload_model(self, model_name:str): def offload_model(self, model_name:str):
''' '''
Offload the indicated model to CPU. Will call Offload the indicated model to CPU. Will call