Merge branch 'main' into feat/nodes/add-randomintinvocation

This commit is contained in:
blessedcoolant 2023-05-12 15:21:49 +12:00 committed by GitHub
commit 8c1c9cd702
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 301 additions and 255 deletions

View File

@ -52,7 +52,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", ) width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", ) height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)") model: str = Field(default="", description="The model to use (currently ignored)")
# fmt: on # fmt: on

View File

@ -33,8 +33,8 @@ class ImageOutput(BaseInvocationOutput):
# fmt: off # fmt: off
type: Literal["image"] = "image" type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image") image: ImageField = Field(default=None, description="The output image")
width: Optional[int] = Field(default=None, description="The width of the image in pixels") width: int = Field(description="The width of the image in pixels")
height: Optional[int] = Field(default=None, description="The height of the image in pixels") height: int = Field(description="The height of the image in pixels")
# fmt: on # fmt: on
class Config: class Config:

View File

@ -17,6 +17,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import Post
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np import numpy as np
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
@ -52,29 +53,20 @@ class NoiseOutput(BaseInvocationOutput):
#fmt: on #fmt: on
# TODO: this seems like a hack
scheduler_map = dict(
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[
tuple(list(scheduler_map.keys())) tuple(list(SCHEDULER_MAP.keys()))
] ]
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = scheduler_map.get(scheduler_name,'ddim') scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler = scheduler_class.from_config(model.scheduler.config)
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'): if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False scheduler.uses_inpainting_model = lambda: False
@ -148,7 +140,7 @@ class TextToLatentsInvocation(BaseInvocation):
noise: Optional[LatentsField] = Field(description="The noise to use") noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)") model: str = Field(default="", description="The model to use (currently ignored)")
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
@ -216,7 +208,7 @@ class TextToLatentsInvocation(BaseInvocation):
h_symmetry_time_pct=None,#h_symmetry_time_pct, h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct, v_symmetry_time_pct=None#v_symmetry_time_pct,
), ),
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta) ).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
return conditioning_data return conditioning_data
@ -293,11 +285,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latent, device=model.device, dtype=latent.dtype latent, device=model.device, dtype=latent.dtype
) )
timesteps, _ = model.get_img2img_timesteps( timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
self.steps,
self.strength,
device=model.device,
)
result_latents, result_attention_map_saver = model.latents_from_embeddings( result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents, latents=initial_latents,

View File

@ -108,17 +108,20 @@ APP_VERSION = invokeai.version.__version__
SAMPLER_CHOICES = [ SAMPLER_CHOICES = [
"ddim", "ddim",
"k_dpm_2_a", "ddpm",
"k_dpm_2", "deis",
"k_dpmpp_2_a", "lms",
"k_dpmpp_2",
"k_euler_a",
"k_euler",
"k_heun",
"k_lms",
"plms",
# diffusers:
"pndm", "pndm",
"heun",
"euler",
"euler_k",
"euler_a",
"kdpm_2",
"kdpm_2_a",
"dpmpp_2s",
"dpmpp_2m",
"dpmpp_2m_k",
"unipc",
] ]
PRECISION_CHOICES = [ PRECISION_CHOICES = [
@ -631,7 +634,7 @@ class Args(object):
choices=SAMPLER_CHOICES, choices=SAMPLER_CHOICES,
metavar="SAMPLER_NAME", metavar="SAMPLER_NAME",
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}', help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
default="k_lms", default="lms",
) )
render_group.add_argument( render_group.add_argument(
"--log_tokenization", "--log_tokenization",

View File

@ -37,6 +37,7 @@ from .safety_checker import SafetyChecker
from .prompting import get_uc_and_c_and_ec from .prompting import get_uc_and_c_and_ec
from .prompting.conditioning import log_tokenization from .prompting.conditioning import log_tokenization
from .stable_diffusion import HuggingFaceConceptsLibrary from .stable_diffusion import HuggingFaceConceptsLibrary
from .stable_diffusion.schedulers import SCHEDULER_MAP
from .util import choose_precision, choose_torch_device from .util import choose_precision, choose_torch_device
def fix_func(orig): def fix_func(orig):
@ -141,7 +142,7 @@ class Generate:
model=None, model=None,
conf="configs/models.yaml", conf="configs/models.yaml",
embedding_path=None, embedding_path=None,
sampler_name="k_lms", sampler_name="lms",
ddim_eta=0.0, # deterministic ddim_eta=0.0, # deterministic
full_precision=False, full_precision=False,
precision="auto", precision="auto",
@ -1047,29 +1048,12 @@ class Generate:
def _set_scheduler(self): def _set_scheduler(self):
default = self.model.scheduler default = self.model.scheduler
# See https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672 if self.sampler_name in SCHEDULER_MAP:
scheduler_map = dict( sampler_class, sampler_extra_config = SCHEDULER_MAP[self.sampler_name]
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
# DPMSolverMultistepScheduler is technically not `k_` anything, as it is neither
# the k-diffusers implementation nor included in EDM (Karras 2022), but we can
# provide an alias for compatibility.
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
if self.sampler_name in scheduler_map:
sampler_class = scheduler_map[self.sampler_name]
msg = ( msg = (
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})" f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
) )
self.sampler = sampler_class.from_config(self.model.scheduler.config) self.sampler = sampler_class.from_config({**self.model.scheduler.config, **sampler_extra_config})
else: else:
msg = ( msg = (
f" Unsupported Sampler: {self.sampler_name} "+ f" Unsupported Sampler: {self.sampler_name} "+

View File

@ -31,6 +31,7 @@ from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec from ..prompting.conditioning import get_uc_and_c_and_ec
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..stable_diffusion.schedulers import SCHEDULER_MAP
downsampling = 8 downsampling = 8
@ -71,19 +72,6 @@ class InvokeAIGeneratorOutput:
# we are interposing a wrapper around the original Generator classes so that # we are interposing a wrapper around the original Generator classes so that
# old code that calls Generate will continue to work. # old code that calls Generate will continue to work.
class InvokeAIGenerator(metaclass=ABCMeta): class InvokeAIGenerator(metaclass=ABCMeta):
scheduler_map = dict(
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
def __init__(self, def __init__(self,
model_info: dict, model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(), params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
@ -175,14 +163,20 @@ class InvokeAIGenerator(metaclass=ABCMeta):
''' '''
Return list of all the schedulers that we currently handle. Return list of all the schedulers that we currently handle.
''' '''
return list(self.scheduler_map.keys()) return list(SCHEDULER_MAP.keys())
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]): def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision) return generator_class(model, self.params.precision)
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim') scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler = scheduler_class.from_config(model.scheduler.config)
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'): if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False scheduler.uses_inpainting_model = lambda: False

View File

@ -47,6 +47,7 @@ from diffusers import (
LDMTextToImagePipeline, LDMTextToImagePipeline,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
UniPCMultistepScheduler,
StableDiffusionPipeline, StableDiffusionPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
@ -1209,6 +1210,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "dpm": elif scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == 'unipc':
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == "ddim": elif scheduler_type == "ddim":
scheduler = scheduler scheduler = scheduler
else: else:

View File

@ -30,7 +30,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
SchedulerMixin, SchedulerMixin,
logging as dlogging, logging as dlogging,
) )
from huggingface_hub import scan_cache_dir from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
@ -68,7 +68,7 @@ class SDModelComponent(Enum):
scheduler="scheduler" scheduler="scheduler"
safety_checker="safety_checker" safety_checker="safety_checker"
feature_extractor="feature_extractor" feature_extractor="feature_extractor"
DEFAULT_MAX_MODELS = 2 DEFAULT_MAX_MODELS = 2
class ModelManager(object): class ModelManager(object):
@ -182,7 +182,7 @@ class ModelManager(object):
vae from the model currently in the GPU. vae from the model currently in the GPU.
""" """
return self._get_sub_model(model_name, SDModelComponent.vae) return self._get_sub_model(model_name, SDModelComponent.vae)
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer: def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
"""Given a model name identified in models.yaml, load the model into """Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned CLIPTokenizer. If no GPU if necessary and return its assigned CLIPTokenizer. If no
@ -190,12 +190,12 @@ class ModelManager(object):
currently in the GPU. currently in the GPU.
""" """
return self._get_sub_model(model_name, SDModelComponent.tokenizer) return self._get_sub_model(model_name, SDModelComponent.tokenizer)
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel: def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
"""Given a model name identified in models.yaml, load the model into """Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned UNet2DConditionModel. If no model GPU if necessary and return its assigned UNet2DConditionModel. If no model
name is provided, return the UNet from the model name is provided, return the UNet from the model
currently in the GPU. currently in the GPU.
""" """
return self._get_sub_model(model_name, SDModelComponent.unet) return self._get_sub_model(model_name, SDModelComponent.unet)
@ -222,7 +222,7 @@ class ModelManager(object):
currently in the GPU. currently in the GPU.
""" """
return self._get_sub_model(model_name, SDModelComponent.scheduler) return self._get_sub_model(model_name, SDModelComponent.scheduler)
def _get_sub_model( def _get_sub_model(
self, self,
model_name: str=None, model_name: str=None,
@ -1228,7 +1228,7 @@ class ModelManager(object):
sha.update(chunk) sha.update(chunk)
hash = sha.hexdigest() hash = sha.hexdigest()
toc = time.time() toc = time.time()
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic)) self.logger.debug(f"sha256 = {hash} ({count} files hashed in {toc - tic:4.2f}s)")
with open(hashpath, "w") as f: with open(hashpath, "w") as f:
f.write(hash) f.write(hash)
return hash return hash

View File

@ -509,10 +509,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None, run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
else:
scheduler_device = self._model_group.device_for(self.unet)
if timesteps is None: if timesteps is None:
self.scheduler.set_timesteps( self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
num_inference_steps, device=self._model_group.device_for(self.unet)
)
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator( infer_latents_from_embeddings = GeneratorToCallbackinator(
self.generate_latents_from_embeddings, PipelineIntermediateState self.generate_latents_from_embeddings, PipelineIntermediateState
@ -725,12 +728,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor, noise: torch.Tensor,
run_id=None, run_id=None,
callback=None, callback=None,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps( timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
num_inference_steps,
strength,
device=self._model_group.device_for(self.unet),
)
result_latents, result_attention_maps = self.latents_from_embeddings( result_latents, result_attention_maps = self.latents_from_embeddings(
latents=initial_latents if strength < 1.0 else torch.zeros_like( latents=initial_latents if strength < 1.0 else torch.zeros_like(
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
@ -756,13 +755,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return self.check_for_safety(output, dtype=conditioning_data.dtype) return self.check_for_safety(output, dtype=conditioning_data.dtype)
def get_img2img_timesteps( def get_img2img_timesteps(
self, num_inference_steps: int, strength: float, device self, num_inference_steps: int, strength: float, device=None
) -> (torch.Tensor, int): ) -> (torch.Tensor, int):
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
assert img2img_pipeline.scheduler is self.scheduler assert img2img_pipeline.scheduler is self.scheduler
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
else:
scheduler_device = self._model_group.device_for(self.unet)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
timesteps, adjusted_steps = img2img_pipeline.get_timesteps( timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
num_inference_steps, strength, device=device num_inference_steps, strength, device=scheduler_device
) )
# Workaround for low strength resulting in zero timesteps. # Workaround for low strength resulting in zero timesteps.
# TODO: submit upstream fix for zero-step img2img # TODO: submit upstream fix for zero-step img2img
@ -796,9 +801,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if init_image.dim() == 3: if init_image.dim() == 3:
init_image = init_image.unsqueeze(0) init_image = init_image.unsqueeze(0)
timesteps, _ = self.get_img2img_timesteps( timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
num_inference_steps, strength, device=device
)
# 6. Prepare latent variables # 6. Prepare latent variables
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents

View File

@ -0,0 +1 @@
from .schedulers import SCHEDULER_MAP

View File

@ -0,0 +1,22 @@
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler
SCHEDULER_MAP = dict(
ddim=(DDIMScheduler, dict()),
ddpm=(DDPMScheduler, dict()),
deis=(DEISMultistepScheduler, dict()),
lms=(LMSDiscreteScheduler, dict()),
pndm=(PNDMScheduler, dict()),
heun=(HeunDiscreteScheduler, dict()),
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
euler_a=(EulerAncestralDiscreteScheduler, dict()),
kdpm_2=(KDPM2DiscreteScheduler, dict()),
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
dpmpp_2s=(DPMSolverSinglestepScheduler, dict()),
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
)

View File

@ -4,17 +4,20 @@ from .parse_seed_weights import parse_seed_weights
SAMPLER_CHOICES = [ SAMPLER_CHOICES = [
"ddim", "ddim",
"k_dpm_2_a", "ddpm",
"k_dpm_2", "deis",
"k_dpmpp_2_a", "lms",
"k_dpmpp_2",
"k_euler_a",
"k_euler",
"k_heun",
"k_lms",
"plms",
# diffusers:
"pndm", "pndm",
"heun",
"euler",
"euler_k",
"euler_a",
"kdpm_2",
"kdpm_2_a",
"dpmpp_2s",
"dpmpp_2m",
"dpmpp_2m_k",
"unipc",
] ]

View File

@ -37,7 +37,7 @@ From `invokeai/frontend/web/` run `yarn install` to get everything set up.
Start everything in dev mode: Start everything in dev mode:
1. Start the dev server: `yarn dev` 1. Start the dev server: `yarn dev`
2. Start the InvokeAI UI per usual: `invokeai --web` 2. Start the InvokeAI Nodes backend: `python scripts/invokeai-new.py --web # run from the repo root`
3. Point your browser to the dev server address e.g. <http://localhost:5173/> 3. Point your browser to the dev server address e.g. <http://localhost:5173/>
### Production builds ### Production builds

View File

@ -2,17 +2,28 @@
export const DIFFUSERS_SCHEDULERS: Array<string> = [ export const DIFFUSERS_SCHEDULERS: Array<string> = [
'ddim', 'ddim',
'plms', 'ddpm',
'k_lms', 'deis',
'dpmpp_2', 'lms',
'k_dpm_2', 'pndm',
'k_dpm_2_a', 'heun',
'k_dpmpp_2', 'euler',
'k_euler', 'euler_k',
'k_euler_a', 'euler_a',
'k_heun', 'kdpm_2',
'kdpm_2_a',
'dpmpp_2s',
'dpmpp_2m',
'dpmpp_2m_k',
'unipc',
]; ];
export const IMG2IMG_DIFFUSERS_SCHEDULERS = DIFFUSERS_SCHEDULERS.filter(
(scheduler) => {
return scheduler !== 'dpmpp_2s';
}
);
// Valid image widths // Valid image widths
export const WIDTHS: Array<number> = Array.from(Array(64)).map( export const WIDTHS: Array<number> = Array.from(Array(64)).map(
(_x, i) => (i + 1) * 64 (_x, i) => (i + 1) * 64

View File

@ -47,15 +47,20 @@ export type CommonGeneratedImageMetadata = {
postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>; postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>;
sampler: sampler:
| 'ddim' | 'ddim'
| 'k_dpm_2_a' | 'ddpm'
| 'k_dpm_2' | 'deis'
| 'k_dpmpp_2_a' | 'lms'
| 'k_dpmpp_2' | 'pndm'
| 'k_euler_a' | 'heun'
| 'k_euler' | 'euler'
| 'k_heun' | 'euler_k'
| 'k_lms' | 'euler_a'
| 'plms'; | 'kdpm_2'
| 'kdpm_2_a'
| 'dpmpp_2s'
| 'dpmpp_2m'
| 'dpmpp_2m_k'
| 'unipc';
prompt: Prompt; prompt: Prompt;
seed: number; seed: number;
variations: SeedWeights; variations: SeedWeights;

View File

@ -0,0 +1,54 @@
import { Badge, Flex } from '@chakra-ui/react';
import { Image } from 'app/types/invokeai';
import { isNumber, isString } from 'lodash-es';
import { useMemo } from 'react';
type ImageMetadataOverlayProps = {
image: Image;
};
const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
const dimensions = useMemo(() => {
if (!isNumber(image.metadata?.width) || isNumber(!image.metadata?.height)) {
return;
}
return `${image.metadata?.width} × ${image.metadata?.height}`;
}, [image.metadata]);
const model = useMemo(() => {
if (!isString(image.metadata?.invokeai?.node?.model)) {
return;
}
return image.metadata?.invokeai?.node?.model;
}, [image.metadata]);
return (
<Flex
sx={{
pointerEvents: 'none',
flexDirection: 'column',
position: 'absolute',
top: 0,
right: 0,
p: 2,
alignItems: 'flex-end',
gap: 2,
}}
>
{dimensions && (
<Badge variant="solid" colorScheme="base">
{dimensions}
</Badge>
)}
{model && (
<Badge variant="solid" colorScheme="base">
{model}
</Badge>
)}
</Flex>
);
};
export default ImageMetadataOverlay;

View File

@ -1,37 +0,0 @@
import { Badge, Box, Flex } from '@chakra-ui/react';
import { Image } from 'app/types/invokeai';
type ImageToImageOverlayProps = {
image: Image;
};
const ImageToImageOverlay = ({ image }: ImageToImageOverlayProps) => {
return (
<Box
sx={{
top: 0,
left: 0,
w: 'full',
h: 'full',
position: 'absolute',
pointerEvents: 'none',
}}
>
<Flex
sx={{
position: 'absolute',
top: 0,
right: 0,
p: 2,
alignItems: 'flex-start',
}}
>
<Badge variant="solid" colorScheme="base">
{image.metadata?.width} × {image.metadata?.height}
</Badge>
</Flex>
</Box>
);
};
export default ImageToImageOverlay;

View File

@ -1,4 +1,4 @@
import { Box, Flex, Image, Skeleton, useBoolean } from '@chakra-ui/react'; import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { useGetUrl } from 'common/util/getUrl'; import { useGetUrl } from 'common/util/getUrl';
@ -11,7 +11,8 @@ import NextPrevImageButtons from './NextPrevImageButtons';
import CurrentImageHidden from './CurrentImageHidden'; import CurrentImageHidden from './CurrentImageHidden';
import { DragEvent, memo, useCallback } from 'react'; import { DragEvent, memo, useCallback } from 'react';
import { systemSelector } from 'features/system/store/systemSelectors'; import { systemSelector } from 'features/system/store/systemSelectors';
import CurrentImageFallback from './CurrentImageFallback'; import ImageFallbackSpinner from './ImageFallbackSpinner';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
export const imagesSelector = createSelector( export const imagesSelector = createSelector(
[uiSelector, gallerySelector, systemSelector], [uiSelector, gallerySelector, systemSelector],
@ -50,8 +51,6 @@ const CurrentImagePreview = () => {
} = useAppSelector(imagesSelector); } = useAppSelector(imagesSelector);
const { getUrl } = useGetUrl(); const { getUrl } = useGetUrl();
const [isLoaded, { on, off }] = useBoolean();
const handleDragStart = useCallback( const handleDragStart = useCallback(
(e: DragEvent<HTMLDivElement>) => { (e: DragEvent<HTMLDivElement>) => {
if (!image) { if (!image) {
@ -67,11 +66,11 @@ const CurrentImagePreview = () => {
return ( return (
<Flex <Flex
sx={{ sx={{
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
width: '100%', width: '100%',
height: '100%', height: '100%',
position: 'relative',
alignItems: 'center',
justifyContent: 'center',
}} }}
> >
{progressImage && shouldShowProgressInViewer ? ( {progressImage && shouldShowProgressInViewer ? (
@ -91,28 +90,23 @@ const CurrentImagePreview = () => {
/> />
) : ( ) : (
image && ( image && (
<Image <>
onDragStart={handleDragStart} <Image
fallbackStrategy="beforeLoadOrError" src={getUrl(image.url)}
src={shouldHidePreview ? undefined : getUrl(image.url)} fallbackStrategy="beforeLoadOrError"
width={image.metadata.width || 'auto'} fallback={<ImageFallbackSpinner />}
height={image.metadata.height || 'auto'} onDragStart={handleDragStart}
fallback={ sx={{
shouldHidePreview ? ( objectFit: 'contain',
<CurrentImageHidden /> maxWidth: '100%',
) : ( maxHeight: '100%',
<CurrentImageFallback /> height: 'auto',
) position: 'absolute',
} borderRadius: 'base',
sx={{ }}
objectFit: 'contain', />
maxWidth: '100%', <ImageMetadataOverlay image={image} />
maxHeight: '100%', </>
height: 'auto',
position: 'absolute',
borderRadius: 'base',
}}
/>
) )
)} )}
{shouldShowImageDetails && image && 'metadata' in image && ( {shouldShowImageDetails && image && 'metadata' in image && (

View File

@ -1,4 +1,4 @@
import { Box, Flex, Image } from '@chakra-ui/react'; import { Flex, Image, Spinner } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
@ -42,6 +42,7 @@ const GalleryProgressImage = () => {
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
aspectRatio: '1/1', aspectRatio: '1/1',
position: 'relative',
}} }}
> >
<Image <Image
@ -61,6 +62,7 @@ const GalleryProgressImage = () => {
imageRendering: shouldAntialiasProgressImage ? 'auto' : 'pixelated', imageRendering: shouldAntialiasProgressImage ? 'auto' : 'pixelated',
}} }}
/> />
<Spinner sx={{ position: 'absolute', top: 1, right: 1, opacity: 0.7 }} />
</Flex> </Flex>
); );
}; };

View File

@ -278,6 +278,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
h: 'full', h: 'full',
transition: 'transform 0.2s ease-out', transition: 'transform 0.2s ease-out',
aspectRatio: '1/1', aspectRatio: '1/1',
cursor: 'pointer',
}} }}
> >
<Image <Image

View File

@ -1,8 +1,8 @@
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react'; import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
type CurrentImageFallbackProps = SpinnerProps; type ImageFallbackSpinnerProps = SpinnerProps;
const CurrentImageFallback = (props: CurrentImageFallbackProps) => { const ImageFallbackSpinner = (props: ImageFallbackSpinnerProps) => {
const { size = 'xl', ...rest } = props; const { size = 'xl', ...rest } = props;
return ( return (
@ -21,4 +21,4 @@ const CurrentImageFallback = (props: CurrentImageFallbackProps) => {
); );
}; };
export default CurrentImageFallback; export default ImageFallbackSpinner;

View File

@ -1,6 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/types/invokeai'; import { Image } from 'app/types/invokeai';
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
type GalleryImageObjectFitType = 'contain' | 'cover'; type GalleryImageObjectFitType = 'contain' | 'cover';
@ -63,6 +64,29 @@ export const gallerySlice = createSlice({
state.shouldUseSingleGalleryColumn = action.payload; state.shouldUseSingleGalleryColumn = action.payload;
}, },
}, },
extraReducers(builder) {
builder.addCase(imageReceived.fulfilled, (state, action) => {
// When we get an updated URL for an image, we need to update the selectedImage in gallery,
// which is currently its own object (instead of a reference to an image in results/uploads)
const { imagePath } = action.payload;
const { imageName } = action.meta.arg;
if (state.selectedImage?.name === imageName) {
state.selectedImage.url = imagePath;
}
});
builder.addCase(thumbnailReceived.fulfilled, (state, action) => {
// When we get an updated URL for an image, we need to update the selectedImage in gallery,
// which is currently its own object (instead of a reference to an image in results/uploads)
const { thumbnailPath } = action.payload;
const { thumbnailName } = action.meta.arg;
if (state.selectedImage?.name === thumbnailName) {
state.selectedImage.thumbnail = thumbnailPath;
}
});
},
}); });
export const { export const {

View File

@ -20,7 +20,7 @@ export const iterationGraph = {
model: '', model: '',
progress_images: false, progress_images: false,
prompt: 'dog', prompt: 'dog',
sampler_name: 'k_lms', sampler_name: 'lms',
seamless: false, seamless: false,
steps: 11, steps: 11,
type: 'txt2img', type: 'txt2img',

View File

@ -1,8 +1,12 @@
import { DIFFUSERS_SCHEDULERS } from 'app/constants'; import {
DIFFUSERS_SCHEDULERS,
IMG2IMG_DIFFUSERS_SCHEDULERS,
} from 'app/constants';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISelect from 'common/components/IAISelect'; import IAISelect from 'common/components/IAISelect';
import { setSampler } from 'features/parameters/store/generationSlice'; import { setSampler } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -10,6 +14,9 @@ const ParamSampler = () => {
const sampler = useAppSelector( const sampler = useAppSelector(
(state: RootState) => state.generation.sampler (state: RootState) => state.generation.sampler
); );
const activeTabName = useAppSelector(activeTabNameSelector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -23,7 +30,11 @@ const ParamSampler = () => {
label={t('parameters.sampler')} label={t('parameters.sampler')}
value={sampler} value={sampler}
onChange={handleChange} onChange={handleChange}
validValues={DIFFUSERS_SCHEDULERS} validValues={
activeTabName === 'img2img' || activeTabName == 'unifiedCanvas'
? IMG2IMG_DIFFUSERS_SCHEDULERS
: DIFFUSERS_SCHEDULERS
}
minWidth={36} minWidth={36}
/> />
); );

View File

@ -47,7 +47,7 @@ const ImageToImageStrength = () => {
return ( return (
<IAISlider <IAISlider
label={`${t('parameters.strength')}`} label={`${t('parameters.denoisingStrength')}`}
step={step} step={step}
min={min} min={min}
max={sliderMax} max={sliderMax}

View File

@ -1,17 +1,18 @@
import { Flex, Image, Spinner } from '@chakra-ui/react'; import { Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import SelectImagePlaceholder from 'common/components/SelectImagePlaceholder'; import SelectImagePlaceholder from 'common/components/SelectImagePlaceholder';
import { useGetUrl } from 'common/util/getUrl'; import { useGetUrl } from 'common/util/getUrl';
import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { DragEvent, useCallback, useState } from 'react'; import { DragEvent, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ImageType } from 'services/api'; import { ImageType } from 'services/api';
import ImageToImageOverlay from 'common/components/ImageToImageOverlay'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import ImageFallbackSpinner from 'features/gallery/components/ImageFallbackSpinner';
const selector = createSelector( const selector = createSelector(
[generationSelector], [generationSelector],
@ -30,8 +31,6 @@ const InitialImagePreview = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const [isLoaded, setIsLoaded] = useState(false);
const onError = () => { const onError = () => {
dispatch( dispatch(
addToast({ addToast({
@ -42,13 +41,10 @@ const InitialImagePreview = () => {
}) })
); );
dispatch(clearInitialImage()); dispatch(clearInitialImage());
setIsLoaded(false);
}; };
const handleDrop = useCallback( const handleDrop = useCallback(
(e: DragEvent<HTMLDivElement>) => { (e: DragEvent<HTMLDivElement>) => {
setIsLoaded(false);
const name = e.dataTransfer.getData('invokeai/imageName'); const name = e.dataTransfer.getData('invokeai/imageName');
const type = e.dataTransfer.getData('invokeai/imageType') as ImageType; const type = e.dataTransfer.getData('invokeai/imageType') as ImageType;
@ -62,48 +58,32 @@ const InitialImagePreview = () => {
sx={{ sx={{
width: 'full', width: 'full',
height: 'full', height: 'full',
position: 'relative',
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
position: 'relative',
}} }}
onDrop={handleDrop} onDrop={handleDrop}
> >
<Flex {initialImage?.url && (
sx={{ <>
height: 'full', <Image
width: 'full', src={getUrl(initialImage?.url)}
blur: '5px', fallbackStrategy="beforeLoadOrError"
position: 'relative', fallback={<ImageFallbackSpinner />}
alignItems: 'center', onError={onError}
justifyContent: 'center', sx={{
}} objectFit: 'contain',
> maxWidth: '100%',
{initialImage?.url && ( maxHeight: '100%',
<> height: 'auto',
<Image position: 'absolute',
sx={{ borderRadius: 'base',
objectFit: 'contain', }}
borderRadius: 'base', />
maxHeight: 'full', <ImageMetadataOverlay image={initialImage} />
}} </>
src={getUrl(initialImage?.url)} )}
onError={onError} {!initialImage?.url && <SelectImagePlaceholder />}
onLoad={() => {
setIsLoaded(true);
}}
fallback={
<Flex
sx={{ h: 36, alignItems: 'center', justifyContent: 'center' }}
>
<Spinner color="grey" w="5rem" h="5rem" />
</Flex>
}
/>
{isLoaded && <ImageToImageOverlay image={initialImage} />}
</>
)}
{!initialImage?.url && <SelectImagePlaceholder />}
</Flex>
</Flex> </Flex>
); );
}; };

View File

@ -51,7 +51,7 @@ export const initialGenerationState: GenerationState = {
perlin: 0, perlin: 0,
prompt: '', prompt: '',
negativePrompt: '', negativePrompt: '',
sampler: 'k_lms', sampler: 'lms',
seamBlur: 16, seamBlur: 16,
seamSize: 96, seamSize: 96,
seamSteps: 30, seamSteps: 30,

View File

@ -418,6 +418,7 @@ export const systemSlice = createSlice({
state.currentStep = 0; state.currentStep = 0;
state.totalSteps = 0; state.totalSteps = 0;
state.statusTranslationKey = 'common.statusConnected'; state.statusTranslationKey = 'common.statusConnected';
state.progressImage = null;
state.toastQueue.push( state.toastQueue.push(
makeToast({ title: t('toast.canceled'), status: 'warning' }) makeToast({ title: t('toast.canceled'), status: 'warning' })

View File

@ -64,8 +64,6 @@ const ImageToImageTabCoreParameters = () => {
<ParamSteps /> <ParamSteps />
<ParamCFGScale /> <ParamCFGScale />
</Flex> </Flex>
<ParamWidth isDisabled={!shouldFitToWidthHeight} />
<ParamHeight isDisabled={!shouldFitToWidthHeight} />
<Flex gap={3} w="full"> <Flex gap={3} w="full">
<Box flexGrow={2}> <Box flexGrow={2}>
<ParamSampler /> <ParamSampler />
@ -74,6 +72,8 @@ const ImageToImageTabCoreParameters = () => {
<ModelSelect /> <ModelSelect />
</Box> </Box>
</Flex> </Flex>
<ParamWidth isDisabled={!shouldFitToWidthHeight} />
<ParamHeight isDisabled={!shouldFitToWidthHeight} />
<ImageToImageStrength /> <ImageToImageStrength />
<ImageToImageFit /> <ImageToImageFit />
</Flex> </Flex>

View File

@ -62,8 +62,6 @@ const UnifiedCanvasCoreParameters = () => {
<ParamSteps /> <ParamSteps />
<ParamCFGScale /> <ParamCFGScale />
</Flex> </Flex>
<ParamWidth />
<ParamHeight />
<Flex gap={3} w="full"> <Flex gap={3} w="full">
<Box flexGrow={2}> <Box flexGrow={2}>
<ParamSampler /> <ParamSampler />
@ -72,8 +70,9 @@ const UnifiedCanvasCoreParameters = () => {
<ModelSelect /> <ModelSelect />
</Box> </Box>
</Flex> </Flex>
<ParamWidth />
<ParamHeight />
<ImageToImageStrength /> <ImageToImageStrength />
<ImageToImageFit />
</Flex> </Flex>
)} )}
</Flex> </Flex>

View File

@ -3,7 +3,7 @@ import LanguageDetector from 'i18next-browser-languagedetector';
import Backend from 'i18next-http-backend'; import Backend from 'i18next-http-backend';
import { initReactI18next } from 'react-i18next'; import { initReactI18next } from 'react-i18next';
import translationEN from '../dist/locales/en.json'; import translationEN from '../public/locales/en.json';
import { LOCALSTORAGE_PREFIX } from 'app/store/constants'; import { LOCALSTORAGE_PREFIX } from 'app/store/constants';
if (import.meta.env.MODE === 'package') { if (import.meta.env.MODE === 'package') {

View File

@ -23,7 +23,7 @@
], ],
"threshold": 0, "threshold": 0,
"postprocessing": null, "postprocessing": null,
"sampler": "k_lms", "sampler": "lms",
"variations": [], "variations": [],
"type": "txt2img" "type": "txt2img"
} }

View File

@ -17,7 +17,7 @@ valid_metadata = {
"width": 512, "width": 512,
"height": 512, "height": 512,
"cfg_scale": 7.5, "cfg_scale": 7.5,
"scheduler": "k_lms", "scheduler": "lms",
"model": "stable-diffusion-1.5", "model": "stable-diffusion-1.5",
}, },
} }