diff --git a/backend/modules/parameters.py b/backend/modules/parameters.py index 10af5ece3a..130d04d64a 100644 --- a/backend/modules/parameters.py +++ b/backend/modules/parameters.py @@ -12,6 +12,9 @@ SAMPLER_CHOICES = [ "k_heun", "k_lms", "plms", + # diffusers: + "ipndm", + "pndm", ] diff --git a/ldm/generate.py b/ldm/generate.py index 0480543989..9300458611 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -19,7 +19,7 @@ import hashlib import cv2 import skimage from diffusers import DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \ - EulerAncestralDiscreteScheduler + EulerAncestralDiscreteScheduler, PNDMScheduler, IPNDMScheduler from omegaconf import OmegaConf from ldm.invoke.generator.base import downsampling @@ -1004,36 +1004,39 @@ class Generate: print(msg) def _set_scheduler(self): - msg = f'>> Setting Sampler to {self.sampler_name}' default = self.model.scheduler - # TODO: Test me! Not all schedulers take the same args. - scheduler_args = dict( - num_train_timesteps=default.num_train_timesteps, - beta_start=default.beta_start, - beta_end=default.beta_end, - beta_schedule=default.beta_schedule, + + higher_order_samplers = [ + 'k_dpm_2', + 'k_dpm_2_a', + 'k_heun', + 'plms', # Its first step is like Heun + ] + scheduler_map = dict( + ddim=DDIMScheduler, + ipndm=IPNDMScheduler, + k_euler=EulerDiscreteScheduler, + k_euler_a=EulerAncestralDiscreteScheduler, + k_lms=LMSDiscreteScheduler, + pndm=PNDMScheduler, ) - trained_betas = getattr(self.model.scheduler, 'trained_betas') - if trained_betas is not None: - scheduler_args.update(trained_betas=trained_betas) - if self.sampler_name == 'plms': - raise NotImplementedError("What's the diffusers implementation of PLMS?") - elif self.sampler_name == 'ddim': - self.sampler = DDIMScheduler(**scheduler_args) - elif self.sampler_name == 'k_dpm_2_a': - raise NotImplementedError("no diffusers implementation of dpm_2 samplers") - elif self.sampler_name == 'k_dpm_2': - raise NotImplementedError("no diffusers implementation of dpm_2 samplers") - elif self.sampler_name == 'k_euler_a': - self.sampler = EulerAncestralDiscreteScheduler(**scheduler_args) - elif self.sampler_name == 'k_euler': - self.sampler = EulerDiscreteScheduler(**scheduler_args) - elif self.sampler_name == 'k_heun': - raise NotImplementedError("no diffusers implementation of Heun's sampler") - elif self.sampler_name == 'k_lms': - self.sampler = LMSDiscreteScheduler(**scheduler_args) + + if self.sampler_name in scheduler_map: + sampler_class = scheduler_map[self.sampler_name] + msg = f'>> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})' + self.sampler = sampler_class.from_config( + self.model_cache.model_name_or_path(self.model_name), + subfolder="scheduler" + ) + elif self.sampler_name in higher_order_samplers: + msg = (f'>> Unsupported Sampler: {self.sampler_name} ' + f'— diffusers does not yet support higher-order samplers, ' + f'Defaulting to {default}') + self.sampler = default else: - msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to {default}' + msg = (f'>> Unsupported Sampler: {self.sampler_name} ' + f'Defaulting to {default}') + self.sampler = default print(msg) diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index e746e5bab3..a58617a511 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -108,6 +108,9 @@ SAMPLER_CHOICES = [ 'k_heun', 'k_lms', 'plms', + # diffusers: + "ipndm", + "pndm", ] PRECISION_CHOICES = [ diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 15650ccbd9..219e813172 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -24,7 +24,7 @@ class Txt2Img(Generator): uc, c, extra_conditioning_info = conditioning pipeline = self.model - # TODO: customize a new pipeline for the given sampler (Scheduler) + pipeline.scheduler = sampler def make_image(x_T) -> PIL.Image.Image: # FIXME: restore free_gpu_mem functionality diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 6d8ce990f0..8ad8e3913b 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -318,6 +318,18 @@ class ModelCache(object): return pipeline, width, height, model_hash + def model_name_or_path(self, model_name:str) -> str | Path: + if model_name not in self.config: + raise ValueError(f'"{model_name}" is not a known model name. Please check your models.yaml file') + + mconfig = self.config[model_name] + if 'repo_name' in mconfig: + return mconfig['repo_name'] + elif 'path' in mconfig: + return Path(mconfig['path']) + else: + raise ValueError("Model config must specify either repo_name or path.") + def offload_model(self, model_name:str): ''' Offload the indicated model to CPU. Will call