mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
txt2img: support switching diffusers schedulers
This commit is contained in:
parent
ae9b482acf
commit
d55e22981a
@ -12,6 +12,9 @@ SAMPLER_CHOICES = [
|
||||
"k_heun",
|
||||
"k_lms",
|
||||
"plms",
|
||||
# diffusers:
|
||||
"ipndm",
|
||||
"pndm",
|
||||
]
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@ import hashlib
|
||||
import cv2
|
||||
import skimage
|
||||
from diffusers import DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \
|
||||
EulerAncestralDiscreteScheduler
|
||||
EulerAncestralDiscreteScheduler, PNDMScheduler, IPNDMScheduler
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
@ -1004,36 +1004,39 @@ class Generate:
|
||||
print(msg)
|
||||
|
||||
def _set_scheduler(self):
|
||||
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||
default = self.model.scheduler
|
||||
# TODO: Test me! Not all schedulers take the same args.
|
||||
scheduler_args = dict(
|
||||
num_train_timesteps=default.num_train_timesteps,
|
||||
beta_start=default.beta_start,
|
||||
beta_end=default.beta_end,
|
||||
beta_schedule=default.beta_schedule,
|
||||
|
||||
higher_order_samplers = [
|
||||
'k_dpm_2',
|
||||
'k_dpm_2_a',
|
||||
'k_heun',
|
||||
'plms', # Its first step is like Heun
|
||||
]
|
||||
scheduler_map = dict(
|
||||
ddim=DDIMScheduler,
|
||||
ipndm=IPNDMScheduler,
|
||||
k_euler=EulerDiscreteScheduler,
|
||||
k_euler_a=EulerAncestralDiscreteScheduler,
|
||||
k_lms=LMSDiscreteScheduler,
|
||||
pndm=PNDMScheduler,
|
||||
)
|
||||
trained_betas = getattr(self.model.scheduler, 'trained_betas')
|
||||
if trained_betas is not None:
|
||||
scheduler_args.update(trained_betas=trained_betas)
|
||||
if self.sampler_name == 'plms':
|
||||
raise NotImplementedError("What's the diffusers implementation of PLMS?")
|
||||
elif self.sampler_name == 'ddim':
|
||||
self.sampler = DDIMScheduler(**scheduler_args)
|
||||
elif self.sampler_name == 'k_dpm_2_a':
|
||||
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
|
||||
elif self.sampler_name == 'k_dpm_2':
|
||||
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
|
||||
elif self.sampler_name == 'k_euler_a':
|
||||
self.sampler = EulerAncestralDiscreteScheduler(**scheduler_args)
|
||||
elif self.sampler_name == 'k_euler':
|
||||
self.sampler = EulerDiscreteScheduler(**scheduler_args)
|
||||
elif self.sampler_name == 'k_heun':
|
||||
raise NotImplementedError("no diffusers implementation of Heun's sampler")
|
||||
elif self.sampler_name == 'k_lms':
|
||||
self.sampler = LMSDiscreteScheduler(**scheduler_args)
|
||||
|
||||
if self.sampler_name in scheduler_map:
|
||||
sampler_class = scheduler_map[self.sampler_name]
|
||||
msg = f'>> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})'
|
||||
self.sampler = sampler_class.from_config(
|
||||
self.model_cache.model_name_or_path(self.model_name),
|
||||
subfolder="scheduler"
|
||||
)
|
||||
elif self.sampler_name in higher_order_samplers:
|
||||
msg = (f'>> Unsupported Sampler: {self.sampler_name} '
|
||||
f'— diffusers does not yet support higher-order samplers, '
|
||||
f'Defaulting to {default}')
|
||||
self.sampler = default
|
||||
else:
|
||||
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to {default}'
|
||||
msg = (f'>> Unsupported Sampler: {self.sampler_name} '
|
||||
f'Defaulting to {default}')
|
||||
self.sampler = default
|
||||
|
||||
print(msg)
|
||||
|
||||
|
@ -108,6 +108,9 @@ SAMPLER_CHOICES = [
|
||||
'k_heun',
|
||||
'k_lms',
|
||||
'plms',
|
||||
# diffusers:
|
||||
"ipndm",
|
||||
"pndm",
|
||||
]
|
||||
|
||||
PRECISION_CHOICES = [
|
||||
|
@ -24,7 +24,7 @@ class Txt2Img(Generator):
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
pipeline = self.model
|
||||
# TODO: customize a new pipeline for the given sampler (Scheduler)
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
def make_image(x_T) -> PIL.Image.Image:
|
||||
# FIXME: restore free_gpu_mem functionality
|
||||
|
@ -318,6 +318,18 @@ class ModelCache(object):
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
def model_name_or_path(self, model_name:str) -> str | Path:
|
||||
if model_name not in self.config:
|
||||
raise ValueError(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
|
||||
mconfig = self.config[model_name]
|
||||
if 'repo_name' in mconfig:
|
||||
return mconfig['repo_name']
|
||||
elif 'path' in mconfig:
|
||||
return Path(mconfig['path'])
|
||||
else:
|
||||
raise ValueError("Model config must specify either repo_name or path.")
|
||||
|
||||
def offload_model(self, model_name:str):
|
||||
'''
|
||||
Offload the indicated model to CPU. Will call
|
||||
|
Loading…
x
Reference in New Issue
Block a user