Codesplit SCHEDULER_MAP for reusage

This commit is contained in:
blessedcoolant 2023-05-12 00:40:03 +12:00
parent c1e7460d39
commit 9a383e456d
5 changed files with 27 additions and 54 deletions

View File

@ -17,6 +17,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import Post
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_storage import ImageType
@ -52,29 +53,13 @@ class NoiseOutput(BaseInvocationOutput):
#fmt: on
# TODO: this seems like a hack
scheduler_map = dict(
ddim=(diffusers.DDIMScheduler, dict(cpu_only=False)),
dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
k_dpm_2=(diffusers.KDPM2DiscreteScheduler, dict(cpu_only=False)),
k_dpm_2_a=(diffusers.KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
k_dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
k_euler=(diffusers.EulerDiscreteScheduler, dict(cpu_only=False)),
k_euler_a=(diffusers.EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
k_heun=(diffusers.HeunDiscreteScheduler, dict(cpu_only=False)),
k_lms=(diffusers.LMSDiscreteScheduler, dict(cpu_only=False)),
plms=(diffusers.PNDMScheduler, dict(cpu_only=False)),
unipc=(diffusers.UniPCMultistepScheduler, dict(cpu_only=True))
)
SAMPLER_NAME_VALUES = Literal[
tuple(list(scheduler_map.keys()))
tuple(list(SCHEDULER_MAP.keys()))
]
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class, scheduler_extra_config = scheduler_map.get(scheduler_name, scheduler_map['ddim'])
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = {**model.scheduler.config, **scheduler_extra_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py

View File

@ -37,6 +37,7 @@ from .safety_checker import SafetyChecker
from .prompting import get_uc_and_c_and_ec
from .prompting.conditioning import log_tokenization
from .stable_diffusion import HuggingFaceConceptsLibrary
from .stable_diffusion.schedulers import SCHEDULER_MAP
from .util import choose_precision, choose_torch_device
def fix_func(orig):
@ -1047,26 +1048,8 @@ class Generate:
def _set_scheduler(self):
default = self.model.scheduler
# See https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672
scheduler_map = dict(
ddim=(diffusers.DDIMScheduler, dict(cpu_only=False)),
dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
k_dpm_2=(diffusers.KDPM2DiscreteScheduler, dict(cpu_only=False)),
k_dpm_2_a=(diffusers.KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
# DPMSolverMultistepScheduler is technically not `k_` anything, as it is neither
# the k-diffusers implementation nor included in EDM (Karras 2022), but we can
# provide an alias for compatibility.
k_dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
k_euler=(diffusers.EulerDiscreteScheduler, dict(cpu_only=False)),
k_euler_a=(diffusers.EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
k_heun=(diffusers.HeunDiscreteScheduler, dict(cpu_only=False)),
k_lms=(diffusers.LMSDiscreteScheduler, dict(cpu_only=False)),
plms=(diffusers.PNDMScheduler, dict(cpu_only=False)),
unipc=(diffusers.UniPCMultistepScheduler, dict(cpu_only=True))
)
if self.sampler_name in scheduler_map:
sampler_class, sampler_extra_config = scheduler_map[self.sampler_name]
if self.sampler_name in SCHEDULER_MAP:
sampler_class, sampler_extra_config = SCHEDULER_MAP[self.sampler_name]
msg = (
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
)

View File

@ -31,6 +31,7 @@ from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..stable_diffusion.schedulers import SCHEDULER_MAP
downsampling = 8
@ -71,20 +72,6 @@ class InvokeAIGeneratorOutput:
# we are interposing a wrapper around the original Generator classes so that
# old code that calls Generate will continue to work.
class InvokeAIGenerator(metaclass=ABCMeta):
scheduler_map = dict(
ddim=(diffusers.DDIMScheduler, dict(cpu_only=False)),
dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
k_dpm_2=(diffusers.KDPM2DiscreteScheduler, dict(cpu_only=False)),
k_dpm_2_a=(diffusers.KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
k_dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
k_euler=(diffusers.EulerDiscreteScheduler, dict(cpu_only=False)),
k_euler_a=(diffusers.EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
k_heun=(diffusers.HeunDiscreteScheduler, dict(cpu_only=False)),
k_lms=(diffusers.LMSDiscreteScheduler, dict(cpu_only=False)),
plms=(diffusers.PNDMScheduler, dict(cpu_only=False)),
unipc=(diffusers.UniPCMultistepScheduler, dict(cpu_only=True))
)
def __init__(self,
model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
@ -176,13 +163,13 @@ class InvokeAIGenerator(metaclass=ABCMeta):
'''
Return list of all the schedulers that we currently handle.
'''
return list(self.scheduler_map.keys())
return list(SCHEDULER_MAP.keys())
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision)
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class, scheduler_extra_config = self.scheduler_map.get(scheduler_name, self.scheduler_map['ddim'])
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = {**model.scheduler.config, **scheduler_extra_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py

View File

@ -0,0 +1 @@
from .schedulers import SCHEDULER_MAP

View File

@ -0,0 +1,17 @@
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler
SCHEDULER_MAP = dict(
ddim=(DDIMScheduler, dict(cpu_only=False)),
dpmpp_2=(DPMSolverMultistepScheduler, dict(cpu_only=False)),
k_dpm_2=(KDPM2DiscreteScheduler, dict(cpu_only=False)),
k_dpm_2_a=(KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
k_dpmpp_2=(DPMSolverMultistepScheduler, dict(cpu_only=False)),
k_euler=(EulerDiscreteScheduler, dict(cpu_only=False)),
k_euler_a=(EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
k_heun=(HeunDiscreteScheduler, dict(cpu_only=False)),
k_lms=(LMSDiscreteScheduler, dict(cpu_only=False)),
plms=(PNDMScheduler, dict(cpu_only=False)),
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
)