mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Codesplit SCHEDULER_MAP for reusage
This commit is contained in:
@ -17,6 +17,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import Post
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_storage import ImageType
|
||||
@ -52,29 +53,13 @@ class NoiseOutput(BaseInvocationOutput):
|
||||
#fmt: on
|
||||
|
||||
|
||||
# TODO: this seems like a hack
|
||||
scheduler_map = dict(
|
||||
ddim=(diffusers.DDIMScheduler, dict(cpu_only=False)),
|
||||
dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
||||
k_dpm_2=(diffusers.KDPM2DiscreteScheduler, dict(cpu_only=False)),
|
||||
k_dpm_2_a=(diffusers.KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
||||
k_euler=(diffusers.EulerDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_euler_a=(diffusers.EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_heun=(diffusers.HeunDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_lms=(diffusers.LMSDiscreteScheduler, dict(cpu_only=False)),
|
||||
plms=(diffusers.PNDMScheduler, dict(cpu_only=False)),
|
||||
unipc=(diffusers.UniPCMultistepScheduler, dict(cpu_only=True))
|
||||
)
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(list(scheduler_map.keys()))
|
||||
tuple(list(SCHEDULER_MAP.keys()))
|
||||
]
|
||||
|
||||
|
||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class, scheduler_extra_config = scheduler_map.get(scheduler_name, scheduler_map['ddim'])
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
scheduler_config = {**model.scheduler.config, **scheduler_extra_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
# hack copied over from generate.py
|
||||
|
Reference in New Issue
Block a user