Codesplit SCHEDULER_MAP for reusage

This commit is contained in:
blessedcoolant
2023-05-12 00:40:03 +12:00
parent c1e7460d39
commit 9a383e456d
5 changed files with 27 additions and 54 deletions

View File

@ -17,6 +17,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import Post
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_storage import ImageType
@ -52,29 +53,13 @@ class NoiseOutput(BaseInvocationOutput):
#fmt: on
# TODO: this seems like a hack
scheduler_map = dict(
ddim=(diffusers.DDIMScheduler, dict(cpu_only=False)),
dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
k_dpm_2=(diffusers.KDPM2DiscreteScheduler, dict(cpu_only=False)),
k_dpm_2_a=(diffusers.KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
k_dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
k_euler=(diffusers.EulerDiscreteScheduler, dict(cpu_only=False)),
k_euler_a=(diffusers.EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
k_heun=(diffusers.HeunDiscreteScheduler, dict(cpu_only=False)),
k_lms=(diffusers.LMSDiscreteScheduler, dict(cpu_only=False)),
plms=(diffusers.PNDMScheduler, dict(cpu_only=False)),
unipc=(diffusers.UniPCMultistepScheduler, dict(cpu_only=True))
)
SAMPLER_NAME_VALUES = Literal[
tuple(list(scheduler_map.keys()))
tuple(list(SCHEDULER_MAP.keys()))
]
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class, scheduler_extra_config = scheduler_map.get(scheduler_name, scheduler_map['ddim'])
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = {**model.scheduler.config, **scheduler_extra_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py