txt2img: support switching diffusers schedulers

This commit is contained in:
Kevin Turner 2022-11-10 14:36:45 -08:00
parent ae9b482acf
commit d55e22981a
5 changed files with 50 additions and 29 deletions

View File

@ -12,6 +12,9 @@ SAMPLER_CHOICES = [
"k_heun",
"k_lms",
"plms",
# diffusers:
"ipndm",
"pndm",
]

View File

@ -19,7 +19,7 @@ import hashlib
import cv2
import skimage
from diffusers import DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \
EulerAncestralDiscreteScheduler
EulerAncestralDiscreteScheduler, PNDMScheduler, IPNDMScheduler
from omegaconf import OmegaConf
from ldm.invoke.generator.base import downsampling
@ -1004,36 +1004,39 @@ class Generate:
print(msg)
def _set_scheduler(self):
msg = f'>> Setting Sampler to {self.sampler_name}'
default = self.model.scheduler
# TODO: Test me! Not all schedulers take the same args.
scheduler_args = dict(
num_train_timesteps=default.num_train_timesteps,
beta_start=default.beta_start,
beta_end=default.beta_end,
beta_schedule=default.beta_schedule,
higher_order_samplers = [
'k_dpm_2',
'k_dpm_2_a',
'k_heun',
'plms', # Its first step is like Heun
]
scheduler_map = dict(
ddim=DDIMScheduler,
ipndm=IPNDMScheduler,
k_euler=EulerDiscreteScheduler,
k_euler_a=EulerAncestralDiscreteScheduler,
k_lms=LMSDiscreteScheduler,
pndm=PNDMScheduler,
)
trained_betas = getattr(self.model.scheduler, 'trained_betas')
if trained_betas is not None:
scheduler_args.update(trained_betas=trained_betas)
if self.sampler_name == 'plms':
raise NotImplementedError("What's the diffusers implementation of PLMS?")
elif self.sampler_name == 'ddim':
self.sampler = DDIMScheduler(**scheduler_args)
elif self.sampler_name == 'k_dpm_2_a':
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
elif self.sampler_name == 'k_dpm_2':
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
elif self.sampler_name == 'k_euler_a':
self.sampler = EulerAncestralDiscreteScheduler(**scheduler_args)
elif self.sampler_name == 'k_euler':
self.sampler = EulerDiscreteScheduler(**scheduler_args)
elif self.sampler_name == 'k_heun':
raise NotImplementedError("no diffusers implementation of Heun's sampler")
elif self.sampler_name == 'k_lms':
self.sampler = LMSDiscreteScheduler(**scheduler_args)
if self.sampler_name in scheduler_map:
sampler_class = scheduler_map[self.sampler_name]
msg = f'>> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})'
self.sampler = sampler_class.from_config(
self.model_cache.model_name_or_path(self.model_name),
subfolder="scheduler"
)
elif self.sampler_name in higher_order_samplers:
msg = (f'>> Unsupported Sampler: {self.sampler_name} '
f'— diffusers does not yet support higher-order samplers, '
f'Defaulting to {default}')
self.sampler = default
else:
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to {default}'
msg = (f'>> Unsupported Sampler: {self.sampler_name} '
f'Defaulting to {default}')
self.sampler = default
print(msg)

View File

@ -108,6 +108,9 @@ SAMPLER_CHOICES = [
'k_heun',
'k_lms',
'plms',
# diffusers:
"ipndm",
"pndm",
]
PRECISION_CHOICES = [

View File

@ -24,7 +24,7 @@ class Txt2Img(Generator):
uc, c, extra_conditioning_info = conditioning
pipeline = self.model
# TODO: customize a new pipeline for the given sampler (Scheduler)
pipeline.scheduler = sampler
def make_image(x_T) -> PIL.Image.Image:
# FIXME: restore free_gpu_mem functionality

View File

@ -318,6 +318,18 @@ class ModelCache(object):
return pipeline, width, height, model_hash
def model_name_or_path(self, model_name:str) -> str | Path:
if model_name not in self.config:
raise ValueError(f'"{model_name}" is not a known model name. Please check your models.yaml file')
mconfig = self.config[model_name]
if 'repo_name' in mconfig:
return mconfig['repo_name']
elif 'path' in mconfig:
return Path(mconfig['path'])
else:
raise ValueError("Model config must specify either repo_name or path.")
def offload_model(self, model_name:str):
'''
Offload the indicated model to CPU. Will call