mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'upgrade-diffusers' of https://github.com/blessedcoolant/InvokeAI into upgrade-diffusers
This commit is contained in:
commit
a53e0dce6c
@ -67,7 +67,7 @@ title: Home
|
||||
implementation of Stable Diffusion, the open source text-to-image and
|
||||
image-to-image generator. It provides a streamlined process with various new
|
||||
features and options to aid the image generation process. It runs on Windows,
|
||||
Mac and Linux machines, and runs on GPU cards with as little as 4 GB or RAM.
|
||||
Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
|
||||
|
||||
**Quick links**: [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>]
|
||||
[<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a
|
||||
|
@ -12,12 +12,19 @@ from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
|
||||
from ...backend.generator import Img2Img, Inpaint, InvokeAIGenerator, Txt2Img
|
||||
from ...backend.generator import Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .image import ImageOutput
|
||||
|
||||
import re
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from .model import UNetField, VaeField
|
||||
from .compel import ConditioningField
|
||||
from contextlib import contextmanager, ExitStack, ContextDecorator
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
@ -25,114 +32,48 @@ DEFAULT_INFILL_METHOD = (
|
||||
)
|
||||
|
||||
|
||||
class SDImageInvocation(BaseModel):
|
||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
||||
from .latent import get_scheduler
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
},
|
||||
},
|
||||
}
|
||||
class OldModelContext(ContextDecorator):
|
||||
model: StableDiffusionGeneratorPipeline
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def __enter__(self):
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
class OldModelInfo:
|
||||
name: str
|
||||
hash: str
|
||||
context: OldModelContext
|
||||
|
||||
def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline):
|
||||
self.name = name
|
||||
self.hash = hash
|
||||
self.context = OldModelContext(
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
"""Generates an image using text2img."""
|
||||
class InpaintInvocation(BaseInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
type: Literal["txt2img"] = "txt2img"
|
||||
type: Literal["inpaint"] = "inpaint"
|
||||
|
||||
# Inputs
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
control_model: Optional[str] = Field(default=None, description="The control model to use")
|
||||
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
|
||||
# fmt: on
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Handle invalid model parameter
|
||||
model = context.services.model_manager.get_model(self.model,node=self,context=context)
|
||||
|
||||
# loading controlnet image (currently requires pre-processed image)
|
||||
control_image = (
|
||||
None if self.control_image is None
|
||||
else context.services.images.get_pil_image(self.control_image.image_name)
|
||||
)
|
||||
# loading controlnet model
|
||||
if (self.control_model is None or self.control_model==''):
|
||||
control_model = None
|
||||
else:
|
||||
# FIXME: change this to dropdown menu?
|
||||
# FIXME: generalize so don't have to hardcode torch_dtype and device
|
||||
control_model = ControlNetModel.from_pretrained(self.control_model,
|
||||
torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
txt2img = Txt2Img(model, control_model=control_model)
|
||||
outputs = txt2img.generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
control_image=control_image,
|
||||
**self.dict(
|
||||
exclude={"prompt", "control_image" }
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generate_output = next(outputs)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generate_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageToImageInvocation(TextToImageInvocation):
|
||||
"""Generates an image using img2img."""
|
||||
|
||||
type: Literal["img2img"] = "img2img"
|
||||
unet: UNetField = Field(default=None, description="UNet model")
|
||||
vae: VaeField = Field(default=None, description="Vae model")
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
@ -144,72 +85,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||
)
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get_pil_image(self.image.image_name)
|
||||
)
|
||||
|
||||
if self.fit:
|
||||
image = image.resize((self.width, self.height))
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = context.services.model_manager.get_model(self.model,node=self,context=context)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generator_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
type: Literal["inpaint"] = "inpaint"
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField, None] = Field(description="The mask")
|
||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||
@ -252,6 +127,14 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
description="The amount by which to replace masked areas with latent noise",
|
||||
)
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@ -265,6 +148,49 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def get_conditioning(self, context):
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
||||
return (uc, c, extra_conditioning_info)
|
||||
|
||||
@contextmanager
|
||||
def load_model_old_way(self, context, scheduler):
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
|
||||
#unet = unet_info.context.model
|
||||
#vae = vae_info.context.model
|
||||
|
||||
with ExitStack() as stack:
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
|
||||
with vae_info as vae,\
|
||||
unet_info as unet,\
|
||||
ModelPatcher.apply_lora_unet(unet, loras):
|
||||
|
||||
device = context.services.model_manager.mgr.cache.execution_device
|
||||
dtype = context.services.model_manager.mgr.cache.precision
|
||||
|
||||
pipeline = StableDiffusionGeneratorPipeline(
|
||||
vae=vae,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if dtype == torch.float16 else "float32",
|
||||
execution_device=device,
|
||||
)
|
||||
|
||||
yield OldModelInfo(
|
||||
name=self.unet.unet.model_name,
|
||||
hash="<NO-HASH>",
|
||||
model=pipeline,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
@ -277,25 +203,31 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
else context.services.images.get_pil_image(self.mask.image_name)
|
||||
)
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = context.services.model_manager.get_model(self.model,node=self,context=context)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
conditioning = self.get_conditioning(context)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
with self.load_model_old_way(context, scheduler) as model:
|
||||
outputs = Inpaint(model).generate(
|
||||
conditioning=conditioning,
|
||||
scheduler=scheduler,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
@ -7,7 +7,7 @@ import einops
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
from diffusers import ControlNetModel
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
@ -233,7 +233,17 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(scheduler, eta=0.0)#ddim_eta)
|
||||
)
|
||||
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
||||
scheduler,
|
||||
|
||||
# for ddim scheduler
|
||||
eta=0.0, #ddim_eta
|
||||
|
||||
# for ancestral and sde schedulers
|
||||
generator=torch.Generator(device=uc.device).manual_seed(0),
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
||||
|
@ -5,7 +5,6 @@ from .generator import (
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorOutput,
|
||||
Txt2Img,
|
||||
Img2Img,
|
||||
Inpaint
|
||||
)
|
||||
|
@ -5,7 +5,6 @@ from .base import (
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGeneratorOutput,
|
||||
Txt2Img,
|
||||
Img2Img,
|
||||
Inpaint,
|
||||
Generator,
|
||||
|
@ -29,7 +29,6 @@ import invokeai.backend.util.logging as logger
|
||||
from ..image_util import configure_model_padding
|
||||
from ..util.util import rand_perlin_2d
|
||||
from ..safety_checker import SafetyChecker
|
||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
|
||||
@ -81,13 +80,15 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
self.params=params
|
||||
self.kwargs = kwargs
|
||||
|
||||
def generate(self,
|
||||
prompt: str='',
|
||||
callback: Optional[Callable]=None,
|
||||
step_callback: Optional[Callable]=None,
|
||||
iterations: int=1,
|
||||
**keyword_args,
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
def generate(
|
||||
self,
|
||||
conditioning: tuple,
|
||||
scheduler,
|
||||
callback: Optional[Callable]=None,
|
||||
step_callback: Optional[Callable]=None,
|
||||
iterations: int=1,
|
||||
**keyword_args,
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
'''
|
||||
Return an iterator across the indicated number of generations.
|
||||
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
||||
@ -116,11 +117,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
model_name = model_info.name
|
||||
model_hash = model_info.hash
|
||||
with model_info.context as model:
|
||||
scheduler: Scheduler = self.get_scheduler(
|
||||
model=model,
|
||||
scheduler_name=generator_args.get('scheduler')
|
||||
)
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||
gen_class = self._generator_class()
|
||||
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||
if self.params.variation_amount > 0:
|
||||
@ -143,12 +139,12 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
|
||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||
for i in iteration_count:
|
||||
results = generator.generate(prompt,
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
step_callback=step_callback,
|
||||
sampler=scheduler,
|
||||
**generator_args,
|
||||
)
|
||||
results = generator.generate(
|
||||
conditioning=conditioning,
|
||||
step_callback=step_callback,
|
||||
sampler=scheduler,
|
||||
**generator_args,
|
||||
)
|
||||
output = InvokeAIGeneratorOutput(
|
||||
image=results[0][0],
|
||||
seed=results[0][1],
|
||||
@ -170,20 +166,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
return scheduler
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls)->Type[Generator]:
|
||||
'''
|
||||
@ -193,13 +175,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
'''
|
||||
return Generator
|
||||
|
||||
# ------------------------------------
|
||||
class Txt2Img(InvokeAIGenerator):
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .txt2img import Txt2Img
|
||||
return Txt2Img
|
||||
|
||||
# ------------------------------------
|
||||
class Img2Img(InvokeAIGenerator):
|
||||
def generate(self,
|
||||
@ -253,24 +228,6 @@ class Inpaint(Img2Img):
|
||||
from .inpaint import Inpaint
|
||||
return Inpaint
|
||||
|
||||
# ------------------------------------
|
||||
class Embiggen(Txt2Img):
|
||||
def generate(
|
||||
self,
|
||||
embiggen: list=None,
|
||||
embiggen_tiles: list = None,
|
||||
strength: float=0.75,
|
||||
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
|
||||
return super().generate(embiggen=embiggen,
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
strength=strength,
|
||||
**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .embiggen import Embiggen
|
||||
return Embiggen
|
||||
|
||||
class Generator:
|
||||
downsampling_factor: int
|
||||
latent_channels: int
|
||||
@ -281,7 +238,7 @@ class Generator:
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.latent_channels = model.unet.config.in_channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.safety_checker = None
|
||||
self.perlin = 0.0
|
||||
@ -292,7 +249,7 @@ class Generator:
|
||||
self.free_gpu_mem = None
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self, prompt, **kwargs):
|
||||
def get_make_image(self, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
@ -308,7 +265,6 @@ class Generator:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt,
|
||||
width,
|
||||
height,
|
||||
sampler,
|
||||
@ -333,7 +289,6 @@ class Generator:
|
||||
saver.get_stacked_maps_image()
|
||||
)
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
init_image=init_image,
|
||||
width=width,
|
||||
|
@ -1,559 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.generator.embiggen descends from .generator
|
||||
and generates with .generator.img2img
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import trange
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from .base import Generator
|
||||
from .img2img import Img2Img
|
||||
|
||||
class Embiggen(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None
|
||||
|
||||
# Replace generate because Embiggen doesn't need/use most of what it does normallly
|
||||
def generate(
|
||||
self,
|
||||
prompt,
|
||||
iterations=1,
|
||||
seed=None,
|
||||
image_callback=None,
|
||||
step_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
make_image = self.get_make_image(prompt, step_callback=step_callback, **kwargs)
|
||||
results = []
|
||||
seed = seed if seed else self.new_seed()
|
||||
|
||||
# Noise will be generated by the Img2Img generator when called
|
||||
for _ in trange(iterations, desc="Generating"):
|
||||
# make_image will call Img2Img which will do the equivalent of get_noise itself
|
||||
image = make_image()
|
||||
results.append([image, seed])
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, prompt_in=prompt)
|
||||
seed = self.new_seed()
|
||||
return results
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_img,
|
||||
strength,
|
||||
width,
|
||||
height,
|
||||
embiggen,
|
||||
embiggen_tiles,
|
||||
step_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
assert (
|
||||
not sampler.uses_inpainting_model()
|
||||
), "--embiggen is not supported by inpainting models"
|
||||
|
||||
# Construct embiggen arg array, and sanity check arguments
|
||||
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
||||
embiggen = [1.0] # If not specified, assume no scaling
|
||||
elif embiggen[0] < 0:
|
||||
embiggen[0] = 1.0
|
||||
logger.warning(
|
||||
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||
)
|
||||
if len(embiggen) < 2:
|
||||
embiggen.append(0.75)
|
||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||
embiggen[1] = 0.75
|
||||
logger.warning(
|
||||
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||
)
|
||||
if len(embiggen) < 3:
|
||||
embiggen.append(0.25)
|
||||
elif embiggen[2] < 0:
|
||||
embiggen[2] = 0.25
|
||||
logger.warning(
|
||||
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||
)
|
||||
|
||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||
# and then sort them, because... people.
|
||||
if embiggen_tiles:
|
||||
embiggen_tiles = list(map(lambda n: n - 1, embiggen_tiles))
|
||||
embiggen_tiles.sort()
|
||||
|
||||
if strength >= 0.5:
|
||||
logger.warning(
|
||||
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||
)
|
||||
|
||||
# Prep img2img generator, since we wrap over it
|
||||
gen_img2img = Img2Img(self.model, self.precision)
|
||||
|
||||
# Open original init image (not a tensor) to manipulate
|
||||
initsuperimage = Image.open(init_img)
|
||||
|
||||
with Image.open(init_img) as img:
|
||||
initsuperimage = img.convert("RGB")
|
||||
|
||||
# Size of the target super init image in pixels
|
||||
initsuperwidth, initsuperheight = initsuperimage.size
|
||||
|
||||
# Increase by scaling factor if not already resized, using ESRGAN as able
|
||||
if embiggen[0] != 1.0:
|
||||
initsuperwidth = round(initsuperwidth * embiggen[0])
|
||||
initsuperheight = round(initsuperheight * embiggen[0])
|
||||
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
||||
from ..restoration.realesrgan import ESRGAN
|
||||
|
||||
esrgan = ESRGAN()
|
||||
logger.info(
|
||||
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||
)
|
||||
if embiggen[0] > 2:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
4, # upscale scale
|
||||
)
|
||||
else:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
2, # upscale scale
|
||||
)
|
||||
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
|
||||
# but from personal experiance it doesn't greatly improve anything after 4x
|
||||
# Resize to target scaling factor resolution
|
||||
initsuperimage = initsuperimage.resize(
|
||||
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
# Use width and height as tile widths and height
|
||||
# Determine buffer size in pixels
|
||||
if embiggen[2] < 1:
|
||||
if embiggen[2] < 0:
|
||||
embiggen[2] = 0
|
||||
overlap_size_x = round(embiggen[2] * width)
|
||||
overlap_size_y = round(embiggen[2] * height)
|
||||
else:
|
||||
overlap_size_x = round(embiggen[2])
|
||||
overlap_size_y = round(embiggen[2])
|
||||
|
||||
# With overall image width and height known, determine how many tiles we need
|
||||
def ceildiv(a, b):
|
||||
return -1 * (-a // b)
|
||||
|
||||
# X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count)
|
||||
# (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill
|
||||
# (width - overlap_size_x) is how much new we can fill with a single tile
|
||||
emb_tiles_x = 1
|
||||
emb_tiles_y = 1
|
||||
if (initsuperwidth - width) > 0:
|
||||
emb_tiles_x = ceildiv(initsuperwidth - width, width - overlap_size_x) + 1
|
||||
if (initsuperheight - height) > 0:
|
||||
emb_tiles_y = ceildiv(initsuperheight - height, height - overlap_size_y) + 1
|
||||
# Sanity
|
||||
assert (
|
||||
emb_tiles_x > 1 or emb_tiles_y > 1
|
||||
), f"ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don't need to Embiggen! Check your arguments."
|
||||
|
||||
# Prep alpha layers --------------
|
||||
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
|
||||
# agradientL is Left-side transparent
|
||||
agradientL = (
|
||||
Image.linear_gradient("L").rotate(90).resize((overlap_size_x, height))
|
||||
)
|
||||
# agradientT is Top-side transparent
|
||||
agradientT = Image.linear_gradient("L").resize((width, overlap_size_y))
|
||||
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
|
||||
agradientC = Image.new("L", (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
# Find distance to lower right corner (numpy takes arrays)
|
||||
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
|
||||
# Clamp values to max 255
|
||||
if distanceToLR > 255:
|
||||
distanceToLR = 255
|
||||
# Place the pixel as invert of distance
|
||||
agradientC.putpixel((x, y), round(255 - distanceToLR))
|
||||
|
||||
# Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
|
||||
# Fits for a left-fading gradient on the bottom side and full opacity on the right side.
|
||||
agradientAsymC = Image.new("L", (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
value = round(max(0, x - (255 - y)) * (255 / max(1, y)))
|
||||
# Clamp values
|
||||
value = max(0, value)
|
||||
value = min(255, value)
|
||||
agradientAsymC.putpixel((x, y), value)
|
||||
|
||||
# Create alpha layers default fully white
|
||||
alphaLayerL = Image.new("L", (width, height), 255)
|
||||
alphaLayerT = Image.new("L", (width, height), 255)
|
||||
alphaLayerLTC = Image.new("L", (width, height), 255)
|
||||
# Paste gradients into alpha layers
|
||||
alphaLayerL.paste(agradientL, (0, 0))
|
||||
alphaLayerT.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientL, (0, 0))
|
||||
alphaLayerLTC.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
||||
# make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
|
||||
# to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
|
||||
alphaLayerTaC = alphaLayerT.copy()
|
||||
alphaLayerTaC.paste(
|
||||
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
alphaLayerLTaC = alphaLayerLTC.copy()
|
||||
alphaLayerLTaC.paste(
|
||||
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
if embiggen_tiles:
|
||||
# Individual unconnected sides
|
||||
alphaLayerR = Image.new("L", (width, height), 255)
|
||||
alphaLayerR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerB = Image.new("L", (width, height), 255)
|
||||
alphaLayerB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerTB = Image.new("L", (width, height), 255)
|
||||
alphaLayerTB.paste(agradientT, (0, 0))
|
||||
alphaLayerTB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerLR = Image.new("L", (width, height), 255)
|
||||
alphaLayerLR.paste(agradientL, (0, 0))
|
||||
alphaLayerLR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
|
||||
# Sides and corner Layers
|
||||
alphaLayerRBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRBC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerRBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerRBC.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerLBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerLBC.paste(agradientL, (0, 0))
|
||||
alphaLayerLBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerLBC.paste(
|
||||
agradientC.rotate(90).resize((overlap_size_x, overlap_size_y)),
|
||||
(0, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerRTC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRTC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerRTC.paste(agradientT, (0, 0))
|
||||
alphaLayerRTC.paste(
|
||||
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
# All but X layers
|
||||
alphaLayerABT = Image.new("L", (width, height), 255)
|
||||
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABT.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerABT.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerABL = Image.new("L", (width, height), 255)
|
||||
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABL.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerABL.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerABR = Image.new("L", (width, height), 255)
|
||||
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABR.paste(agradientT, (0, 0))
|
||||
alphaLayerABR.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
alphaLayerABB = Image.new("L", (width, height), 255)
|
||||
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABB.paste(agradientL, (0, 0))
|
||||
alphaLayerABB.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
|
||||
# All-around layer
|
||||
alphaLayerAA = Image.new("L", (width, height), 255)
|
||||
alphaLayerAA.paste(alphaLayerABT, (0, 0))
|
||||
alphaLayerAA.paste(agradientT, (0, 0))
|
||||
alphaLayerAA.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
alphaLayerAA.paste(
|
||||
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
# Clean up temporary gradients
|
||||
del agradientL
|
||||
del agradientT
|
||||
del agradientC
|
||||
|
||||
def make_image():
|
||||
# Make main tiles -------------------------------------------------
|
||||
if embiggen_tiles:
|
||||
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||
else:
|
||||
logger.info(
|
||||
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||
)
|
||||
|
||||
emb_tile_store = []
|
||||
# Although we could use the same seed for every tile for determinism, at higher strengths this may
|
||||
# produce duplicated structures for each tile and make the tiling effect more obvious
|
||||
# instead track and iterate a local seed we pass to Img2Img
|
||||
seed = self.seed
|
||||
seedintlimit = (
|
||||
np.iinfo(np.uint32).max - 1
|
||||
) # only retreive this one from numpy
|
||||
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
# Don't iterate on first tile
|
||||
if tile != 0:
|
||||
if seed < seedintlimit:
|
||||
seed += 1
|
||||
else:
|
||||
seed = 0
|
||||
|
||||
# Determine if this is a re-run and replace
|
||||
if embiggen_tiles and not tile in embiggen_tiles:
|
||||
continue
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
# Determine bounds to cut up the init image
|
||||
# Determine upper-left point
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i * (width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
right = left + width
|
||||
bottom = top + height
|
||||
|
||||
# Cropped image of above dimension (does not modify the original)
|
||||
newinitimage = initsuperimage.crop((left, top, right, bottom))
|
||||
# DEBUG:
|
||||
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
|
||||
# newinitimage.save(newinitimagepath)
|
||||
|
||||
if embiggen_tiles:
|
||||
logger.debug(
|
||||
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||
|
||||
# create a torch tensor from an Image
|
||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
||||
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
|
||||
newinitimage = torch.from_numpy(newinitimage)
|
||||
newinitimage = 2.0 * newinitimage - 1.0
|
||||
newinitimage = newinitimage.to(self.model.device)
|
||||
clear_cuda_cache = (
|
||||
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
||||
)
|
||||
|
||||
tile_results = gen_img2img.generate(
|
||||
prompt,
|
||||
iterations=1,
|
||||
seed=seed,
|
||||
sampler=sampler,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
conditioning=conditioning,
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=None, # called only after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
width=width,
|
||||
height=height,
|
||||
init_image=newinitimage, # notice that init_image is different from init_img
|
||||
mask_image=None,
|
||||
strength=strength,
|
||||
clear_cuda_cache=clear_cuda_cache,
|
||||
)
|
||||
|
||||
emb_tile_store.append(tile_results[0][0])
|
||||
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
|
||||
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
|
||||
del newinitimage
|
||||
|
||||
# Sanity check we have them all
|
||||
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (
|
||||
embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)
|
||||
):
|
||||
outputsuperimage = Image.new("RGBA", (initsuperwidth, initsuperheight))
|
||||
if embiggen_tiles:
|
||||
outputsuperimage.alpha_composite(
|
||||
initsuperimage.convert("RGBA"), (0, 0)
|
||||
)
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
if embiggen_tiles:
|
||||
if tile in embiggen_tiles:
|
||||
intileimage = emb_tile_store.pop(0)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
intileimage = emb_tile_store[tile]
|
||||
intileimage = intileimage.convert("RGBA")
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles:
|
||||
left = 0
|
||||
top = 0
|
||||
else:
|
||||
# Determine upper-left point
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i * (width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
# Handle gradients for various conditions
|
||||
# Handle emb_rerun case
|
||||
if embiggen_tiles:
|
||||
# top of image
|
||||
if emb_row_i == 0:
|
||||
if emb_column_i == 0:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) not in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerB)
|
||||
# Otherwise do nothing on this tile
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRBC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
else:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerLR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABT)
|
||||
# bottom of image
|
||||
elif emb_row_i == emb_tiles_y - 1:
|
||||
if emb_column_i == 0:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
# No tiles to look ahead to
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
# vertical middle of image
|
||||
else:
|
||||
if emb_column_i == 0:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTB)
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABL)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
else:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerAA)
|
||||
# Handle normal tiling case (much simpler - since we tile left to right, top to bottom)
|
||||
else:
|
||||
if emb_row_i == 0 and emb_column_i >= 1:
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
elif emb_row_i >= 1 and emb_column_i == 0:
|
||||
if (
|
||||
emb_column_i + 1 == emb_tiles_x
|
||||
): # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerT)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
if (
|
||||
emb_column_i + 1 == emb_tiles_x
|
||||
): # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
# Layer tile onto final image
|
||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||
else:
|
||||
logger.error(
|
||||
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||
)
|
||||
|
||||
# after internal loops and patching up return Embiggen image
|
||||
return outputsuperimage
|
||||
|
||||
# end of function declaration
|
||||
return make_image
|
@ -22,7 +22,6 @@ class Img2Img(Generator):
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
|
@ -161,9 +161,7 @@ class Inpaint(Img2Img):
|
||||
im: Image.Image,
|
||||
seam_size: int,
|
||||
seam_blur: int,
|
||||
prompt,
|
||||
seed,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@ -177,8 +175,6 @@ class Inpaint(Img2Img):
|
||||
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
||||
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@ -203,8 +199,6 @@ class Inpaint(Img2Img):
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@ -306,7 +300,6 @@ class Inpaint(Img2Img):
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
@ -345,9 +338,7 @@ class Inpaint(Img2Img):
|
||||
result,
|
||||
seam_size,
|
||||
seam_blur,
|
||||
prompt,
|
||||
seed,
|
||||
sampler,
|
||||
seam_steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@ -360,8 +351,6 @@ class Inpaint(Img2Img):
|
||||
|
||||
# Restore original settings
|
||||
self.get_make_image(
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
|
@ -1,125 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
"""
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from diffusers.pipelines.controlnet import MultiControlNetModel
|
||||
|
||||
from ..stable_diffusion import (
|
||||
ConditioningData,
|
||||
PostprocessingSettings,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .base import Generator
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self, model, precision,
|
||||
control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
|
||||
**kwargs):
|
||||
self.control_model = control_model
|
||||
if isinstance(self.control_model, list):
|
||||
self.control_model = MultiControlNetModel(self.control_model)
|
||||
super().__init__(model, precision, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width,
|
||||
height,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
warmup=0.2,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
control_image = kwargs.get("control_image", None)
|
||||
do_classifier_free_guidance = cfg_scale > 1.0
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.control_model = self.control_model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
if control_image is not None:
|
||||
if isinstance(self.control_model, ControlNetModel):
|
||||
control_image = pipeline.prepare_control_image(
|
||||
image=control_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=width,
|
||||
height=height,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=self.control_model.device,
|
||||
dtype=self.control_model.dtype,
|
||||
)
|
||||
elif isinstance(self.control_model, MultiControlNetModel):
|
||||
images = []
|
||||
for image_ in control_image:
|
||||
image_ = pipeline.prepare_control_image(
|
||||
image=image_,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=width,
|
||||
height=height,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=self.control_model.device,
|
||||
dtype=self.control_model.dtype,
|
||||
)
|
||||
images.append(image_)
|
||||
control_image = images
|
||||
kwargs["control_image"] = control_image
|
||||
|
||||
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
||||
noise=x_T,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
@ -1,209 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
|
||||
|
||||
from ..stable_diffusion import PostprocessingSettings
|
||||
from .base import Generator
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
||||
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # for get_noise()
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt: str,
|
||||
sampler,
|
||||
steps: int,
|
||||
cfg_scale: float,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width: int,
|
||||
height: int,
|
||||
strength: float,
|
||||
step_callback: Optional[Callable] = None,
|
||||
threshold=0.0,
|
||||
warmup=0.2,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=0.2,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T: torch.Tensor, _: int):
|
||||
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(x_T),
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
noise=x_T,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# Get our initial generation width and height directly from the latent output so
|
||||
# the message below is accurate.
|
||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
||||
logger.info(
|
||||
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
first_pass_latent_output,
|
||||
size=(
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
),
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
# Free up memory from the last generation.
|
||||
clear_cuda_cache = kwargs["clear_cuda_cache"] or None
|
||||
if clear_cuda_cache is not None:
|
||||
clear_cuda_cache()
|
||||
|
||||
second_pass_noise = self.get_noise_like(
|
||||
resized_latents, override_perlin=True
|
||||
)
|
||||
|
||||
# Clear symmetry for the second pass
|
||||
from dataclasses import replace
|
||||
|
||||
new_postprocessing_settings = replace(
|
||||
conditioning_data.postprocessing_settings, h_symmetry_time_pct=None
|
||||
)
|
||||
new_postprocessing_settings = replace(
|
||||
new_postprocessing_settings, v_symmetry_time_pct=None
|
||||
)
|
||||
new_conditioning_data = replace(
|
||||
conditioning_data, postprocessing_settings=new_postprocessing_settings
|
||||
)
|
||||
|
||||
verbosity = get_verbosity()
|
||||
set_verbosity_error()
|
||||
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
|
||||
resized_latents,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=new_conditioning_data,
|
||||
strength=strength,
|
||||
noise=second_pass_noise,
|
||||
callback=step_callback,
|
||||
)
|
||||
set_verbosity(verbosity)
|
||||
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
# FIXME: do we really need something entirely different for the inpainting model?
|
||||
|
||||
# in the case of the inpainting model being loaded, the trick of
|
||||
# providing an interpolated latent doesn't work, so we transiently
|
||||
# create a 512x512 PIL image, upscale it, and run the inpainting
|
||||
# over it in img2img mode. Because the inpaing model is so conservative
|
||||
# it doesn't change the image (much)
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor, override_perlin: bool = False):
|
||||
device = like.device
|
||||
if device.type == "mps":
|
||||
x = torch.randn_like(like, device="cpu", dtype=self.torch_dtype()).to(
|
||||
device
|
||||
)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
|
||||
if self.perlin > 0.0 and override_perlin == False:
|
||||
shape = like.shape
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
|
||||
shape[3], shape[2]
|
||||
)
|
||||
return x
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self, width, height, scale=True):
|
||||
# print(f"Get noise: {width}x{height}")
|
||||
if scale:
|
||||
# Scale the input width and height for the initial generation
|
||||
# Make their area equivalent to the model's resolution area (e.g. 512*512 = 262144),
|
||||
# while keeping the minimum dimension at least 0.5 * resolution (e.g. 512*0.5 = 256)
|
||||
|
||||
aspect = width / height
|
||||
dimension = self.model.unet.config.sample_size * self.model.vae_scale_factor
|
||||
min_dimension = math.floor(dimension * 0.5)
|
||||
model_area = (
|
||||
dimension * dimension
|
||||
) # hardcoded for now since all models are trained on square images
|
||||
|
||||
if aspect > 1.0:
|
||||
init_height = max(min_dimension, math.sqrt(model_area / aspect))
|
||||
init_width = init_height * aspect
|
||||
else:
|
||||
init_width = max(min_dimension, math.sqrt(model_area * aspect))
|
||||
init_height = init_width / aspect
|
||||
|
||||
scaled_width, scaled_height = trim_to_multiple_of(
|
||||
math.floor(init_width), math.floor(init_height)
|
||||
)
|
||||
|
||||
else:
|
||||
scaled_width = width
|
||||
scaled_height = height
|
||||
|
||||
device = self.model.device
|
||||
channels = self.latent_channels
|
||||
if channels == 9:
|
||||
channels = 4 # we don't really want noise for all the mask channels
|
||||
shape = (
|
||||
1,
|
||||
channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor,
|
||||
)
|
||||
if self.use_mps_noise or device.type == "mps":
|
||||
tensor = torch.empty(size=shape, device="cpu")
|
||||
tensor = self.get_noise_like(like=tensor).to(device)
|
||||
else:
|
||||
tensor = torch.empty(size=shape, device=device)
|
||||
tensor = self.get_noise_like(like=tensor)
|
||||
return tensor
|
@ -9,6 +9,7 @@ SAMPLER_CHOICES = [
|
||||
"ddpm",
|
||||
"deis",
|
||||
"lms",
|
||||
"lms_k",
|
||||
"pndm",
|
||||
"heun",
|
||||
"heun_k",
|
||||
@ -18,8 +19,13 @@ SAMPLER_CHOICES = [
|
||||
"kdpm_2",
|
||||
"kdpm_2_a",
|
||||
"dpmpp_2s",
|
||||
"dpmpp_2s_k",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"dpmpp_2m_sde",
|
||||
"dpmpp_2m_sde_k",
|
||||
"dpmpp_sde",
|
||||
"dpmpp_sde_k",
|
||||
"unipc",
|
||||
]
|
||||
|
||||
|
@ -556,8 +556,8 @@ class ModelPatcher:
|
||||
new_tokens_added = None
|
||||
|
||||
try:
|
||||
ti_manager = TextualInversionManager()
|
||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||
|
||||
def _get_trigger(ti, index):
|
||||
@ -650,22 +650,24 @@ class TextualInversionModel:
|
||||
|
||||
class TextualInversionManager(BaseTextualInversionManager):
|
||||
pad_tokens: Dict[int, List[int]]
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.pad_tokens = dict()
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(
|
||||
self, token_ids: list[int]
|
||||
) -> list[int]:
|
||||
|
||||
#if token_ids[0] == self.tokenizer.bos_token_id:
|
||||
# raise ValueError("token_ids must not start with bos_token_id")
|
||||
#if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
# raise ValueError("token_ids must not end with eos_token_id")
|
||||
|
||||
if len(self.pad_tokens) == 0:
|
||||
return token_ids
|
||||
|
||||
if token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("token_ids must not start with bos_token_id")
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("token_ids must not end with eos_token_id")
|
||||
|
||||
new_token_ids = []
|
||||
for token_id in token_ids:
|
||||
new_token_ids.append(token_id)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Optional
|
||||
from .base import (
|
||||
|
@ -1,9 +0,0 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.prompting
|
||||
"""
|
||||
from .conditioning import (
|
||||
get_prompt_structure,
|
||||
get_tokens_for_prompt_object,
|
||||
get_uc_and_c_and_ec,
|
||||
split_weighted_subprompts,
|
||||
)
|
@ -1,297 +0,0 @@
|
||||
"""
|
||||
This module handles the generation of the conditioning tensors.
|
||||
|
||||
Useful function exports:
|
||||
|
||||
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
|
||||
|
||||
"""
|
||||
import re
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (
|
||||
Blend,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt,
|
||||
Fragment,
|
||||
PromptParser,
|
||||
Conjunction,
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||
from ..util import torch_dtype
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
def get_uc_and_c_and_ec(prompt_string,
|
||||
model: InvokeAIDiffuserComponent,
|
||||
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||
|
||||
compel = Compel(tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
# get rid of any newline characters
|
||||
prompt_string = prompt_string.replace("\n", " ")
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_conjunction: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
||||
|
||||
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
|
||||
if log_tokens or config.log_tokenization:
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get(
|
||||
'cross_attention_control', None))
|
||||
return uc, c, ec
|
||||
|
||||
def get_prompt_structure(
|
||||
prompt_string, skip_normalize_legacy_blend: bool = False
|
||||
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
(
|
||||
positive_prompt_string,
|
||||
negative_prompt_string,
|
||||
) = split_prompt_to_positive_and_negative(prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(
|
||||
positive_prompt_string, skip_normalize_legacy_blend
|
||||
)
|
||||
positive_prompt: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
|
||||
|
||||
return positive_prompt, negative_prompt
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max(
|
||||
[
|
||||
get_max_token_count(tokenizer, c, truncate_if_too_long)
|
||||
for c in blend.prompts
|
||||
]
|
||||
)
|
||||
else:
|
||||
return len(
|
||||
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
|
||||
)
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
||||
) -> [str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError(
|
||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
||||
)
|
||||
|
||||
text_fragments = [
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (
|
||||
" ".join([f.text for f in x.original])
|
||||
if type(x) is CrossAttentionControlSubstitute
|
||||
else str(x)
|
||||
)
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
if truncate_if_too_long:
|
||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
||||
tokens = tokens[0:max_tokens_length]
|
||||
return tokens
|
||||
|
||||
|
||||
def split_prompt_to_positive_and_negative(prompt_string_uncleaned: str):
|
||||
unconditioned_words = ""
|
||||
unconditional_regex = r"\[(.*?)\]"
|
||||
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
|
||||
if len(unconditionals) > 0:
|
||||
unconditioned_words = " ".join(unconditionals)
|
||||
|
||||
# Remove Unconditioned Words From Prompt
|
||||
unconditional_regex_compile = re.compile(unconditional_regex)
|
||||
clean_prompt = unconditional_regex_compile.sub(" ", prompt_string_uncleaned)
|
||||
prompt_string_cleaned = re.sub(" +", " ", clean_prompt)
|
||||
else:
|
||||
prompt_string_cleaned = prompt_string_uncleaned
|
||||
return prompt_string_cleaned, unconditioned_words
|
||||
|
||||
|
||||
def log_tokenization(
|
||||
positive_prompt: Union[Blend, FlattenedPrompt],
|
||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
||||
tokenizer,
|
||||
):
|
||||
logger.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||
logger.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||
|
||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
||||
log_tokenization_for_prompt_object(
|
||||
negative_prompt, tokenizer, display_label_prefix="(negative prompt)"
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||
):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
for i, c in enumerate(blend.prompts):
|
||||
log_tokenization_for_prompt_object(
|
||||
c,
|
||||
tokenizer,
|
||||
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
|
||||
)
|
||||
elif type(p) is FlattenedPrompt:
|
||||
flattened_prompt: FlattenedPrompt = p
|
||||
if flattened_prompt.wants_cross_attention_control:
|
||||
original_fragments = []
|
||||
edited_fragments = []
|
||||
for f in flattened_prompt.children:
|
||||
if type(f) is CrossAttentionControlSubstitute:
|
||||
original_fragments += f.original
|
||||
edited_fragments += f.edited
|
||||
else:
|
||||
original_fragments.append(f)
|
||||
edited_fragments.append(f)
|
||||
|
||||
original_text = " ".join([x.text for x in original_fragments])
|
||||
log_tokenization_for_text(
|
||||
original_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap originals)",
|
||||
)
|
||||
edited_text = " ".join([x.text for x in edited_fragments])
|
||||
log_tokenization_for_text(
|
||||
edited_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap replacements)",
|
||||
)
|
||||
else:
|
||||
text = " ".join([x.text for x in flattened_prompt.children])
|
||||
log_tokenization_for_text(
|
||||
text, tokenizer, display_label=display_label_prefix
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
"""
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace("</w>", " ")
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if truncate_if_too_long and i >= tokenizer.model_max_length:
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
else:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
|
||||
if usedTokens > 0:
|
||||
logger.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
logger.debug(f"{tokenized}\x1b[0m")
|
||||
|
||||
if discarded != "":
|
||||
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
logger.debug(f"{discarded}\x1b[0m")
|
||||
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
|
||||
pp = PromptParser()
|
||||
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = []
|
||||
weights = []
|
||||
for i, x in enumerate(parsed_conjunctions):
|
||||
if len(x.prompts)>0:
|
||||
flattened_prompts.append(x.prompts[0])
|
||||
weights.append(weighted_subprompts[i][1])
|
||||
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||
"""
|
||||
Legacy blend parsing.
|
||||
|
||||
grabs all text up to the first occurrence of ':'
|
||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||
if ':' has no value defined, defaults to 1.0
|
||||
repeats until no text remaining
|
||||
"""
|
||||
prompt_parser = re.compile(
|
||||
"""
|
||||
(?P<prompt> # capture group for 'prompt'
|
||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||
) # end 'prompt'
|
||||
(?: # non-capture group
|
||||
:+ # match one or more ':' characters
|
||||
(?P<weight> # capture group for 'weight'
|
||||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||||
)? # end weight capture group, make optional
|
||||
\s* # strip spaces after weight
|
||||
| # OR
|
||||
$ # else, if no ':' then match end of line
|
||||
) # end non-capture group
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
parsed_prompts = [
|
||||
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
|
||||
for match in re.finditer(prompt_parser, text)
|
||||
]
|
||||
if len(parsed_prompts) == 0:
|
||||
return []
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
logger.warning(
|
||||
"Subprompt weights add up to zero. Discarding and using even weights instead."
|
||||
)
|
||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
@ -1,7 +1,6 @@
|
||||
"""
|
||||
Initialization file for the invokeai.backend.stable_diffusion package
|
||||
"""
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
from .diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
PipelineIntermediateState,
|
||||
@ -10,4 +9,3 @@ from .diffusers_pipeline import (
|
||||
from .diffusion import InvokeAIDiffuserComponent
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
|
@ -1,275 +0,0 @@
|
||||
"""
|
||||
Query and install embeddings from the HuggingFace SD Concepts Library
|
||||
at https://huggingface.co/sd-concepts-library.
|
||||
|
||||
The interface is through the Concepts() object.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
from typing import Callable
|
||||
from urllib import error as ul_error
|
||||
from urllib import request
|
||||
|
||||
from huggingface_hub import (
|
||||
HfApi,
|
||||
HfFolder,
|
||||
ModelFilter,
|
||||
hf_hub_url,
|
||||
)
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
class HuggingFaceConceptsLibrary(object):
|
||||
def __init__(self, root=None):
|
||||
"""
|
||||
Initialize the Concepts object. May optionally pass a root directory.
|
||||
"""
|
||||
self.config = InvokeAIAppConfig.get_config()
|
||||
self.root = root or self.config.root
|
||||
self.hf_api = HfApi()
|
||||
self.local_concepts = dict()
|
||||
self.concept_list = None
|
||||
self.concepts_loaded = dict()
|
||||
self.triggers = dict() # concept name to trigger phrase
|
||||
self.concept_names = dict() # trigger phrase to concept name
|
||||
self.match_trigger = re.compile(
|
||||
"(<[\w\- >]+>)"
|
||||
) # trigger is slightly less restrictive than HF concept name
|
||||
self.match_concept = re.compile(
|
||||
"<([\w\-]+)>"
|
||||
) # HF concept name can only contain A-Za-z0-9_-
|
||||
|
||||
def list_concepts(self) -> list:
|
||||
"""
|
||||
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
|
||||
Also adds local concepts in invokeai/embeddings folder.
|
||||
"""
|
||||
local_concepts_now = self.get_local_concepts(
|
||||
os.path.join(self.root, "embeddings")
|
||||
)
|
||||
local_concepts_to_add = set(local_concepts_now).difference(
|
||||
set(self.local_concepts)
|
||||
)
|
||||
self.local_concepts.update(local_concepts_now)
|
||||
|
||||
if self.concept_list is not None:
|
||||
if local_concepts_to_add:
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
return self.concept_list
|
||||
return self.concept_list
|
||||
elif self.config.internet_available is True:
|
||||
try:
|
||||
models = self.hf_api.list_models(
|
||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||
)
|
||||
self.concept_list = [a.id.split("/")[1] for a in models]
|
||||
# when init, add all in dir. when not init, add only concepts added between init and now
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||
)
|
||||
logger.warning(
|
||||
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||
)
|
||||
return self.concept_list
|
||||
else:
|
||||
return self.concept_list
|
||||
|
||||
def get_concept_model_path(self, concept_name: str) -> str:
|
||||
"""
|
||||
Returns the path to the 'learned_embeds.bin' file in
|
||||
the named concept. Returns None if invalid or cannot
|
||||
be downloaded.
|
||||
"""
|
||||
if not concept_name in self.list_concepts():
|
||||
logger.warning(
|
||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||
)
|
||||
return None
|
||||
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
|
||||
|
||||
def concept_to_trigger(self, concept_name: str) -> str:
|
||||
"""
|
||||
Given a concept name returns its trigger by looking in the
|
||||
"token_identifier.txt" file.
|
||||
"""
|
||||
if concept_name in self.triggers:
|
||||
return self.triggers[concept_name]
|
||||
elif self.concept_is_local(concept_name):
|
||||
trigger = f"<{concept_name}>"
|
||||
self.triggers[concept_name] = trigger
|
||||
self.concept_names[trigger] = concept_name
|
||||
return trigger
|
||||
|
||||
file = self.get_concept_file(
|
||||
concept_name, "token_identifier.txt", local_only=True
|
||||
)
|
||||
if not file:
|
||||
return None
|
||||
with open(file, "r") as f:
|
||||
trigger = f.readline()
|
||||
trigger = trigger.strip()
|
||||
self.triggers[concept_name] = trigger
|
||||
self.concept_names[trigger] = concept_name
|
||||
return trigger
|
||||
|
||||
def trigger_to_concept(self, trigger: str) -> str:
|
||||
"""
|
||||
Given a trigger phrase, maps it to the concept library name.
|
||||
Only works if concept_to_trigger() has previously been called
|
||||
on this library. There needs to be a persistent database for
|
||||
this.
|
||||
"""
|
||||
concept = self.concept_names.get(trigger, None)
|
||||
return f"<{concept}>" if concept else f"{trigger}"
|
||||
|
||||
def replace_triggers_with_concepts(self, prompt: str) -> str:
|
||||
"""
|
||||
Given a prompt string that contains <trigger> tags, replace these
|
||||
tags with the concept name. The reason for this is so that the
|
||||
concept names get stored in the prompt metadata. There is no
|
||||
controlling of colliding triggers in the SD library, so it is
|
||||
better to store the concept name (unique) than the concept trigger
|
||||
(not necessarily unique!)
|
||||
"""
|
||||
if not prompt:
|
||||
return prompt
|
||||
triggers = self.match_trigger.findall(prompt)
|
||||
if not triggers:
|
||||
return prompt
|
||||
|
||||
def do_replace(match) -> str:
|
||||
return self.trigger_to_concept(match.group(1)) or f"<{match.group(1)}>"
|
||||
|
||||
return self.match_trigger.sub(do_replace, prompt)
|
||||
|
||||
def replace_concepts_with_triggers(
|
||||
self,
|
||||
prompt: str,
|
||||
load_concepts_callback: Callable[[list], any],
|
||||
excluded_tokens: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Given a prompt string that contains `<concept_name>` tags, replace
|
||||
these tags with the appropriate trigger.
|
||||
|
||||
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
||||
of `concepts_name` strings.
|
||||
|
||||
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
||||
are trigger tokens from a locally-loaded embedding.
|
||||
"""
|
||||
concepts = self.match_concept.findall(prompt)
|
||||
if not concepts:
|
||||
return prompt
|
||||
load_concepts_callback(concepts)
|
||||
|
||||
def do_replace(match) -> str:
|
||||
if excluded_tokens and f"<{match.group(1)}>" in excluded_tokens:
|
||||
return f"<{match.group(1)}>"
|
||||
return self.concept_to_trigger(match.group(1)) or f"<{match.group(1)}>"
|
||||
|
||||
return self.match_concept.sub(do_replace, prompt)
|
||||
|
||||
def get_concept_file(
|
||||
self,
|
||||
concept_name: str,
|
||||
file_name: str = "learned_embeds.bin",
|
||||
local_only: bool = False,
|
||||
) -> str:
|
||||
if not (
|
||||
self.concept_is_downloaded(concept_name)
|
||||
or self.concept_is_local(concept_name)
|
||||
or local_only
|
||||
):
|
||||
self.download_concept(concept_name)
|
||||
|
||||
# get local path in invokeai/embeddings if local concept
|
||||
if self.concept_is_local(concept_name):
|
||||
concept_path = self._concept_local_path(concept_name)
|
||||
path = concept_path
|
||||
else:
|
||||
concept_path = self._concept_path(concept_name)
|
||||
path = os.path.join(concept_path, file_name)
|
||||
return path if os.path.exists(path) else None
|
||||
|
||||
def concept_is_local(self, concept_name) -> bool:
|
||||
return concept_name in self.local_concepts
|
||||
|
||||
def concept_is_downloaded(self, concept_name) -> bool:
|
||||
concept_directory = self._concept_path(concept_name)
|
||||
return os.path.exists(concept_directory)
|
||||
|
||||
def download_concept(self, concept_name) -> bool:
|
||||
repo_id = self._concept_id(concept_name)
|
||||
dest = self._concept_path(concept_name)
|
||||
|
||||
access_token = HfFolder.get_token()
|
||||
header = [("Authorization", f"Bearer {access_token}")] if access_token else []
|
||||
opener = request.build_opener()
|
||||
opener.addheaders = header
|
||||
request.install_opener(opener)
|
||||
|
||||
os.makedirs(dest, exist_ok=True)
|
||||
succeeded = True
|
||||
|
||||
bytes = 0
|
||||
|
||||
def tally_download_size(chunk, size, total):
|
||||
nonlocal bytes
|
||||
if chunk == 0:
|
||||
bytes += total
|
||||
|
||||
logger.info(f"Downloading {repo_id}...", end="")
|
||||
try:
|
||||
for file in (
|
||||
"README.md",
|
||||
"learned_embeds.bin",
|
||||
"token_identifier.txt",
|
||||
"type_of_concept.txt",
|
||||
):
|
||||
url = hf_hub_url(repo_id, file)
|
||||
request.urlretrieve(
|
||||
url, os.path.join(dest, file), reporthook=tally_download_size
|
||||
)
|
||||
except ul_error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
logger.warning(
|
||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
except ul_error.URLError as e:
|
||||
logger.error(
|
||||
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
logger.info("...{:.2f}Kb".format(bytes / 1024))
|
||||
return succeeded
|
||||
|
||||
def _concept_id(self, concept_name: str) -> str:
|
||||
return f"sd-concepts-library/{concept_name}"
|
||||
|
||||
def _concept_path(self, concept_name: str) -> str:
|
||||
return os.path.join(self.root, "models", "sd-concepts-library", concept_name)
|
||||
|
||||
def _concept_local_path(self, concept_name: str) -> str:
|
||||
filename = self.local_concepts[concept_name]
|
||||
return os.path.join(self.root, "embeddings", filename)
|
||||
|
||||
def get_local_concepts(self, loc_dir: str):
|
||||
locs_dic = dict()
|
||||
if os.path.isdir(loc_dir):
|
||||
for file in os.listdir(loc_dir):
|
||||
f = os.path.splitext(file)
|
||||
if f[1] == ".bin" or f[1] == ".pt":
|
||||
locs_dic[f[0]] = file
|
||||
return locs_dic
|
@ -16,7 +16,6 @@ from accelerate.utils import set_seed
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from compel import EmbeddingsProvider
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
@ -48,7 +47,6 @@ from .diffusion import (
|
||||
PostprocessingSettings,
|
||||
)
|
||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
@ -317,6 +315,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
requires_safety_checker: bool = False,
|
||||
precision: str = "float32",
|
||||
control_model: ControlNetModel = None,
|
||||
execution_device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
@ -341,22 +340,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# control_model=control_model,
|
||||
)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||
self.unet, self._unet_forward, is_running_diffusers=True
|
||||
)
|
||||
use_full_precision = precision == "float32" or precision == "autocast"
|
||||
self.textual_inversion_manager = TextualInversionManager(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
full_precision=use_full_precision,
|
||||
)
|
||||
# InvokeAI's interface for text embeddings and whatnot
|
||||
self.embeddings_provider = EmbeddingsProvider(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
textual_inversion_manager=self.textual_inversion_manager,
|
||||
self.unet, self._unet_forward
|
||||
)
|
||||
|
||||
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
||||
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
||||
self._model_group.install(*self._submodels)
|
||||
self.control_model = control_model
|
||||
|
||||
@ -404,50 +391,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
|
||||
def enable_offload_submodels(self, device: torch.device):
|
||||
"""
|
||||
Offload each submodel when it's not in use.
|
||||
|
||||
Useful for low-vRAM situations where the size of the model in memory is a big chunk of
|
||||
the total available resource, and you want to free up as much for inference as possible.
|
||||
|
||||
This requires more moving parts and may add some delay as the U-Net is swapped out for the
|
||||
VAE and vice-versa.
|
||||
"""
|
||||
models = self._submodels
|
||||
if self._model_group is not None:
|
||||
self._model_group.uninstall(*models)
|
||||
group = LazilyLoadedModelGroup(device)
|
||||
group.install(*models)
|
||||
self._model_group = group
|
||||
|
||||
def disable_offload_submodels(self):
|
||||
"""
|
||||
Leave all submodels loaded.
|
||||
|
||||
Appropriate for cases where the size of the model in memory is small compared to the memory
|
||||
required for inference. Avoids the delay and complexity of shuffling the submodels to and
|
||||
from the GPU.
|
||||
"""
|
||||
models = self._submodels
|
||||
if self._model_group is not None:
|
||||
self._model_group.uninstall(*models)
|
||||
group = FullyLoadedModelGroup(self._model_group.execution_device)
|
||||
group.install(*models)
|
||||
self._model_group = group
|
||||
|
||||
def offload_all(self):
|
||||
"""Offload all this pipeline's models to CPU."""
|
||||
self._model_group.offload_current()
|
||||
|
||||
def ready(self):
|
||||
"""
|
||||
Ready this pipeline's models.
|
||||
|
||||
i.e. preload them to the GPU if appropriate.
|
||||
"""
|
||||
self._model_group.ready()
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
# overridden method; types match the superclass.
|
||||
if torch_device is None:
|
||||
@ -991,25 +934,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
device = self._model_group.device_for(self.safety_checker)
|
||||
return super().run_safety_checker(image, device, dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(
|
||||
self, c: List[List[str]], *, return_tokens=True, fragment_weights=None
|
||||
):
|
||||
"""
|
||||
Compatibility function for invokeai.models.diffusion.ddpm.LatentDiffusion.
|
||||
"""
|
||||
return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments(
|
||||
text_batch=c,
|
||||
fragment_weights_batch=fragment_weights,
|
||||
should_return_tokens=return_tokens,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
|
||||
@property
|
||||
def channels(self) -> int:
|
||||
"""Compatible with DiffusionWrapper"""
|
||||
return self.unet.config.in_channels
|
||||
|
||||
def decode_latents(self, latents):
|
||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||
self._model_group.load(self.vae)
|
||||
|
@ -18,7 +18,6 @@ from .cross_attention_control import (
|
||||
CrossAttentionType,
|
||||
SwapCrossAttnContext,
|
||||
get_cross_attention_modules,
|
||||
restore_default_cross_attention,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
)
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
@ -66,7 +65,6 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
model,
|
||||
model_forward_callback: ModelForwardCallback,
|
||||
is_running_diffusers: bool = False,
|
||||
):
|
||||
"""
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
@ -75,7 +73,6 @@ class InvokeAIDiffuserComponent:
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
self.conditioning = None
|
||||
self.model = model
|
||||
self.is_running_diffusers = is_running_diffusers
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = config.sequential_guidance
|
||||
@ -112,37 +109,6 @@ class InvokeAIDiffuserComponent:
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
# apparently unused code
|
||||
# TODO: delete
|
||||
# def override_cross_attention(
|
||||
# self, conditioning: ExtraConditioningInfo, step_count: int
|
||||
# ) -> Dict[str, AttentionProcessor]:
|
||||
# """
|
||||
# setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
# the previous attention processor is returned so that the caller can restore it later.
|
||||
# """
|
||||
# self.conditioning = conditioning
|
||||
# self.cross_attention_control_context = Context(
|
||||
# arguments=self.conditioning.cross_attention_control_args,
|
||||
# step_count=step_count,
|
||||
# )
|
||||
# return override_cross_attention(
|
||||
# self.model,
|
||||
# self.cross_attention_control_context,
|
||||
# is_running_diffusers=self.is_running_diffusers,
|
||||
# )
|
||||
|
||||
def restore_default_cross_attention(
|
||||
self, restore_attention_processor: Optional["AttentionProcessor"] = None
|
||||
):
|
||||
self.conditioning = None
|
||||
self.cross_attention_control_context = None
|
||||
restore_default_cross_attention(
|
||||
self.model,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
restore_attention_processor=restore_attention_processor,
|
||||
)
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
if dim is not None:
|
||||
@ -204,9 +170,7 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = self.calculate_percent_through(
|
||||
sigma, step_index, total_step_count
|
||||
)
|
||||
percent_through = step_index / total_step_count
|
||||
cross_attention_control_types_to_do = (
|
||||
context.get_active_cross_attention_control_types_for_step(
|
||||
percent_through
|
||||
@ -264,9 +228,7 @@ class InvokeAIDiffuserComponent:
|
||||
total_step_count,
|
||||
) -> torch.Tensor:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = self.calculate_percent_through(
|
||||
sigma, step_index, total_step_count
|
||||
)
|
||||
percent_through = step_index / total_step_count
|
||||
latents = self.apply_threshold(
|
||||
postprocessing_settings, latents, percent_through
|
||||
)
|
||||
@ -275,22 +237,6 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
return latents
|
||||
|
||||
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
||||
if step_index is not None and total_step_count is not None:
|
||||
# 🧨diffusers codepath
|
||||
percent_through = (
|
||||
step_index / total_step_count
|
||||
) # will never reach 1.0 - this is deliberate
|
||||
else:
|
||||
# legacy compvis codepath
|
||||
# TODO remove when compvis codepath support is dropped
|
||||
if step_index is None and sigma is None:
|
||||
raise ValueError(
|
||||
"Either step_index or sigma is required when doing cross attention control, but both are None."
|
||||
)
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
return percent_through
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
@ -323,6 +269,7 @@ class InvokeAIDiffuserComponent:
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
# TODO: looks unused
|
||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
assert isinstance(conditioning, dict)
|
||||
assert isinstance(unconditioning, dict)
|
||||
@ -350,34 +297,6 @@ class InvokeAIDiffuserComponent:
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
if self.is_running_diffusers:
|
||||
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self._apply_cross_attention_controlled_conditioning__compvis(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__diffusers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
@ -409,54 +328,6 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__compvis(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
||||
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
|
||||
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
try:
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
# print("saving attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_save_attention_maps(ca_type)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
|
||||
context.clear_requests(cleanup=False)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
# print("applying saved attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_apply_saved_attention_maps(ca_type)
|
||||
edited_conditioning = (
|
||||
self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
)
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, edited_conditioning, **kwargs,
|
||||
)
|
||||
context.clear_requests(cleanup=True)
|
||||
|
||||
except:
|
||||
context.clear_requests(cleanup=True)
|
||||
raise
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||
|
@ -1,13 +1,14 @@
|
||||
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
|
||||
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
|
||||
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
|
||||
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler
|
||||
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler, DPMSolverSDEScheduler
|
||||
|
||||
SCHEDULER_MAP = dict(
|
||||
ddim=(DDIMScheduler, dict()),
|
||||
ddpm=(DDPMScheduler, dict()),
|
||||
deis=(DEISMultistepScheduler, dict()),
|
||||
lms=(LMSDiscreteScheduler, dict()),
|
||||
lms=(LMSDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
lms_k=(LMSDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
pndm=(PNDMScheduler, dict()),
|
||||
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
@ -16,8 +17,13 @@ SCHEDULER_MAP = dict(
|
||||
euler_a=(EulerAncestralDiscreteScheduler, dict()),
|
||||
kdpm_2=(KDPM2DiscreteScheduler, dict()),
|
||||
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
|
||||
dpmpp_2s=(DPMSolverSinglestepScheduler, dict()),
|
||||
dpmpp_2s=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=False)),
|
||||
dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)),
|
||||
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
|
||||
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
|
||||
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type='sde-dpmsolver++')),
|
||||
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type='sde-dpmsolver++')),
|
||||
dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)),
|
||||
dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)),
|
||||
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
|
||||
)
|
||||
|
@ -1,429 +0,0 @@
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, List
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
|
||||
@dataclass
|
||||
class EmbeddingInfo:
|
||||
name: str
|
||||
embedding: torch.Tensor
|
||||
num_vectors_per_token: int
|
||||
token_dim: int
|
||||
trained_steps: int = None
|
||||
trained_model_name: str = None
|
||||
trained_model_checksum: str = None
|
||||
|
||||
@dataclass
|
||||
class TextualInversion:
|
||||
trigger_string: str
|
||||
embedding: torch.Tensor
|
||||
trigger_token_id: Optional[int] = None
|
||||
pad_token_ids: Optional[list[int]] = None
|
||||
|
||||
@property
|
||||
def embedding_vector_length(self) -> int:
|
||||
return self.embedding.shape[0]
|
||||
|
||||
|
||||
class TextualInversionManager(BaseTextualInversionManager):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
full_precision: bool = True,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.full_precision = full_precision
|
||||
self.hf_concepts_library = HuggingFaceConceptsLibrary()
|
||||
self.trigger_to_sourcefile = dict()
|
||||
default_textual_inversions: list[TextualInversion] = []
|
||||
self.textual_inversions = default_textual_inversions
|
||||
|
||||
def load_huggingface_concepts(self, concepts: list[str]):
|
||||
for concept_name in concepts:
|
||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||
continue
|
||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||
if (
|
||||
self.has_textual_inversion_for_trigger_string(trigger)
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name)
|
||||
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
||||
): # in case a token with literal angle brackets encountered
|
||||
logger.info(f"Loaded local embedding for trigger {concept_name}")
|
||||
continue
|
||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||
if not bin_file:
|
||||
continue
|
||||
logger.info(f"Loaded remote embedding for trigger {concept_name}")
|
||||
self.load_textual_inversion(bin_file)
|
||||
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
||||
|
||||
def get_all_trigger_strings(self) -> list[str]:
|
||||
return [ti.trigger_string for ti in self.textual_inversions]
|
||||
|
||||
def load_textual_inversion(
|
||||
self, ckpt_path: Union[str, Path], defer_injecting_tokens: bool = False
|
||||
):
|
||||
ckpt_path = Path(ckpt_path)
|
||||
|
||||
if not ckpt_path.is_file():
|
||||
return
|
||||
|
||||
if str(ckpt_path).endswith(".DS_Store"):
|
||||
return
|
||||
|
||||
embedding_list = self._parse_embedding(str(ckpt_path))
|
||||
for embedding_info in embedding_list:
|
||||
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
||||
logger.warning(
|
||||
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||
)
|
||||
continue
|
||||
|
||||
# Resolve the situation in which an earlier embedding has claimed the same
|
||||
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
||||
trigger_str = embedding_info.name
|
||||
sourcefile = (
|
||||
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else ckpt_path.name
|
||||
)
|
||||
|
||||
if trigger_str in self.trigger_to_sourcefile:
|
||||
replacement_trigger_str = (
|
||||
f"<{ckpt_path.parent.name}>"
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else f"<{ckpt_path.stem}>"
|
||||
)
|
||||
logger.info(
|
||||
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||
)
|
||||
trigger_str = replacement_trigger_str
|
||||
|
||||
try:
|
||||
self._add_textual_inversion(
|
||||
trigger_str,
|
||||
embedding_info.embedding,
|
||||
defer_injecting_tokens=defer_injecting_tokens,
|
||||
)
|
||||
# remember which source file claims this trigger
|
||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||
|
||||
except ValueError as e:
|
||||
logger.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
logger.debug(f"The error was {str(e)}")
|
||||
|
||||
def _add_textual_inversion(
|
||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||
) -> Optional[TextualInversion]:
|
||||
"""
|
||||
Add a textual inversion to be recognised.
|
||||
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
||||
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
|
||||
:return: The token id for the added embedding, either existing or newly-added.
|
||||
"""
|
||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||
logger.warning(
|
||||
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||
)
|
||||
return
|
||||
if not self.full_precision:
|
||||
embedding = embedding.half()
|
||||
if len(embedding.shape) == 1:
|
||||
embedding = embedding.unsqueeze(0)
|
||||
elif len(embedding.shape) > 2:
|
||||
raise ValueError(
|
||||
f"** TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
|
||||
)
|
||||
|
||||
try:
|
||||
ti = TextualInversion(trigger_string=trigger_str, embedding=embedding)
|
||||
if not defer_injecting_tokens:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
self.textual_inversions.append(ti)
|
||||
return ti
|
||||
|
||||
except ValueError as e:
|
||||
if str(e).startswith("Warning"):
|
||||
logger.warning(f"{str(e)}")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
logger.error(
|
||||
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||
)
|
||||
raise
|
||||
|
||||
def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int:
|
||||
if ti.trigger_token_id is not None:
|
||||
raise ValueError(
|
||||
f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'"
|
||||
)
|
||||
|
||||
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(
|
||||
ti.trigger_string, ti.embedding[0]
|
||||
)
|
||||
|
||||
if ti.embedding_vector_length > 1:
|
||||
# for embeddings with vector length > 1
|
||||
pad_token_strings = [
|
||||
ti.trigger_string + "-!pad-" + str(pad_index)
|
||||
for pad_index in range(1, ti.embedding_vector_length)
|
||||
]
|
||||
# todo: batched UI for faster loading when vector length >2
|
||||
pad_token_ids = [
|
||||
self._get_or_create_token_id_and_assign_embedding(
|
||||
pad_token_str, ti.embedding[1 + i]
|
||||
)
|
||||
for (i, pad_token_str) in enumerate(pad_token_strings)
|
||||
]
|
||||
else:
|
||||
pad_token_ids = []
|
||||
|
||||
ti.trigger_token_id = trigger_token_id
|
||||
ti.pad_token_ids = pad_token_ids
|
||||
return ti.trigger_token_id
|
||||
|
||||
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
|
||||
try:
|
||||
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
|
||||
return ti is not None
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
def get_textual_inversion_for_trigger_string(
|
||||
self, trigger_string: str
|
||||
) -> TextualInversion:
|
||||
return next(
|
||||
ti for ti in self.textual_inversions if ti.trigger_string == trigger_string
|
||||
)
|
||||
|
||||
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
|
||||
return next(
|
||||
ti for ti in self.textual_inversions if ti.trigger_token_id == token_id
|
||||
)
|
||||
|
||||
def create_deferred_token_ids_for_any_trigger_terms(
|
||||
self, prompt_string: str
|
||||
) -> list[int]:
|
||||
injected_token_ids = []
|
||||
for ti in self.textual_inversions:
|
||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||
if ti.embedding_vector_length > 1:
|
||||
logger.info(
|
||||
f"Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||
)
|
||||
try:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
except ValueError as e:
|
||||
logger.debug(
|
||||
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||
)
|
||||
logger.debug(f"The error was {str(e)}")
|
||||
continue
|
||||
injected_token_ids.append(ti.trigger_token_id)
|
||||
injected_token_ids.extend(ti.pad_token_ids)
|
||||
return injected_token_ids
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(
|
||||
self, prompt_token_ids: list[int]
|
||||
) -> list[int]:
|
||||
"""
|
||||
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
|
||||
|
||||
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
|
||||
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
|
||||
long - caller is responsible for prepending/appending eos and bos token ids, and truncating if necessary.
|
||||
"""
|
||||
if len(prompt_token_ids) == 0:
|
||||
return prompt_token_ids
|
||||
|
||||
if prompt_token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("prompt_token_ids must not start with bos_token_id")
|
||||
if prompt_token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("prompt_token_ids must not end with eos_token_id")
|
||||
textual_inversion_trigger_token_ids = [
|
||||
ti.trigger_token_id for ti in self.textual_inversions
|
||||
]
|
||||
prompt_token_ids = prompt_token_ids.copy()
|
||||
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
||||
if token_id in textual_inversion_trigger_token_ids:
|
||||
textual_inversion = next(
|
||||
ti
|
||||
for ti in self.textual_inversions
|
||||
if ti.trigger_token_id == token_id
|
||||
)
|
||||
for pad_idx in range(0, textual_inversion.embedding_vector_length - 1):
|
||||
prompt_token_ids.insert(
|
||||
i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx]
|
||||
)
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
def _get_or_create_token_id_and_assign_embedding(
|
||||
self, token_str: str, embedding: torch.Tensor
|
||||
) -> int:
|
||||
if len(embedding.shape) != 1:
|
||||
raise ValueError(
|
||||
"Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2"
|
||||
)
|
||||
existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if existing_token_id == self.tokenizer.unk_token_id:
|
||||
num_tokens_added = self.tokenizer.add_tokens(token_str)
|
||||
current_embeddings = self.text_encoder.resize_token_embeddings(None)
|
||||
current_token_count = current_embeddings.num_embeddings
|
||||
new_token_count = current_token_count + num_tokens_added
|
||||
# the following call is slow - todo make batched for better performance with vector length >1
|
||||
self.text_encoder.resize_token_embeddings(new_token_count)
|
||||
|
||||
token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if token_id == self.tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{token_str}'")
|
||||
if (
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id].shape
|
||||
!= embedding.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}."
|
||||
)
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
||||
|
||||
return token_id
|
||||
|
||||
|
||||
def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
|
||||
suffix = Path(embedding_file).suffix
|
||||
try:
|
||||
if suffix in [".pt",".ckpt",".bin"]:
|
||||
scan_result = scan_file_path(embedding_file)
|
||||
if scan_result.infected_files > 0:
|
||||
logger.critical(
|
||||
f"Security Issues Found in Model: {scan_result.issues_count}"
|
||||
)
|
||||
logger.critical("For your safety, InvokeAI will not load this embed.")
|
||||
return list()
|
||||
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||
else:
|
||||
ckpt = safetensors.torch.load_file(embedding_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||
return list()
|
||||
|
||||
# try to figure out what kind of embedding file it is and parse accordingly
|
||||
keys = list(ckpt.keys())
|
||||
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
|
||||
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
|
||||
|
||||
elif all(x in keys for x in ['string_to_token','string_to_param']):
|
||||
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
|
||||
|
||||
elif 'emb_params' in keys:
|
||||
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
|
||||
|
||||
else:
|
||||
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
|
||||
|
||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||
basename = Path(file_path).stem
|
||||
logger.debug(f'Loading v1 embedding file: {basename}')
|
||||
|
||||
embeddings = list()
|
||||
token_counter = -1
|
||||
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||
if token_counter < 0:
|
||||
trigger = embedding_ckpt["name"]
|
||||
elif token_counter == 0:
|
||||
trigger = '<basename>'
|
||||
else:
|
||||
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
||||
token_counter += 1
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = trigger,
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
trained_steps = embedding_ckpt["step"],
|
||||
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
|
||||
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
return embeddings
|
||||
|
||||
def _parse_embedding_v2 (
|
||||
self, embedding_ckpt: dict, file_path: str
|
||||
) -> List[EmbeddingInfo]:
|
||||
"""
|
||||
This handles embedding .pt file variant #2.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
logger.debug(f'Loading v2 embedding file: {basename}')
|
||||
embeddings = list()
|
||||
|
||||
if isinstance(
|
||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
||||
):
|
||||
token_counter = 0
|
||||
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||
trigger = token if token != '*' \
|
||||
else f'<{basename}>' if token_counter == 0 \
|
||||
else f'<{basename}-{int(token_counter:=token_counter+1)}>'
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = trigger,
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
else:
|
||||
logger.warning(f"{basename}: Unrecognized embedding format")
|
||||
|
||||
return embeddings
|
||||
|
||||
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||
"""
|
||||
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
logger.debug(f'Loading v3 embedding file: {basename}')
|
||||
embedding = embedding_ckpt['emb_params']
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = f'<{basename}>',
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
)
|
||||
return [embedding_info]
|
||||
|
||||
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
|
||||
"""
|
||||
Parse 'version 4' of the textual inversion embedding files. This one
|
||||
is usually associated with .bin files trained by HuggingFace diffusers.
|
||||
"""
|
||||
basename = Path(filepath).stem
|
||||
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||
|
||||
logger.debug(f'Loading v4 embedding file: {short_path}')
|
||||
|
||||
embeddings = list()
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
logger.warning(f"Invalid embeddings file: {short_path}")
|
||||
else:
|
||||
for token,embedding in embedding_ckpt.items():
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = token or f"<{basename}>",
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = 1, # All Concepts seem to default to 1
|
||||
token_dim = embedding.size()[0],
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
return embeddings
|
@ -7,6 +7,7 @@ SAMPLER_CHOICES = [
|
||||
"ddpm",
|
||||
"deis",
|
||||
"lms",
|
||||
"lms_k",
|
||||
"pndm",
|
||||
"heun",
|
||||
'heun_k',
|
||||
@ -16,8 +17,13 @@ SAMPLER_CHOICES = [
|
||||
"kdpm_2",
|
||||
"kdpm_2_a",
|
||||
"dpmpp_2s",
|
||||
"dpmpp_2s_k",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"dpmpp_2m_sde",
|
||||
"dpmpp_2m_sde_k",
|
||||
"dpmpp_sde",
|
||||
"dpmpp_sde_k",
|
||||
"unipc",
|
||||
]
|
||||
|
||||
|
@ -547,7 +547,8 @@
|
||||
"general": "General",
|
||||
"generation": "Generation",
|
||||
"ui": "User Interface",
|
||||
"availableSchedulers": "Available Schedulers"
|
||||
"favoriteSchedulers": "Favorite Schedulers",
|
||||
"favoriteSchedulersPlaceholder": "No schedulers favorited"
|
||||
},
|
||||
"toast": {
|
||||
"serverError": "Server Error",
|
||||
|
@ -1,25 +1,62 @@
|
||||
// TODO: use Enums?
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
|
||||
export const SCHEDULERS = [
|
||||
'ddim',
|
||||
'lms',
|
||||
// zod needs the array to be `as const` to infer the type correctly
|
||||
// this is the source of the `SchedulerParam` type, which is generated by zod
|
||||
export const SCHEDULER_NAMES_AS_CONST = [
|
||||
'euler',
|
||||
'euler_k',
|
||||
'euler_a',
|
||||
'deis',
|
||||
'ddim',
|
||||
'ddpm',
|
||||
'dpmpp_2s',
|
||||
'dpmpp_2m',
|
||||
'dpmpp_2m_k',
|
||||
'kdpm_2',
|
||||
'kdpm_2_a',
|
||||
'deis',
|
||||
'ddpm',
|
||||
'pndm',
|
||||
'dpmpp_2m_sde',
|
||||
'dpmpp_sde',
|
||||
'heun',
|
||||
'heun_k',
|
||||
'kdpm_2',
|
||||
'lms',
|
||||
'pndm',
|
||||
'unipc',
|
||||
'euler_k',
|
||||
'dpmpp_2s_k',
|
||||
'dpmpp_2m_k',
|
||||
'dpmpp_2m_sde_k',
|
||||
'dpmpp_sde_k',
|
||||
'heun_k',
|
||||
'lms_k',
|
||||
'euler_a',
|
||||
'kdpm_2_a',
|
||||
] as const;
|
||||
|
||||
export type Scheduler = (typeof SCHEDULERS)[number];
|
||||
export const DEFAULT_SCHEDULER_NAME = 'euler';
|
||||
|
||||
export const SCHEDULER_NAMES: SchedulerParam[] = [...SCHEDULER_NAMES_AS_CONST];
|
||||
|
||||
export const SCHEDULER_LABEL_MAP: Record<SchedulerParam, string> = {
|
||||
euler: 'Euler',
|
||||
deis: 'DEIS',
|
||||
ddim: 'DDIM',
|
||||
ddpm: 'DDPM',
|
||||
dpmpp_sde: 'DPM++ SDE',
|
||||
dpmpp_2s: 'DPM++ 2S',
|
||||
dpmpp_2m: 'DPM++ 2M',
|
||||
dpmpp_2m_sde: 'DPM++ 2M SDE',
|
||||
heun: 'Heun',
|
||||
kdpm_2: 'KDPM 2',
|
||||
lms: 'LMS',
|
||||
pndm: 'PNDM',
|
||||
unipc: 'UniPC',
|
||||
euler_k: 'Euler Karras',
|
||||
dpmpp_sde_k: 'DPM++ SDE Karras',
|
||||
dpmpp_2s_k: 'DPM++ 2S Karras',
|
||||
dpmpp_2m_k: 'DPM++ 2M Karras',
|
||||
dpmpp_2m_sde_k: 'DPM++ 2M SDE Karras',
|
||||
heun_k: 'Heun Karras',
|
||||
lms_k: 'LMS Karras',
|
||||
euler_a: 'Euler Ancestral',
|
||||
kdpm_2_a: 'KDPM 2 Ancestral',
|
||||
};
|
||||
|
||||
export type Scheduler = (typeof SCHEDULER_NAMES)[number];
|
||||
|
||||
// Valid upscaling levels
|
||||
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [
|
||||
|
@ -1,11 +1,10 @@
|
||||
import { startAppListening } from '..';
|
||||
import { sessionCreated } from 'services/thunks/session';
|
||||
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||
import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { imageUpdated, imageUploaded } from 'services/thunks/image';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { Graph } from 'services/api';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import {
|
||||
canvasSessionIdChanged,
|
||||
stagingAreaInitialized,
|
||||
@ -67,112 +66,106 @@ export const addUserInvokedCanvasListener = () => {
|
||||
|
||||
moduleLog.debug(`Generation mode: ${generationMode}`);
|
||||
|
||||
// Build the canvas graph
|
||||
const graphComponents = await buildCanvasGraphComponents(
|
||||
state,
|
||||
generationMode
|
||||
);
|
||||
// Temp placeholders for the init and mask images
|
||||
let canvasInitImage: ImageDTO | undefined;
|
||||
let canvasMaskImage: ImageDTO | undefined;
|
||||
|
||||
if (!graphComponents) {
|
||||
moduleLog.error('Problem building graph');
|
||||
return;
|
||||
}
|
||||
|
||||
const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
|
||||
|
||||
// Assemble! Note that this graph *does not have the init or mask image set yet!*
|
||||
const nodes: Graph['nodes'] = {
|
||||
[rangeNode.id]: rangeNode,
|
||||
[iterateNode.id]: iterateNode,
|
||||
[baseNode.id]: baseNode,
|
||||
};
|
||||
|
||||
const graph = { nodes, edges };
|
||||
|
||||
dispatch(canvasGraphBuilt(graph));
|
||||
|
||||
moduleLog.debug({ data: graph }, 'Canvas graph built');
|
||||
|
||||
// If we are generating img2img or inpaint, we need to upload the init images
|
||||
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
|
||||
const baseFilename = `${uuidv4()}.png`;
|
||||
dispatch(
|
||||
// For img2img and inpaint/outpaint, we need to upload the init images
|
||||
if (['img2img', 'inpaint', 'outpaint'].includes(generationMode)) {
|
||||
// upload the image, saving the request id
|
||||
const { requestId: initImageUploadedRequestId } = dispatch(
|
||||
imageUploaded({
|
||||
formData: {
|
||||
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
|
||||
file: new File([baseBlob], 'canvasInitImage.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
},
|
||||
imageCategory: 'general',
|
||||
isIntermediate: true,
|
||||
})
|
||||
);
|
||||
|
||||
// Wait for the image to be uploaded
|
||||
const [{ payload: baseImageDTO }] = await take(
|
||||
// Wait for the image to be uploaded, matching by request id
|
||||
const [{ payload }] = await take(
|
||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.meta.arg.formData.file.name === baseFilename
|
||||
action.meta.requestId === initImageUploadedRequestId
|
||||
);
|
||||
|
||||
// Update the base node with the image name and type
|
||||
baseNode.image = {
|
||||
image_name: baseImageDTO.image_name,
|
||||
};
|
||||
canvasInitImage = payload;
|
||||
}
|
||||
|
||||
// For inpaint, we also need to upload the mask layer
|
||||
if (baseNode.type === 'inpaint') {
|
||||
const maskFilename = `${uuidv4()}.png`;
|
||||
dispatch(
|
||||
// For inpaint/outpaint, we also need to upload the mask layer
|
||||
if (['inpaint', 'outpaint'].includes(generationMode)) {
|
||||
// upload the image, saving the request id
|
||||
const { requestId: maskImageUploadedRequestId } = dispatch(
|
||||
imageUploaded({
|
||||
formData: {
|
||||
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
|
||||
file: new File([maskBlob], 'canvasMaskImage.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
},
|
||||
imageCategory: 'mask',
|
||||
isIntermediate: true,
|
||||
})
|
||||
);
|
||||
|
||||
// Wait for the mask to be uploaded
|
||||
const [{ payload: maskImageDTO }] = await take(
|
||||
// Wait for the image to be uploaded, matching by request id
|
||||
const [{ payload }] = await take(
|
||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.meta.arg.formData.file.name === maskFilename
|
||||
action.meta.requestId === maskImageUploadedRequestId
|
||||
);
|
||||
|
||||
// Update the base node with the image name and type
|
||||
baseNode.mask = {
|
||||
image_name: maskImageDTO.image_name,
|
||||
};
|
||||
canvasMaskImage = payload;
|
||||
}
|
||||
|
||||
// Create the session and wait for response
|
||||
dispatch(sessionCreated({ graph }));
|
||||
const [sessionCreatedAction] = await take(sessionCreated.fulfilled.match);
|
||||
const graph = buildCanvasGraph(
|
||||
state,
|
||||
generationMode,
|
||||
canvasInitImage,
|
||||
canvasMaskImage
|
||||
);
|
||||
|
||||
moduleLog.debug({ graph }, `Canvas graph built`);
|
||||
|
||||
// currently this action is just listened to for logging
|
||||
dispatch(canvasGraphBuilt(graph));
|
||||
|
||||
// Create the session, store the request id
|
||||
const { requestId: sessionCreatedRequestId } = dispatch(
|
||||
sessionCreated({ graph })
|
||||
);
|
||||
|
||||
// Take the session created action, matching by its request id
|
||||
const [sessionCreatedAction] = await take(
|
||||
(action): action is ReturnType<typeof sessionCreated.fulfilled> =>
|
||||
sessionCreated.fulfilled.match(action) &&
|
||||
action.meta.requestId === sessionCreatedRequestId
|
||||
);
|
||||
const sessionId = sessionCreatedAction.payload.id;
|
||||
|
||||
// Associate the init image with the session, now that we have the session ID
|
||||
if (
|
||||
(baseNode.type === 'img2img' || baseNode.type === 'inpaint') &&
|
||||
baseNode.image
|
||||
) {
|
||||
if (['img2img', 'inpaint'].includes(generationMode) && canvasInitImage) {
|
||||
dispatch(
|
||||
imageUpdated({
|
||||
imageName: baseNode.image.image_name,
|
||||
imageName: canvasInitImage.image_name,
|
||||
requestBody: { session_id: sessionId },
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Associate the mask image with the session, now that we have the session ID
|
||||
if (baseNode.type === 'inpaint' && baseNode.mask) {
|
||||
if (['inpaint'].includes(generationMode) && canvasMaskImage) {
|
||||
dispatch(
|
||||
imageUpdated({
|
||||
imageName: baseNode.mask.image_name,
|
||||
imageName: canvasMaskImage.image_name,
|
||||
requestBody: { session_id: sessionId },
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Prep the canvas staging area if it is not yet initialized
|
||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||
dispatch(
|
||||
stagingAreaInitialized({
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { startAppListening } from '..';
|
||||
import { buildImageToImageGraph } from 'features/nodes/util/graphBuilders/buildImageToImageGraph';
|
||||
import { sessionCreated } from 'services/thunks/session';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'invoke' });
|
||||
|
||||
@ -15,7 +15,7 @@ export const addUserInvokedImageToImageListener = () => {
|
||||
effect: async (action, { getState, dispatch, take }) => {
|
||||
const state = getState();
|
||||
|
||||
const graph = buildImageToImageGraph(state);
|
||||
const graph = buildLinearImageToImageGraph(state);
|
||||
dispatch(imageToImageGraphBuilt(graph));
|
||||
moduleLog.debug({ data: graph }, 'Image to Image graph built');
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { startAppListening } from '..';
|
||||
import { buildTextToImageGraph } from 'features/nodes/util/graphBuilders/buildTextToImageGraph';
|
||||
import { sessionCreated } from 'services/thunks/session';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'invoke' });
|
||||
|
||||
@ -15,7 +15,7 @@ export const addUserInvokedTextToImageListener = () => {
|
||||
effect: async (action, { getState, dispatch, take }) => {
|
||||
const state = getState();
|
||||
|
||||
const graph = buildTextToImageGraph(state);
|
||||
const graph = buildLinearTextToImageGraph(state);
|
||||
|
||||
dispatch(textToImageGraphBuilt(graph));
|
||||
|
||||
|
@ -0,0 +1,94 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { MultiSelect, MultiSelectProps } from '@mantine/core';
|
||||
import { memo } from 'react';
|
||||
|
||||
type IAIMultiSelectProps = MultiSelectProps & {
|
||||
tooltip?: string;
|
||||
};
|
||||
|
||||
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
const { searchable = true, tooltip, ...rest } = props;
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<MultiSelect
|
||||
searchable={searchable}
|
||||
styles={() => ({
|
||||
label: {
|
||||
color: 'var(--invokeai-colors-base-300)',
|
||||
fontWeight: 'normal',
|
||||
},
|
||||
searchInput: {
|
||||
'::placeholder': {
|
||||
color: 'var(--invokeai-colors-base-700)',
|
||||
},
|
||||
},
|
||||
input: {
|
||||
backgroundColor: 'var(--invokeai-colors-base-900)',
|
||||
borderWidth: '2px',
|
||||
borderColor: 'var(--invokeai-colors-base-800)',
|
||||
color: 'var(--invokeai-colors-base-100)',
|
||||
padding: 10,
|
||||
paddingRight: 24,
|
||||
fontWeight: 600,
|
||||
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },
|
||||
'&:focus': {
|
||||
borderColor: 'var(--invokeai-colors-accent-600)',
|
||||
},
|
||||
'&:focus-within': {
|
||||
borderColor: 'var(--invokeai-colors-accent-600)',
|
||||
},
|
||||
},
|
||||
value: {
|
||||
backgroundColor: 'var(--invokeai-colors-base-800)',
|
||||
color: 'var(--invokeai-colors-base-100)',
|
||||
button: {
|
||||
color: 'var(--invokeai-colors-base-100)',
|
||||
},
|
||||
'&:hover': {
|
||||
backgroundColor: 'var(--invokeai-colors-base-700)',
|
||||
cursor: 'pointer',
|
||||
},
|
||||
},
|
||||
dropdown: {
|
||||
backgroundColor: 'var(--invokeai-colors-base-800)',
|
||||
borderColor: 'var(--invokeai-colors-base-700)',
|
||||
},
|
||||
item: {
|
||||
backgroundColor: 'var(--invokeai-colors-base-800)',
|
||||
color: 'var(--invokeai-colors-base-200)',
|
||||
padding: 6,
|
||||
'&[data-hovered]': {
|
||||
color: 'var(--invokeai-colors-base-100)',
|
||||
backgroundColor: 'var(--invokeai-colors-base-750)',
|
||||
},
|
||||
'&[data-active]': {
|
||||
backgroundColor: 'var(--invokeai-colors-base-750)',
|
||||
'&:hover': {
|
||||
color: 'var(--invokeai-colors-base-100)',
|
||||
backgroundColor: 'var(--invokeai-colors-base-750)',
|
||||
},
|
||||
},
|
||||
'&[data-selected]': {
|
||||
color: 'var(--invokeai-colors-base-50)',
|
||||
backgroundColor: 'var(--invokeai-colors-accent-650)',
|
||||
fontWeight: 600,
|
||||
'&:hover': {
|
||||
backgroundColor: 'var(--invokeai-colors-accent-600)',
|
||||
},
|
||||
},
|
||||
},
|
||||
rightSection: {
|
||||
width: 24,
|
||||
padding: 20,
|
||||
button: {
|
||||
color: 'var(--invokeai-colors-base-100)',
|
||||
},
|
||||
},
|
||||
})}
|
||||
{...rest}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAIMantineMultiSelect);
|
@ -2,8 +2,7 @@ import { RootState } from 'app/store/store';
|
||||
import { filter, forEach, size } from 'lodash-es';
|
||||
import { CollectInvocation, ControlNetInvocation } from 'services/api';
|
||||
import { NonNullableGraph } from '../types/types';
|
||||
|
||||
const CONTROL_NET_COLLECT = 'control_net_collect';
|
||||
import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
|
||||
|
||||
export const addControlNetToLinearGraph = (
|
||||
graph: NonNullableGraph,
|
||||
@ -37,7 +36,7 @@ export const addControlNetToLinearGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
forEach(controlNets, (controlNet, index) => {
|
||||
forEach(controlNets, (controlNet) => {
|
||||
const {
|
||||
controlNetId,
|
||||
isEnabled,
|
||||
|
@ -1,116 +1,39 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
Edge,
|
||||
ImageToImageInvocation,
|
||||
InpaintInvocation,
|
||||
IterateInvocation,
|
||||
RandomRangeInvocation,
|
||||
RangeInvocation,
|
||||
TextToImageInvocation,
|
||||
} from 'services/api';
|
||||
import { buildImg2ImgNode } from '../nodeBuilders/buildImageToImageNode';
|
||||
import { buildTxt2ImgNode } from '../nodeBuilders/buildTextToImageNode';
|
||||
import { buildRangeNode } from '../nodeBuilders/buildRangeNode';
|
||||
import { buildIterateNode } from '../nodeBuilders/buildIterateNode';
|
||||
import { buildEdges } from '../edgeBuilders/buildEdges';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { buildInpaintNode } from '../nodeBuilders/buildInpaintNode';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { buildCanvasInpaintGraph } from './buildCanvasInpaintGraph';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { buildCanvasImageToImageGraph } from './buildCanvasImageToImageGraph';
|
||||
import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
const buildBaseNode = (
|
||||
nodeType: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
|
||||
state: RootState
|
||||
):
|
||||
| TextToImageInvocation
|
||||
| ImageToImageInvocation
|
||||
| InpaintInvocation
|
||||
| undefined => {
|
||||
const overrides = {
|
||||
...state.canvas.boundingBoxDimensions,
|
||||
is_intermediate: true,
|
||||
};
|
||||
|
||||
if (nodeType === 'txt2img') {
|
||||
return buildTxt2ImgNode(state, overrides);
|
||||
}
|
||||
|
||||
if (nodeType === 'img2img') {
|
||||
return buildImg2ImgNode(state, overrides);
|
||||
}
|
||||
|
||||
if (nodeType === 'inpaint' || nodeType === 'outpaint') {
|
||||
return buildInpaintNode(state, overrides);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Builds the Canvas workflow graph and image blobs.
|
||||
*/
|
||||
export const buildCanvasGraphComponents = async (
|
||||
export const buildCanvasGraph = (
|
||||
state: RootState,
|
||||
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint'
|
||||
): Promise<
|
||||
| {
|
||||
rangeNode: RangeInvocation | RandomRangeInvocation;
|
||||
iterateNode: IterateInvocation;
|
||||
baseNode:
|
||||
| TextToImageInvocation
|
||||
| ImageToImageInvocation
|
||||
| InpaintInvocation;
|
||||
edges: Edge[];
|
||||
}
|
||||
| undefined
|
||||
> => {
|
||||
// The base node is a txt2img, img2img or inpaint node
|
||||
const baseNode = buildBaseNode(generationMode, state);
|
||||
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
|
||||
canvasInitImage: ImageDTO | undefined,
|
||||
canvasMaskImage: ImageDTO | undefined
|
||||
) => {
|
||||
let graph: NonNullableGraph;
|
||||
|
||||
if (!baseNode) {
|
||||
moduleLog.error('Problem building base node');
|
||||
return;
|
||||
if (generationMode === 'txt2img') {
|
||||
graph = buildCanvasTextToImageGraph(state);
|
||||
} else if (generationMode === 'img2img') {
|
||||
if (!canvasInitImage) {
|
||||
throw new Error('Missing canvas init image');
|
||||
}
|
||||
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
|
||||
} else {
|
||||
if (!canvasInitImage || !canvasMaskImage) {
|
||||
throw new Error('Missing canvas init and mask images');
|
||||
}
|
||||
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
}
|
||||
|
||||
if (baseNode.type === 'inpaint') {
|
||||
const {
|
||||
seamSize,
|
||||
seamBlur,
|
||||
seamSteps,
|
||||
seamStrength,
|
||||
tileSize,
|
||||
infillMethod,
|
||||
} = state.generation;
|
||||
forEach(graph.nodes, (node) => {
|
||||
graph.nodes[node.id].is_intermediate = true;
|
||||
});
|
||||
|
||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } =
|
||||
state.canvas;
|
||||
|
||||
if (boundingBoxScaleMethod !== 'none') {
|
||||
baseNode.inpaint_width = scaledBoundingBoxDimensions.width;
|
||||
baseNode.inpaint_height = scaledBoundingBoxDimensions.height;
|
||||
}
|
||||
|
||||
baseNode.seam_size = seamSize;
|
||||
baseNode.seam_blur = seamBlur;
|
||||
baseNode.seam_strength = seamStrength;
|
||||
baseNode.seam_steps = seamSteps;
|
||||
baseNode.infill_method = infillMethod as InpaintInvocation['infill_method'];
|
||||
|
||||
if (infillMethod === 'tile') {
|
||||
baseNode.tile_size = tileSize;
|
||||
}
|
||||
}
|
||||
|
||||
// We always range and iterate nodes, no matter the iteration count
|
||||
// This is required to provide the correct seeds to the backend engine
|
||||
const rangeNode = buildRangeNode(state);
|
||||
const iterateNode = buildIterateNode();
|
||||
|
||||
// Build the edges for the nodes selected.
|
||||
const edges = buildEdges(baseNode, rangeNode, iterateNode);
|
||||
|
||||
return {
|
||||
rangeNode,
|
||||
iterateNode,
|
||||
baseNode,
|
||||
edges,
|
||||
};
|
||||
return graph;
|
||||
};
|
||||
|
@ -0,0 +1,331 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
ImageDTO,
|
||||
ImageResizeInvocation,
|
||||
RandomIntInvocation,
|
||||
RangeOfSizeInvocation,
|
||||
} from 'services/api';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import {
|
||||
ITERATE,
|
||||
LATENTS_TO_IMAGE,
|
||||
MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
RANGE_OF_SIZE,
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_LATENTS,
|
||||
RESIZE,
|
||||
} from './constants';
|
||||
import { set } from 'lodash-es';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Image to Image graph.
|
||||
*/
|
||||
export const buildCanvasImageToImageGraph = (
|
||||
state: RootState,
|
||||
initialImage: ImageDTO
|
||||
): NonNullableGraph => {
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: model_name,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
img2imgStrength: strength,
|
||||
iterations,
|
||||
seed,
|
||||
shouldRandomizeSeed,
|
||||
} = state.generation;
|
||||
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
* ids.
|
||||
*
|
||||
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||
* the `fit` param. These are added to the graph at the end.
|
||||
*/
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
const graph: NonNullableGraph = {
|
||||
id: IMAGE_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
},
|
||||
[RANGE_OF_SIZE]: {
|
||||
type: 'range_of_size',
|
||||
id: RANGE_OF_SIZE,
|
||||
// seed - must be connected manually
|
||||
// start: 0,
|
||||
size: iterations,
|
||||
step: 1,
|
||||
},
|
||||
[NOISE]: {
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
},
|
||||
[MODEL_LOADER]: {
|
||||
type: 'sd1_model_loader',
|
||||
id: MODEL_LOADER,
|
||||
model_name,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
},
|
||||
[ITERATE]: {
|
||||
type: 'iterate',
|
||||
id: ITERATE,
|
||||
},
|
||||
[LATENTS_TO_LATENTS]: {
|
||||
type: 'l2l',
|
||||
id: LATENTS_TO_LATENTS,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
strength,
|
||||
},
|
||||
[IMAGE_TO_LATENTS]: {
|
||||
type: 'i2l',
|
||||
id: IMAGE_TO_LATENTS,
|
||||
// must be set manually later, bc `fit` parameter may require a resize node inserted
|
||||
// image: {
|
||||
// image_name: initialImage.image_name,
|
||||
// },
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: RANGE_OF_SIZE,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: ITERATE,
|
||||
field: 'collection',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: ITERATE,
|
||||
field: 'item',
|
||||
},
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'seed',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// handle seed
|
||||
if (shouldRandomizeSeed) {
|
||||
// Random int node to generate the starting seed
|
||||
const randomIntNode: RandomIntInvocation = {
|
||||
id: RANDOM_INT,
|
||||
type: 'rand_int',
|
||||
};
|
||||
|
||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||
|
||||
// Connect random int to the start of the range of size so the range starts on the random first seed
|
||||
graph.edges.push({
|
||||
source: { node_id: RANDOM_INT, field: 'a' },
|
||||
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
||||
});
|
||||
} else {
|
||||
// User specified seed, so set the start of the range of size to the seed
|
||||
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
|
||||
}
|
||||
|
||||
// handle `fit`
|
||||
if (initialImage.width !== width || initialImage.height !== height) {
|
||||
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||
|
||||
// Create a resize node, explicitly setting its image
|
||||
const resizeNode: ImageResizeInvocation = {
|
||||
id: RESIZE,
|
||||
type: 'img_resize',
|
||||
image: {
|
||||
image_name: initialImage.image_name,
|
||||
},
|
||||
is_intermediate: true,
|
||||
width,
|
||||
height,
|
||||
};
|
||||
|
||||
graph.nodes[RESIZE] = resizeNode;
|
||||
|
||||
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'image' },
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
|
||||
// The `RESIZE` node also passes its width and height to `NOISE`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
|
||||
image_name: initialImage.image_name,
|
||||
});
|
||||
|
||||
// Pass the image's dimensions to the `NOISE` node
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// add controlnet
|
||||
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
|
||||
|
||||
return graph;
|
||||
};
|
@ -0,0 +1,224 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
ImageDTO,
|
||||
InpaintInvocation,
|
||||
RandomIntInvocation,
|
||||
RangeOfSizeInvocation,
|
||||
} from 'services/api';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import {
|
||||
ITERATE,
|
||||
MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
RANGE_OF_SIZE,
|
||||
INPAINT_GRAPH,
|
||||
INPAINT,
|
||||
} from './constants';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Inpaint graph.
|
||||
*/
|
||||
export const buildCanvasInpaintGraph = (
|
||||
state: RootState,
|
||||
canvasInitImage: ImageDTO,
|
||||
canvasMaskImage: ImageDTO
|
||||
): NonNullableGraph => {
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: model_name,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
img2imgStrength: strength,
|
||||
shouldFitToWidthHeight,
|
||||
iterations,
|
||||
seed,
|
||||
shouldRandomizeSeed,
|
||||
seamSize,
|
||||
seamBlur,
|
||||
seamSteps,
|
||||
seamStrength,
|
||||
tileSize,
|
||||
infillMethod,
|
||||
} = state.generation;
|
||||
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
|
||||
// We may need to set the inpaint width and height to scale the image
|
||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
id: INPAINT_GRAPH,
|
||||
nodes: {
|
||||
[INPAINT]: {
|
||||
type: 'inpaint',
|
||||
id: INPAINT,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
image: {
|
||||
image_name: canvasInitImage.image_name,
|
||||
},
|
||||
strength,
|
||||
fit: shouldFitToWidthHeight,
|
||||
mask: {
|
||||
image_name: canvasMaskImage.image_name,
|
||||
},
|
||||
seam_size: seamSize,
|
||||
seam_blur: seamBlur,
|
||||
seam_strength: seamStrength,
|
||||
seam_steps: seamSteps,
|
||||
tile_size: infillMethod === 'tile' ? tileSize : undefined,
|
||||
infill_method: infillMethod as InpaintInvocation['infill_method'],
|
||||
inpaint_width:
|
||||
boundingBoxScaleMethod !== 'none'
|
||||
? scaledBoundingBoxDimensions.width
|
||||
: undefined,
|
||||
inpaint_height:
|
||||
boundingBoxScaleMethod !== 'none'
|
||||
? scaledBoundingBoxDimensions.height
|
||||
: undefined,
|
||||
},
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
},
|
||||
[MODEL_LOADER]: {
|
||||
type: 'sd1_model_loader',
|
||||
id: MODEL_LOADER,
|
||||
model_name,
|
||||
},
|
||||
[RANGE_OF_SIZE]: {
|
||||
type: 'range_of_size',
|
||||
id: RANGE_OF_SIZE,
|
||||
// seed - must be connected manually
|
||||
// start: 0,
|
||||
size: iterations,
|
||||
step: 1,
|
||||
},
|
||||
[ITERATE]: {
|
||||
type: 'iterate',
|
||||
id: ITERATE,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: RANGE_OF_SIZE,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: ITERATE,
|
||||
field: 'collection',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: ITERATE,
|
||||
field: 'item',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT,
|
||||
field: 'seed',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// handle seed
|
||||
if (shouldRandomizeSeed) {
|
||||
// Random int node to generate the starting seed
|
||||
const randomIntNode: RandomIntInvocation = {
|
||||
id: RANDOM_INT,
|
||||
type: 'rand_int',
|
||||
};
|
||||
|
||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||
|
||||
// Connect random int to the start of the range of size so the range starts on the random first seed
|
||||
graph.edges.push({
|
||||
source: { node_id: RANDOM_INT, field: 'a' },
|
||||
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
||||
});
|
||||
} else {
|
||||
// User specified seed, so set the start of the range of size to the seed
|
||||
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
|
||||
}
|
||||
|
||||
return graph;
|
||||
};
|
@ -0,0 +1,224 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api';
|
||||
import {
|
||||
ITERATE,
|
||||
LATENTS_TO_IMAGE,
|
||||
MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
RANGE_OF_SIZE,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
TEXT_TO_LATENTS,
|
||||
} from './constants';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
*/
|
||||
export const buildCanvasTextToImageGraph = (
|
||||
state: RootState
|
||||
): NonNullableGraph => {
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: model_name,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
iterations,
|
||||
seed,
|
||||
shouldRandomizeSeed,
|
||||
} = state.generation;
|
||||
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
* ids.
|
||||
*
|
||||
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||
* the `fit` param. These are added to the graph at the end.
|
||||
*/
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
const graph: NonNullableGraph = {
|
||||
id: TEXT_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
},
|
||||
[RANGE_OF_SIZE]: {
|
||||
type: 'range_of_size',
|
||||
id: RANGE_OF_SIZE,
|
||||
// start: 0, // seed - must be connected manually
|
||||
size: iterations,
|
||||
step: 1,
|
||||
},
|
||||
[NOISE]: {
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
width,
|
||||
height,
|
||||
},
|
||||
[TEXT_TO_LATENTS]: {
|
||||
type: 't2l',
|
||||
id: TEXT_TO_LATENTS,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
},
|
||||
[MODEL_LOADER]: {
|
||||
type: 'sd1_model_loader',
|
||||
id: MODEL_LOADER,
|
||||
model_name,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
},
|
||||
[ITERATE]: {
|
||||
type: 'iterate',
|
||||
id: ITERATE,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: RANGE_OF_SIZE,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: ITERATE,
|
||||
field: 'collection',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: ITERATE,
|
||||
field: 'item',
|
||||
},
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'seed',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// handle seed
|
||||
if (shouldRandomizeSeed) {
|
||||
// Random int node to generate the starting seed
|
||||
const randomIntNode: RandomIntInvocation = {
|
||||
id: RANDOM_INT,
|
||||
type: 'rand_int',
|
||||
};
|
||||
|
||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||
|
||||
// Connect random int to the start of the range of size so the range starts on the random first seed
|
||||
graph.edges.push({
|
||||
source: { node_id: RANDOM_INT, field: 'a' },
|
||||
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
||||
});
|
||||
} else {
|
||||
// User specified seed, so set the start of the range of size to the seed
|
||||
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
|
||||
}
|
||||
|
||||
// add controlnet
|
||||
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
|
||||
|
||||
return graph;
|
||||
};
|
@ -1,416 +0,0 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
CompelInvocation,
|
||||
Graph,
|
||||
ImageResizeInvocation,
|
||||
ImageToLatentsInvocation,
|
||||
IterateInvocation,
|
||||
LatentsToImageInvocation,
|
||||
LatentsToLatentsInvocation,
|
||||
NoiseInvocation,
|
||||
RandomIntInvocation,
|
||||
RangeOfSizeInvocation,
|
||||
} from 'services/api';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { set } from 'lodash-es';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
const POSITIVE_CONDITIONING = 'positive_conditioning';
|
||||
const NEGATIVE_CONDITIONING = 'negative_conditioning';
|
||||
const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||
const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||
const LATENTS_TO_IMAGE = 'latents_to_image';
|
||||
const RESIZE = 'resize_image';
|
||||
const NOISE = 'noise';
|
||||
const RANDOM_INT = 'rand_int';
|
||||
const RANGE_OF_SIZE = 'range_of_size';
|
||||
const ITERATE = 'iterate';
|
||||
|
||||
/**
|
||||
* Builds the Image to Image tab graph.
|
||||
*/
|
||||
export const buildImageToImageGraph = (state: RootState): Graph => {
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
initialImage,
|
||||
img2imgStrength: strength,
|
||||
shouldFitToWidthHeight,
|
||||
width,
|
||||
height,
|
||||
iterations,
|
||||
seed,
|
||||
shouldRandomizeSeed,
|
||||
} = state.generation;
|
||||
|
||||
if (!initialImage) {
|
||||
moduleLog.error('No initial image found in state');
|
||||
throw new Error('No initial image found in state');
|
||||
}
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
nodes: {},
|
||||
edges: [],
|
||||
};
|
||||
|
||||
// Create the positive conditioning (prompt) node
|
||||
const positiveConditioningNode: CompelInvocation = {
|
||||
id: POSITIVE_CONDITIONING,
|
||||
type: 'compel',
|
||||
prompt: positivePrompt,
|
||||
model,
|
||||
};
|
||||
|
||||
// Negative conditioning
|
||||
const negativeConditioningNode: CompelInvocation = {
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
type: 'compel',
|
||||
prompt: negativePrompt,
|
||||
model,
|
||||
};
|
||||
|
||||
// This will encode the raster image to latents - but it may get its `image` from a resize node,
|
||||
// so we do not set its `image` property yet
|
||||
const imageToLatentsNode: ImageToLatentsInvocation = {
|
||||
id: IMAGE_TO_LATENTS,
|
||||
type: 'i2l',
|
||||
model,
|
||||
};
|
||||
|
||||
// This does the actual img2img inference
|
||||
const latentsToLatentsNode: LatentsToLatentsInvocation = {
|
||||
id: LATENTS_TO_LATENTS,
|
||||
type: 'l2l',
|
||||
cfg_scale,
|
||||
model,
|
||||
scheduler,
|
||||
steps,
|
||||
strength,
|
||||
};
|
||||
|
||||
// Finally we decode the latents back to an image
|
||||
const latentsToImageNode: LatentsToImageInvocation = {
|
||||
id: LATENTS_TO_IMAGE,
|
||||
type: 'l2i',
|
||||
model,
|
||||
};
|
||||
|
||||
// Add all those nodes to the graph
|
||||
graph.nodes[POSITIVE_CONDITIONING] = positiveConditioningNode;
|
||||
graph.nodes[NEGATIVE_CONDITIONING] = negativeConditioningNode;
|
||||
graph.nodes[IMAGE_TO_LATENTS] = imageToLatentsNode;
|
||||
graph.nodes[LATENTS_TO_LATENTS] = latentsToLatentsNode;
|
||||
graph.nodes[LATENTS_TO_IMAGE] = latentsToImageNode;
|
||||
|
||||
// Connect the prompt nodes to the imageToLatents node
|
||||
graph.edges.push({
|
||||
source: { node_id: POSITIVE_CONDITIONING, field: 'conditioning' },
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: NEGATIVE_CONDITIONING, field: 'conditioning' },
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
});
|
||||
|
||||
// Connect the image-encoding node
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'latents' },
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
});
|
||||
|
||||
// Connect the image-decoding node
|
||||
graph.edges.push({
|
||||
source: { node_id: LATENTS_TO_LATENTS, field: 'latents' },
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
},
|
||||
});
|
||||
|
||||
/**
|
||||
* Now we need to handle iterations and random seeds. There are four possible scenarios:
|
||||
* - Single iteration, explicit seed
|
||||
* - Single iteration, random seed
|
||||
* - Multiple iterations, explicit seed
|
||||
* - Multiple iterations, random seed
|
||||
*
|
||||
* They all have different graphs and connections.
|
||||
*/
|
||||
|
||||
// Single iteration, explicit seed
|
||||
if (!shouldRandomizeSeed && iterations === 1) {
|
||||
// Noise node using the explicit seed
|
||||
const noiseNode: NoiseInvocation = {
|
||||
id: NOISE,
|
||||
type: 'noise',
|
||||
seed: seed,
|
||||
};
|
||||
|
||||
graph.nodes[NOISE] = noiseNode;
|
||||
|
||||
// Connect noise to l2l
|
||||
graph.edges.push({
|
||||
source: { node_id: NOISE, field: 'noise' },
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Single iteration, random seed
|
||||
if (shouldRandomizeSeed && iterations === 1) {
|
||||
// Random int node to generate the seed
|
||||
const randomIntNode: RandomIntInvocation = {
|
||||
id: RANDOM_INT,
|
||||
type: 'rand_int',
|
||||
};
|
||||
|
||||
// Noise node without any seed
|
||||
const noiseNode: NoiseInvocation = {
|
||||
id: NOISE,
|
||||
type: 'noise',
|
||||
};
|
||||
|
||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||
graph.nodes[NOISE] = noiseNode;
|
||||
|
||||
// Connect random int to the seed of the noise node
|
||||
graph.edges.push({
|
||||
source: { node_id: RANDOM_INT, field: 'a' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'seed',
|
||||
},
|
||||
});
|
||||
|
||||
// Connect noise to l2l
|
||||
graph.edges.push({
|
||||
source: { node_id: NOISE, field: 'noise' },
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Multiple iterations, explicit seed
|
||||
if (!shouldRandomizeSeed && iterations > 1) {
|
||||
// Range of size node to generate `iterations` count of seeds - range of size generates a collection
|
||||
// of ints from `start` to `start + size`. The `start` is the seed, and the `size` is the number of
|
||||
// iterations.
|
||||
const rangeOfSizeNode: RangeOfSizeInvocation = {
|
||||
id: RANGE_OF_SIZE,
|
||||
type: 'range_of_size',
|
||||
start: seed,
|
||||
size: iterations,
|
||||
};
|
||||
|
||||
// Iterate node to iterate over the seeds generated by the range of size node
|
||||
const iterateNode: IterateInvocation = {
|
||||
id: ITERATE,
|
||||
type: 'iterate',
|
||||
};
|
||||
|
||||
// Noise node without any seed
|
||||
const noiseNode: NoiseInvocation = {
|
||||
id: NOISE,
|
||||
type: 'noise',
|
||||
};
|
||||
|
||||
// Adding to the graph
|
||||
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
|
||||
graph.nodes[ITERATE] = iterateNode;
|
||||
graph.nodes[NOISE] = noiseNode;
|
||||
|
||||
// Connect range of size to iterate
|
||||
graph.edges.push({
|
||||
source: { node_id: RANGE_OF_SIZE, field: 'collection' },
|
||||
destination: {
|
||||
node_id: ITERATE,
|
||||
field: 'collection',
|
||||
},
|
||||
});
|
||||
|
||||
// Connect iterate to noise
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: ITERATE,
|
||||
field: 'item',
|
||||
},
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'seed',
|
||||
},
|
||||
});
|
||||
|
||||
// Connect noise to l2l
|
||||
graph.edges.push({
|
||||
source: { node_id: NOISE, field: 'noise' },
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Multiple iterations, random seed
|
||||
if (shouldRandomizeSeed && iterations > 1) {
|
||||
// Random int node to generate the seed
|
||||
const randomIntNode: RandomIntInvocation = {
|
||||
id: RANDOM_INT,
|
||||
type: 'rand_int',
|
||||
};
|
||||
|
||||
// Range of size node to generate `iterations` count of seeds - range of size generates a collection
|
||||
const rangeOfSizeNode: RangeOfSizeInvocation = {
|
||||
id: RANGE_OF_SIZE,
|
||||
type: 'range_of_size',
|
||||
size: iterations,
|
||||
};
|
||||
|
||||
// Iterate node to iterate over the seeds generated by the range of size node
|
||||
const iterateNode: IterateInvocation = {
|
||||
id: ITERATE,
|
||||
type: 'iterate',
|
||||
};
|
||||
|
||||
// Noise node without any seed
|
||||
const noiseNode: NoiseInvocation = {
|
||||
id: NOISE,
|
||||
type: 'noise',
|
||||
width,
|
||||
height,
|
||||
};
|
||||
|
||||
// Adding to the graph
|
||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
|
||||
graph.nodes[ITERATE] = iterateNode;
|
||||
graph.nodes[NOISE] = noiseNode;
|
||||
|
||||
// Connect random int to the start of the range of size so the range starts on the random first seed
|
||||
graph.edges.push({
|
||||
source: { node_id: RANDOM_INT, field: 'a' },
|
||||
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
||||
});
|
||||
|
||||
// Connect range of size to iterate
|
||||
graph.edges.push({
|
||||
source: { node_id: RANGE_OF_SIZE, field: 'collection' },
|
||||
destination: {
|
||||
node_id: ITERATE,
|
||||
field: 'collection',
|
||||
},
|
||||
});
|
||||
|
||||
// Connect iterate to noise
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: ITERATE,
|
||||
field: 'item',
|
||||
},
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'seed',
|
||||
},
|
||||
});
|
||||
|
||||
// Connect noise to l2l
|
||||
graph.edges.push({
|
||||
source: { node_id: NOISE, field: 'noise' },
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (
|
||||
shouldFitToWidthHeight &&
|
||||
(initialImage.width !== width || initialImage.height !== height)
|
||||
) {
|
||||
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||
|
||||
// Create a resize node, explicitly setting its image
|
||||
const resizeNode: ImageResizeInvocation = {
|
||||
id: RESIZE,
|
||||
type: 'img_resize',
|
||||
image: {
|
||||
image_name: initialImage.image_name,
|
||||
},
|
||||
is_intermediate: true,
|
||||
height,
|
||||
width,
|
||||
};
|
||||
|
||||
graph.nodes[RESIZE] = resizeNode;
|
||||
|
||||
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'image' },
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
|
||||
// The `RESIZE` node also passes its width and height to `NOISE`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
|
||||
image_name: initialImage.image_name,
|
||||
});
|
||||
|
||||
// Pass the image's dimensions to the `NOISE` node
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
|
||||
|
||||
return graph;
|
||||
};
|
@ -0,0 +1,338 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
ImageResizeInvocation,
|
||||
RandomIntInvocation,
|
||||
RangeOfSizeInvocation,
|
||||
} from 'services/api';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import {
|
||||
ITERATE,
|
||||
LATENTS_TO_IMAGE,
|
||||
MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
RANGE_OF_SIZE,
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_LATENTS,
|
||||
RESIZE,
|
||||
} from './constants';
|
||||
import { set } from 'lodash-es';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
/**
|
||||
* Builds the Image to Image tab graph.
|
||||
*/
|
||||
export const buildLinearImageToImageGraph = (
|
||||
state: RootState
|
||||
): NonNullableGraph => {
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: model_name,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
initialImage,
|
||||
img2imgStrength: strength,
|
||||
shouldFitToWidthHeight,
|
||||
width,
|
||||
height,
|
||||
iterations,
|
||||
seed,
|
||||
shouldRandomizeSeed,
|
||||
} = state.generation;
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
* ids.
|
||||
*
|
||||
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||
* the `fit` param. These are added to the graph at the end.
|
||||
*/
|
||||
|
||||
if (!initialImage) {
|
||||
moduleLog.error('No initial image found in state');
|
||||
throw new Error('No initial image found in state');
|
||||
}
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
const graph: NonNullableGraph = {
|
||||
id: IMAGE_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
},
|
||||
[RANGE_OF_SIZE]: {
|
||||
type: 'range_of_size',
|
||||
id: RANGE_OF_SIZE,
|
||||
// seed - must be connected manually
|
||||
// start: 0,
|
||||
size: iterations,
|
||||
step: 1,
|
||||
},
|
||||
[NOISE]: {
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
},
|
||||
[MODEL_LOADER]: {
|
||||
type: 'sd1_model_loader',
|
||||
id: MODEL_LOADER,
|
||||
model_name,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
},
|
||||
[ITERATE]: {
|
||||
type: 'iterate',
|
||||
id: ITERATE,
|
||||
},
|
||||
[LATENTS_TO_LATENTS]: {
|
||||
type: 'l2l',
|
||||
id: LATENTS_TO_LATENTS,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
strength,
|
||||
},
|
||||
[IMAGE_TO_LATENTS]: {
|
||||
type: 'i2l',
|
||||
id: IMAGE_TO_LATENTS,
|
||||
// must be set manually later, bc `fit` parameter may require a resize node inserted
|
||||
// image: {
|
||||
// image_name: initialImage.image_name,
|
||||
// },
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: RANGE_OF_SIZE,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: ITERATE,
|
||||
field: 'collection',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: ITERATE,
|
||||
field: 'item',
|
||||
},
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'seed',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// handle seed
|
||||
if (shouldRandomizeSeed) {
|
||||
// Random int node to generate the starting seed
|
||||
const randomIntNode: RandomIntInvocation = {
|
||||
id: RANDOM_INT,
|
||||
type: 'rand_int',
|
||||
};
|
||||
|
||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||
|
||||
// Connect random int to the start of the range of size so the range starts on the random first seed
|
||||
graph.edges.push({
|
||||
source: { node_id: RANDOM_INT, field: 'a' },
|
||||
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
||||
});
|
||||
} else {
|
||||
// User specified seed, so set the start of the range of size to the seed
|
||||
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
|
||||
}
|
||||
|
||||
// handle `fit`
|
||||
if (
|
||||
shouldFitToWidthHeight &&
|
||||
(initialImage.width !== width || initialImage.height !== height)
|
||||
) {
|
||||
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||
|
||||
// Create a resize node, explicitly setting its image
|
||||
const resizeNode: ImageResizeInvocation = {
|
||||
id: RESIZE,
|
||||
type: 'img_resize',
|
||||
image: {
|
||||
image_name: initialImage.image_name,
|
||||
},
|
||||
is_intermediate: true,
|
||||
width,
|
||||
height,
|
||||
};
|
||||
|
||||
graph.nodes[RESIZE] = resizeNode;
|
||||
|
||||
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'image' },
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
|
||||
// The `RESIZE` node also passes its width and height to `NOISE`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
|
||||
image_name: initialImage.image_name,
|
||||
});
|
||||
|
||||
// Pass the image's dimensions to the `NOISE` node
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// add controlnet
|
||||
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
|
||||
|
||||
return graph;
|
||||
};
|
@ -0,0 +1,226 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api';
|
||||
import {
|
||||
ITERATE,
|
||||
LATENTS_TO_IMAGE,
|
||||
MODEL_LOADER,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
RANDOM_INT,
|
||||
RANGE_OF_SIZE,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
TEXT_TO_LATENTS,
|
||||
} from './constants';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
|
||||
type TextToImageGraphOverrides = {
|
||||
width: number;
|
||||
height: number;
|
||||
};
|
||||
|
||||
export const buildLinearTextToImageGraph = (
|
||||
state: RootState,
|
||||
overrides?: TextToImageGraphOverrides
|
||||
): NonNullableGraph => {
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: model_name,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
iterations,
|
||||
seed,
|
||||
shouldRandomizeSeed,
|
||||
} = state.generation;
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
* ids.
|
||||
*
|
||||
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||
* the `fit` param. These are added to the graph at the end.
|
||||
*/
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
const graph: NonNullableGraph = {
|
||||
id: TEXT_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'compel',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
},
|
||||
[RANGE_OF_SIZE]: {
|
||||
type: 'range_of_size',
|
||||
id: RANGE_OF_SIZE,
|
||||
// start: 0, // seed - must be connected manually
|
||||
size: iterations,
|
||||
step: 1,
|
||||
},
|
||||
[NOISE]: {
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
width: overrides?.width || width,
|
||||
height: overrides?.height || height,
|
||||
},
|
||||
[TEXT_TO_LATENTS]: {
|
||||
type: 't2l',
|
||||
id: TEXT_TO_LATENTS,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
},
|
||||
[MODEL_LOADER]: {
|
||||
type: 'sd1_model_loader',
|
||||
id: MODEL_LOADER,
|
||||
model_name,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
},
|
||||
[ITERATE]: {
|
||||
type: 'iterate',
|
||||
id: ITERATE,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: RANGE_OF_SIZE,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: ITERATE,
|
||||
field: 'collection',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: ITERATE,
|
||||
field: 'item',
|
||||
},
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'seed',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// handle seed
|
||||
if (shouldRandomizeSeed) {
|
||||
// Random int node to generate the starting seed
|
||||
const randomIntNode: RandomIntInvocation = {
|
||||
id: RANDOM_INT,
|
||||
type: 'rand_int',
|
||||
};
|
||||
|
||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||
|
||||
// Connect random int to the start of the range of size so the range starts on the random first seed
|
||||
graph.edges.push({
|
||||
source: { node_id: RANDOM_INT, field: 'a' },
|
||||
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
||||
});
|
||||
} else {
|
||||
// User specified seed, so set the start of the range of size to the seed
|
||||
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
|
||||
}
|
||||
|
||||
// add controlnet
|
||||
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
|
||||
|
||||
return graph;
|
||||
};
|
@ -1,316 +0,0 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
CompelInvocation,
|
||||
Graph,
|
||||
IterateInvocation,
|
||||
LatentsToImageInvocation,
|
||||
NoiseInvocation,
|
||||
RandomIntInvocation,
|
||||
RangeOfSizeInvocation,
|
||||
TextToLatentsInvocation,
|
||||
} from 'services/api';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
|
||||
const POSITIVE_CONDITIONING = 'positive_conditioning';
|
||||
const NEGATIVE_CONDITIONING = 'negative_conditioning';
|
||||
const TEXT_TO_LATENTS = 'text_to_latents';
|
||||
const LATENTS_TO_IMAGE = 'latents_to_image';
|
||||
const NOISE = 'noise';
|
||||
const RANDOM_INT = 'rand_int';
|
||||
const RANGE_OF_SIZE = 'range_of_size';
|
||||
const ITERATE = 'iterate';
|
||||
|
||||
/**
|
||||
* Builds the Text to Image tab graph.
|
||||
*/
|
||||
export const buildTextToImageGraph = (state: RootState): Graph => {
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
iterations,
|
||||
seed,
|
||||
shouldRandomizeSeed,
|
||||
} = state.generation;
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
nodes: {},
|
||||
edges: [],
|
||||
};
|
||||
|
||||
// Create the conditioning, t2l and l2i nodes
|
||||
const positiveConditioningNode: CompelInvocation = {
|
||||
id: POSITIVE_CONDITIONING,
|
||||
type: 'compel',
|
||||
prompt: positivePrompt,
|
||||
model,
|
||||
};
|
||||
|
||||
const negativeConditioningNode: CompelInvocation = {
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
type: 'compel',
|
||||
prompt: negativePrompt,
|
||||
model,
|
||||
};
|
||||
|
||||
const textToLatentsNode: TextToLatentsInvocation = {
|
||||
id: TEXT_TO_LATENTS,
|
||||
type: 't2l',
|
||||
cfg_scale,
|
||||
model,
|
||||
scheduler,
|
||||
steps,
|
||||
};
|
||||
|
||||
const latentsToImageNode: LatentsToImageInvocation = {
|
||||
id: LATENTS_TO_IMAGE,
|
||||
type: 'l2i',
|
||||
model,
|
||||
};
|
||||
|
||||
// Add to the graph
|
||||
graph.nodes[POSITIVE_CONDITIONING] = positiveConditioningNode;
|
||||
graph.nodes[NEGATIVE_CONDITIONING] = negativeConditioningNode;
|
||||
graph.nodes[TEXT_TO_LATENTS] = textToLatentsNode;
|
||||
graph.nodes[LATENTS_TO_IMAGE] = latentsToImageNode;
|
||||
|
||||
// Connect them
|
||||
graph.edges.push({
|
||||
source: { node_id: POSITIVE_CONDITIONING, field: 'conditioning' },
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: NEGATIVE_CONDITIONING, field: 'conditioning' },
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: TEXT_TO_LATENTS, field: 'latents' },
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
},
|
||||
});
|
||||
|
||||
/**
|
||||
* Now we need to handle iterations and random seeds. There are four possible scenarios:
|
||||
* - Single iteration, explicit seed
|
||||
* - Single iteration, random seed
|
||||
* - Multiple iterations, explicit seed
|
||||
* - Multiple iterations, random seed
|
||||
*
|
||||
* They all have different graphs and connections.
|
||||
*/
|
||||
|
||||
// Single iteration, explicit seed
|
||||
if (!shouldRandomizeSeed && iterations === 1) {
|
||||
// Noise node using the explicit seed
|
||||
const noiseNode: NoiseInvocation = {
|
||||
id: NOISE,
|
||||
type: 'noise',
|
||||
seed: seed,
|
||||
width,
|
||||
height,
|
||||
};
|
||||
|
||||
graph.nodes[NOISE] = noiseNode;
|
||||
|
||||
// Connect noise to l2l
|
||||
graph.edges.push({
|
||||
source: { node_id: NOISE, field: 'noise' },
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Single iteration, random seed
|
||||
if (shouldRandomizeSeed && iterations === 1) {
|
||||
// Random int node to generate the seed
|
||||
const randomIntNode: RandomIntInvocation = {
|
||||
id: RANDOM_INT,
|
||||
type: 'rand_int',
|
||||
};
|
||||
|
||||
// Noise node without any seed
|
||||
const noiseNode: NoiseInvocation = {
|
||||
id: NOISE,
|
||||
type: 'noise',
|
||||
width,
|
||||
height,
|
||||
};
|
||||
|
||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||
graph.nodes[NOISE] = noiseNode;
|
||||
|
||||
// Connect random int to the seed of the noise node
|
||||
graph.edges.push({
|
||||
source: { node_id: RANDOM_INT, field: 'a' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'seed',
|
||||
},
|
||||
});
|
||||
|
||||
// Connect noise to t2l
|
||||
graph.edges.push({
|
||||
source: { node_id: NOISE, field: 'noise' },
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Multiple iterations, explicit seed
|
||||
if (!shouldRandomizeSeed && iterations > 1) {
|
||||
// Range of size node to generate `iterations` count of seeds - range of size generates a collection
|
||||
// of ints from `start` to `start + size`. The `start` is the seed, and the `size` is the number of
|
||||
// iterations.
|
||||
const rangeOfSizeNode: RangeOfSizeInvocation = {
|
||||
id: RANGE_OF_SIZE,
|
||||
type: 'range_of_size',
|
||||
start: seed,
|
||||
size: iterations,
|
||||
};
|
||||
|
||||
// Iterate node to iterate over the seeds generated by the range of size node
|
||||
const iterateNode: IterateInvocation = {
|
||||
id: ITERATE,
|
||||
type: 'iterate',
|
||||
};
|
||||
|
||||
// Noise node without any seed
|
||||
const noiseNode: NoiseInvocation = {
|
||||
id: NOISE,
|
||||
type: 'noise',
|
||||
width,
|
||||
height,
|
||||
};
|
||||
|
||||
// Adding to the graph
|
||||
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
|
||||
graph.nodes[ITERATE] = iterateNode;
|
||||
graph.nodes[NOISE] = noiseNode;
|
||||
|
||||
// Connect range of size to iterate
|
||||
graph.edges.push({
|
||||
source: { node_id: RANGE_OF_SIZE, field: 'collection' },
|
||||
destination: {
|
||||
node_id: ITERATE,
|
||||
field: 'collection',
|
||||
},
|
||||
});
|
||||
|
||||
// Connect iterate to noise
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: ITERATE,
|
||||
field: 'item',
|
||||
},
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'seed',
|
||||
},
|
||||
});
|
||||
|
||||
// Connect noise to t2l
|
||||
graph.edges.push({
|
||||
source: { node_id: NOISE, field: 'noise' },
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Multiple iterations, random seed
|
||||
if (shouldRandomizeSeed && iterations > 1) {
|
||||
// Random int node to generate the seed
|
||||
const randomIntNode: RandomIntInvocation = {
|
||||
id: RANDOM_INT,
|
||||
type: 'rand_int',
|
||||
};
|
||||
|
||||
// Range of size node to generate `iterations` count of seeds - range of size generates a collection
|
||||
const rangeOfSizeNode: RangeOfSizeInvocation = {
|
||||
id: RANGE_OF_SIZE,
|
||||
type: 'range_of_size',
|
||||
size: iterations,
|
||||
};
|
||||
|
||||
// Iterate node to iterate over the seeds generated by the range of size node
|
||||
const iterateNode: IterateInvocation = {
|
||||
id: ITERATE,
|
||||
type: 'iterate',
|
||||
};
|
||||
|
||||
// Noise node without any seed
|
||||
const noiseNode: NoiseInvocation = {
|
||||
id: NOISE,
|
||||
type: 'noise',
|
||||
width,
|
||||
height,
|
||||
};
|
||||
|
||||
// Adding to the graph
|
||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
|
||||
graph.nodes[ITERATE] = iterateNode;
|
||||
graph.nodes[NOISE] = noiseNode;
|
||||
|
||||
// Connect random int to the start of the range of size so the range starts on the random first seed
|
||||
graph.edges.push({
|
||||
source: { node_id: RANDOM_INT, field: 'a' },
|
||||
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
||||
});
|
||||
|
||||
// Connect range of size to iterate
|
||||
graph.edges.push({
|
||||
source: { node_id: RANGE_OF_SIZE, field: 'collection' },
|
||||
destination: {
|
||||
node_id: ITERATE,
|
||||
field: 'collection',
|
||||
},
|
||||
});
|
||||
|
||||
// Connect iterate to noise
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: ITERATE,
|
||||
field: 'item',
|
||||
},
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'seed',
|
||||
},
|
||||
});
|
||||
|
||||
// Connect noise to t2l
|
||||
graph.edges.push({
|
||||
source: { node_id: NOISE, field: 'noise' },
|
||||
destination: {
|
||||
node_id: TEXT_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
|
||||
|
||||
return graph;
|
||||
};
|
@ -0,0 +1,20 @@
|
||||
// friendly node ids
|
||||
export const POSITIVE_CONDITIONING = 'positive_conditioning';
|
||||
export const NEGATIVE_CONDITIONING = 'negative_conditioning';
|
||||
export const TEXT_TO_LATENTS = 'text_to_latents';
|
||||
export const LATENTS_TO_IMAGE = 'latents_to_image';
|
||||
export const NOISE = 'noise';
|
||||
export const RANDOM_INT = 'rand_int';
|
||||
export const RANGE_OF_SIZE = 'range_of_size';
|
||||
export const ITERATE = 'iterate';
|
||||
export const MODEL_LOADER = 'model_loader';
|
||||
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||
export const RESIZE = 'resize_image';
|
||||
export const INPAINT = 'inpaint';
|
||||
export const CONTROL_NET_COLLECT = 'control_net_collect';
|
||||
|
||||
// friendly graph ids
|
||||
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
||||
export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph';
|
||||
export const INPAINT_GRAPH = 'inpaint_graph';
|
@ -1,12 +1,11 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { Scheduler } from 'app/constants';
|
||||
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect, {
|
||||
IAISelectDataType,
|
||||
} from 'common/components/IAIMantineSelect';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { setScheduler } from 'features/parameters/store/generationSlice';
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -14,30 +13,36 @@ import { useTranslation } from 'react-i18next';
|
||||
const selector = createSelector(
|
||||
[uiSelector, generationSelector],
|
||||
(ui, generation) => {
|
||||
const allSchedulers: string[] = ui.schedulers
|
||||
.slice()
|
||||
.sort((a, b) => a.localeCompare(b));
|
||||
const { scheduler } = generation;
|
||||
const { favoriteSchedulers: enabledSchedulers } = ui;
|
||||
|
||||
const data = SCHEDULER_NAMES.map((schedulerName) => ({
|
||||
value: schedulerName,
|
||||
label: SCHEDULER_LABEL_MAP[schedulerName as SchedulerParam],
|
||||
group: enabledSchedulers.includes(schedulerName)
|
||||
? 'Favorites'
|
||||
: undefined,
|
||||
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
return {
|
||||
scheduler: generation.scheduler,
|
||||
allSchedulers,
|
||||
scheduler,
|
||||
data,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamScheduler = () => {
|
||||
const { allSchedulers, scheduler } = useAppSelector(selector);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { scheduler, data } = useAppSelector(selector);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
dispatch(setScheduler(v as Scheduler));
|
||||
dispatch(setScheduler(v as SchedulerParam));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
@ -46,7 +51,7 @@ const ParamScheduler = () => {
|
||||
<IAIMantineSelect
|
||||
label={t('parameters.scheduler')}
|
||||
value={scheduler}
|
||||
data={allSchedulers}
|
||||
data={data}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
);
|
||||
|
@ -1,12 +1,12 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import ModelSelect from 'features/system/components/ModelSelect';
|
||||
import { memo } from 'react';
|
||||
import ParamScheduler from './ParamScheduler';
|
||||
|
||||
const ParamSchedulerAndModel = () => {
|
||||
return (
|
||||
<Flex gap={3} w="full">
|
||||
<Box w="16rem">
|
||||
<Box w="20rem">
|
||||
<ParamScheduler />
|
||||
</Box>
|
||||
<Box w="full">
|
||||
|
@ -1,10 +1,10 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { clamp, sortBy } from 'lodash-es';
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
import { Scheduler } from 'app/constants';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { clamp, sortBy } from 'lodash-es';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { imageUrlsReceived } from 'services/thunks/image';
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
import {
|
||||
CfgScaleParam,
|
||||
HeightParam,
|
||||
@ -17,7 +17,7 @@ import {
|
||||
StrengthParam,
|
||||
WidthParam,
|
||||
} from './parameterZodSchemas';
|
||||
import { imageUrlsReceived } from 'services/thunks/image';
|
||||
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
|
||||
|
||||
export interface GenerationState {
|
||||
cfgScale: CfgScaleParam;
|
||||
@ -63,7 +63,7 @@ export const initialGenerationState: GenerationState = {
|
||||
perlin: 0,
|
||||
positivePrompt: '',
|
||||
negativePrompt: '',
|
||||
scheduler: 'euler',
|
||||
scheduler: DEFAULT_SCHEDULER_NAME,
|
||||
seamBlur: 16,
|
||||
seamSize: 96,
|
||||
seamSteps: 30,
|
||||
@ -133,7 +133,7 @@ export const generationSlice = createSlice({
|
||||
setWidth: (state, action: PayloadAction<number>) => {
|
||||
state.width = action.payload;
|
||||
},
|
||||
setScheduler: (state, action: PayloadAction<Scheduler>) => {
|
||||
setScheduler: (state, action: PayloadAction<SchedulerParam>) => {
|
||||
state.scheduler = action.payload;
|
||||
},
|
||||
setSeed: (state, action: PayloadAction<number>) => {
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { NUMPY_RAND_MAX, SCHEDULERS } from 'app/constants';
|
||||
import { NUMPY_RAND_MAX, SCHEDULER_NAMES_AS_CONST } from 'app/constants';
|
||||
import { z } from 'zod';
|
||||
|
||||
/**
|
||||
@ -73,7 +73,7 @@ export const isValidCfgScale = (val: unknown): val is CfgScaleParam =>
|
||||
/**
|
||||
* Zod schema for scheduler parameter
|
||||
*/
|
||||
export const zScheduler = z.enum(SCHEDULERS);
|
||||
export const zScheduler = z.enum(SCHEDULER_NAMES_AS_CONST);
|
||||
/**
|
||||
* Type alias for scheduler parameter, inferred from its zod schema
|
||||
*/
|
||||
|
@ -1,47 +1,44 @@
|
||||
import {
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItemOption,
|
||||
MenuList,
|
||||
MenuOptionGroup,
|
||||
} from '@chakra-ui/react';
|
||||
import { SCHEDULERS } from 'app/constants';
|
||||
|
||||
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
|
||||
import { RootState } from 'app/store/store';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { setSchedulers } from 'features/ui/store/uiSlice';
|
||||
import { isArray } from 'lodash-es';
|
||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
import { favoriteSchedulersChanged } from 'features/ui/store/uiSlice';
|
||||
import { map } from 'lodash-es';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function SettingsSchedulers() {
|
||||
const schedulers = useAppSelector((state: RootState) => state.ui.schedulers);
|
||||
const data = map(SCHEDULER_NAMES, (s) => ({
|
||||
value: s,
|
||||
label: SCHEDULER_LABEL_MAP[s],
|
||||
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
export default function SettingsSchedulers() {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const schedulerSettingsHandler = (v: string | string[]) => {
|
||||
if (isArray(v)) dispatch(setSchedulers(v.sort()));
|
||||
};
|
||||
const enabledSchedulers = useAppSelector(
|
||||
(state: RootState) => state.ui.favoriteSchedulers
|
||||
);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string[]) => {
|
||||
dispatch(favoriteSchedulersChanged(v as SchedulerParam[]));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Menu closeOnSelect={false}>
|
||||
<MenuButton as={IAIButton}>
|
||||
{t('settings.availableSchedulers')}
|
||||
</MenuButton>
|
||||
<MenuList maxHeight={64} overflowY="scroll">
|
||||
<MenuOptionGroup
|
||||
value={schedulers}
|
||||
type="checkbox"
|
||||
onChange={schedulerSettingsHandler}
|
||||
>
|
||||
{SCHEDULERS.map((scheduler) => (
|
||||
<MenuItemOption key={scheduler} value={scheduler}>
|
||||
{scheduler}
|
||||
</MenuItemOption>
|
||||
))}
|
||||
</MenuOptionGroup>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
<IAIMantineMultiSelect
|
||||
label={t('settings.favoriteSchedulers')}
|
||||
value={enabledSchedulers}
|
||||
data={data}
|
||||
onChange={handleChange}
|
||||
clearable
|
||||
searchable
|
||||
maxSelectedValues={99}
|
||||
placeholder={t('settings.favoriteSchedulersPlaceholder')}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
@ -1,5 +1,4 @@
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
import ParamSeedCollapse from 'features/parameters/components/Parameters/Seed/ParamSeedCollapse';
|
||||
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
||||
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
|
||||
@ -8,6 +7,7 @@ import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
|
||||
import { memo } from 'react';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||
|
||||
const UnifiedCanvasParameters = () => {
|
||||
return (
|
||||
@ -16,6 +16,7 @@ const UnifiedCanvasParameters = () => {
|
||||
<ParamNegativeConditioning />
|
||||
<ProcessButtons />
|
||||
<UnifiedCanvasCoreParameters />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamVariationCollapse />
|
||||
<ParamSymmetryCollapse />
|
||||
<ParamSeamCorrectionCollapse />
|
||||
|
@ -1,10 +1,10 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { setActiveTabReducer } from './extraReducers';
|
||||
import { InvokeTabName } from './tabMap';
|
||||
import { AddNewModelType, UIState } from './uiTypes';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { SCHEDULERS } from 'app/constants';
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
|
||||
export const initialUIState: UIState = {
|
||||
activeTab: 0,
|
||||
@ -20,7 +20,7 @@ export const initialUIState: UIState = {
|
||||
shouldShowGallery: true,
|
||||
shouldHidePreview: false,
|
||||
shouldShowProgressInViewer: true,
|
||||
schedulers: SCHEDULERS,
|
||||
favoriteSchedulers: [],
|
||||
};
|
||||
|
||||
export const uiSlice = createSlice({
|
||||
@ -94,9 +94,11 @@ export const uiSlice = createSlice({
|
||||
setShouldShowProgressInViewer: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowProgressInViewer = action.payload;
|
||||
},
|
||||
setSchedulers: (state, action: PayloadAction<string[]>) => {
|
||||
state.schedulers = [];
|
||||
state.schedulers = action.payload;
|
||||
favoriteSchedulersChanged: (
|
||||
state,
|
||||
action: PayloadAction<SchedulerParam[]>
|
||||
) => {
|
||||
state.favoriteSchedulers = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
@ -124,7 +126,7 @@ export const {
|
||||
toggleParametersPanel,
|
||||
toggleGalleryPanel,
|
||||
setShouldShowProgressInViewer,
|
||||
setSchedulers,
|
||||
favoriteSchedulersChanged,
|
||||
} = uiSlice.actions;
|
||||
|
||||
export default uiSlice.reducer;
|
||||
|
@ -1,3 +1,5 @@
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
|
||||
export type AddNewModelType = 'ckpt' | 'diffusers' | null;
|
||||
|
||||
export type Coordinates = {
|
||||
@ -26,5 +28,5 @@ export interface UIState {
|
||||
shouldPinGallery: boolean;
|
||||
shouldShowGallery: boolean;
|
||||
shouldShowProgressInViewer: boolean;
|
||||
schedulers: string[];
|
||||
favoriteSchedulers: SchedulerParam[];
|
||||
}
|
||||
|
@ -7,9 +7,11 @@ export { OpenAPI } from './core/OpenAPI';
|
||||
export type { OpenAPIConfig } from './core/OpenAPI';
|
||||
|
||||
export type { AddInvocation } from './models/AddInvocation';
|
||||
export type { BaseModelType } from './models/BaseModelType';
|
||||
export type { Body_upload_image } from './models/Body_upload_image';
|
||||
export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation';
|
||||
export type { CkptModelInfo } from './models/CkptModelInfo';
|
||||
export type { ClipField } from './models/ClipField';
|
||||
export type { CollectInvocation } from './models/CollectInvocation';
|
||||
export type { CollectInvocationOutput } from './models/CollectInvocationOutput';
|
||||
export type { ColorField } from './models/ColorField';
|
||||
@ -53,7 +55,6 @@ export type { ImageProcessorInvocation } from './models/ImageProcessorInvocation
|
||||
export type { ImageRecordChanges } from './models/ImageRecordChanges';
|
||||
export type { ImageResizeInvocation } from './models/ImageResizeInvocation';
|
||||
export type { ImageScaleInvocation } from './models/ImageScaleInvocation';
|
||||
export type { ImageToImageInvocation } from './models/ImageToImageInvocation';
|
||||
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
|
||||
export type { ImageUrlsDTO } from './models/ImageUrlsDTO';
|
||||
export type { InfillColorInvocation } from './models/InfillColorInvocation';
|
||||
@ -62,6 +63,14 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation';
|
||||
export type { InpaintInvocation } from './models/InpaintInvocation';
|
||||
export type { IntCollectionOutput } from './models/IntCollectionOutput';
|
||||
export type { IntOutput } from './models/IntOutput';
|
||||
export type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config';
|
||||
export type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './models/invokeai__backend__model_management__models__lora__LoRAModel__Config';
|
||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig';
|
||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig';
|
||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig';
|
||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig';
|
||||
export type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config';
|
||||
export type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './models/invokeai__backend__model_management__models__vae__VaeModel__Config';
|
||||
export type { IterateInvocation } from './models/IterateInvocation';
|
||||
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
|
||||
export type { LatentsField } from './models/LatentsField';
|
||||
@ -71,12 +80,20 @@ export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvoca
|
||||
export type { LineartAnimeImageProcessorInvocation } from './models/LineartAnimeImageProcessorInvocation';
|
||||
export type { LineartImageProcessorInvocation } from './models/LineartImageProcessorInvocation';
|
||||
export type { LoadImageInvocation } from './models/LoadImageInvocation';
|
||||
export type { LoraInfo } from './models/LoraInfo';
|
||||
export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation';
|
||||
export type { LoraLoaderOutput } from './models/LoraLoaderOutput';
|
||||
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
||||
export type { MaskOutput } from './models/MaskOutput';
|
||||
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
|
||||
export type { MidasDepthImageProcessorInvocation } from './models/MidasDepthImageProcessorInvocation';
|
||||
export type { MlsdImageProcessorInvocation } from './models/MlsdImageProcessorInvocation';
|
||||
export type { ModelError } from './models/ModelError';
|
||||
export type { ModelInfo } from './models/ModelInfo';
|
||||
export type { ModelLoaderOutput } from './models/ModelLoaderOutput';
|
||||
export type { ModelsList } from './models/ModelsList';
|
||||
export type { ModelType } from './models/ModelType';
|
||||
export type { ModelVariantType } from './models/ModelVariantType';
|
||||
export type { MultiplyInvocation } from './models/MultiplyInvocation';
|
||||
export type { NoiseInvocation } from './models/NoiseInvocation';
|
||||
export type { NoiseOutput } from './models/NoiseOutput';
|
||||
@ -97,12 +114,17 @@ export type { ResizeLatentsInvocation } from './models/ResizeLatentsInvocation';
|
||||
export type { ResourceOrigin } from './models/ResourceOrigin';
|
||||
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
|
||||
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
|
||||
export type { SchedulerPredictionType } from './models/SchedulerPredictionType';
|
||||
export type { SD1ModelLoaderInvocation } from './models/SD1ModelLoaderInvocation';
|
||||
export type { SD2ModelLoaderInvocation } from './models/SD2ModelLoaderInvocation';
|
||||
export type { ShowImageInvocation } from './models/ShowImageInvocation';
|
||||
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
|
||||
export type { SubModelType } from './models/SubModelType';
|
||||
export type { SubtractInvocation } from './models/SubtractInvocation';
|
||||
export type { TextToImageInvocation } from './models/TextToImageInvocation';
|
||||
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
|
||||
export type { UNetField } from './models/UNetField';
|
||||
export type { UpscaleInvocation } from './models/UpscaleInvocation';
|
||||
export type { VaeField } from './models/VaeField';
|
||||
export type { VaeRepo } from './models/VaeRepo';
|
||||
export type { ValidationError } from './models/ValidationError';
|
||||
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
|
||||
|
@ -0,0 +1,8 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* An enumeration.
|
||||
*/
|
||||
export type BaseModelType = 'sd-1' | 'sd-2';
|
@ -7,6 +7,14 @@ export type CkptModelInfo = {
|
||||
* A description of the model
|
||||
*/
|
||||
description?: string;
|
||||
/**
|
||||
* The name of the model
|
||||
*/
|
||||
model_name: string;
|
||||
/**
|
||||
* The type of the model
|
||||
*/
|
||||
model_type: string;
|
||||
format?: 'ckpt';
|
||||
/**
|
||||
* The path to the model config
|
||||
|
22
invokeai/frontend/web/src/services/api/models/ClipField.ts
Normal file
22
invokeai/frontend/web/src/services/api/models/ClipField.ts
Normal file
@ -0,0 +1,22 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { LoraInfo } from './LoraInfo';
|
||||
import type { ModelInfo } from './ModelInfo';
|
||||
|
||||
export type ClipField = {
|
||||
/**
|
||||
* Info to load tokenizer submodel
|
||||
*/
|
||||
tokenizer: ModelInfo;
|
||||
/**
|
||||
* Info to load text_encoder submodel
|
||||
*/
|
||||
text_encoder: ModelInfo;
|
||||
/**
|
||||
* Loras to apply on model loading
|
||||
*/
|
||||
loras: Array<LoraInfo>;
|
||||
};
|
||||
|
@ -2,6 +2,8 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ClipField } from './ClipField';
|
||||
|
||||
/**
|
||||
* Parse prompt using compel package to conditioning.
|
||||
*/
|
||||
@ -20,8 +22,8 @@ export type CompelInvocation = {
|
||||
*/
|
||||
prompt?: string;
|
||||
/**
|
||||
* Model to use
|
||||
* Clip to use
|
||||
*/
|
||||
model?: string;
|
||||
clip?: ClipField;
|
||||
};
|
||||
|
||||
|
@ -9,7 +9,15 @@ export type DiffusersModelInfo = {
|
||||
* A description of the model
|
||||
*/
|
||||
description?: string;
|
||||
format?: 'diffusers';
|
||||
/**
|
||||
* The name of the model
|
||||
*/
|
||||
model_name: string;
|
||||
/**
|
||||
* The type of the model
|
||||
*/
|
||||
model_type: string;
|
||||
format?: 'folder';
|
||||
/**
|
||||
* The VAE repo to use for this model
|
||||
*/
|
||||
|
@ -26,7 +26,6 @@ import type { ImagePasteInvocation } from './ImagePasteInvocation';
|
||||
import type { ImageProcessorInvocation } from './ImageProcessorInvocation';
|
||||
import type { ImageResizeInvocation } from './ImageResizeInvocation';
|
||||
import type { ImageScaleInvocation } from './ImageScaleInvocation';
|
||||
import type { ImageToImageInvocation } from './ImageToImageInvocation';
|
||||
import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation';
|
||||
import type { InfillColorInvocation } from './InfillColorInvocation';
|
||||
import type { InfillPatchMatchInvocation } from './InfillPatchMatchInvocation';
|
||||
@ -38,6 +37,7 @@ import type { LatentsToLatentsInvocation } from './LatentsToLatentsInvocation';
|
||||
import type { LineartAnimeImageProcessorInvocation } from './LineartAnimeImageProcessorInvocation';
|
||||
import type { LineartImageProcessorInvocation } from './LineartImageProcessorInvocation';
|
||||
import type { LoadImageInvocation } from './LoadImageInvocation';
|
||||
import type { LoraLoaderInvocation } from './LoraLoaderInvocation';
|
||||
import type { MaskFromAlphaInvocation } from './MaskFromAlphaInvocation';
|
||||
import type { MediapipeFaceProcessorInvocation } from './MediapipeFaceProcessorInvocation';
|
||||
import type { MidasDepthImageProcessorInvocation } from './MidasDepthImageProcessorInvocation';
|
||||
@ -56,10 +56,11 @@ import type { RangeOfSizeInvocation } from './RangeOfSizeInvocation';
|
||||
import type { ResizeLatentsInvocation } from './ResizeLatentsInvocation';
|
||||
import type { RestoreFaceInvocation } from './RestoreFaceInvocation';
|
||||
import type { ScaleLatentsInvocation } from './ScaleLatentsInvocation';
|
||||
import type { SD1ModelLoaderInvocation } from './SD1ModelLoaderInvocation';
|
||||
import type { SD2ModelLoaderInvocation } from './SD2ModelLoaderInvocation';
|
||||
import type { ShowImageInvocation } from './ShowImageInvocation';
|
||||
import type { StepParamEasingInvocation } from './StepParamEasingInvocation';
|
||||
import type { SubtractInvocation } from './SubtractInvocation';
|
||||
import type { TextToImageInvocation } from './TextToImageInvocation';
|
||||
import type { TextToLatentsInvocation } from './TextToLatentsInvocation';
|
||||
import type { UpscaleInvocation } from './UpscaleInvocation';
|
||||
import type { ZoeDepthImageProcessorInvocation } from './ZoeDepthImageProcessorInvocation';
|
||||
@ -72,7 +73,7 @@ export type Graph = {
|
||||
/**
|
||||
* The nodes in this graph
|
||||
*/
|
||||
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
|
||||
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation)>;
|
||||
/**
|
||||
* The connections between nodes and their fields in this graph
|
||||
*/
|
||||
|
@ -14,7 +14,9 @@ import type { IntCollectionOutput } from './IntCollectionOutput';
|
||||
import type { IntOutput } from './IntOutput';
|
||||
import type { IterateInvocationOutput } from './IterateInvocationOutput';
|
||||
import type { LatentsOutput } from './LatentsOutput';
|
||||
import type { LoraLoaderOutput } from './LoraLoaderOutput';
|
||||
import type { MaskOutput } from './MaskOutput';
|
||||
import type { ModelLoaderOutput } from './ModelLoaderOutput';
|
||||
import type { NoiseOutput } from './NoiseOutput';
|
||||
import type { PromptCollectionOutput } from './PromptCollectionOutput';
|
||||
import type { PromptOutput } from './PromptOutput';
|
||||
@ -46,7 +48,7 @@ export type GraphExecutionState = {
|
||||
/**
|
||||
* The results of node executions
|
||||
*/
|
||||
results: Record<string, (ImageOutput | MaskOutput | ControlOutput | PromptOutput | PromptCollectionOutput | CompelOutput | IntOutput | FloatOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | FloatCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
|
||||
results: Record<string, (ImageOutput | MaskOutput | ControlOutput | ModelLoaderOutput | LoraLoaderOutput | PromptOutput | PromptCollectionOutput | CompelOutput | IntOutput | FloatOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | FloatCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
|
||||
/**
|
||||
* Errors raised when executing nodes
|
||||
*/
|
||||
|
@ -1,77 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Generates an image using img2img.
|
||||
*/
|
||||
export type ImageToImageInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'img2img';
|
||||
/**
|
||||
* The prompt to generate an image from
|
||||
*/
|
||||
prompt?: string;
|
||||
/**
|
||||
* The seed to use (omit for random)
|
||||
*/
|
||||
seed?: number;
|
||||
/**
|
||||
* The number of steps to use to generate the image
|
||||
*/
|
||||
steps?: number;
|
||||
/**
|
||||
* The width of the resulting image
|
||||
*/
|
||||
width?: number;
|
||||
/**
|
||||
* The height of the resulting image
|
||||
*/
|
||||
height?: number;
|
||||
/**
|
||||
* The Classifier-Free Guidance, higher values may result in a result closer to the prompt
|
||||
*/
|
||||
cfg_scale?: number;
|
||||
/**
|
||||
* The scheduler to use
|
||||
*/
|
||||
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'unipc';
|
||||
/**
|
||||
* The model to use (currently ignored)
|
||||
*/
|
||||
model?: string;
|
||||
/**
|
||||
* Whether or not to produce progress images during generation
|
||||
*/
|
||||
progress_images?: boolean;
|
||||
/**
|
||||
* The control model to use
|
||||
*/
|
||||
control_model?: string;
|
||||
/**
|
||||
* The processed control image
|
||||
*/
|
||||
control_image?: ImageField;
|
||||
/**
|
||||
* The input image
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* The strength of the original image
|
||||
*/
|
||||
strength?: number;
|
||||
/**
|
||||
* Whether or not the result should be fit to the aspect ratio of the input image
|
||||
*/
|
||||
fit?: boolean;
|
||||
};
|
||||
|
@ -3,6 +3,7 @@
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
import type { VaeField } from './VaeField';
|
||||
|
||||
/**
|
||||
* Encodes an image into latents.
|
||||
@ -22,8 +23,12 @@ export type ImageToLatentsInvocation = {
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* The model to use
|
||||
* Vae submodel
|
||||
*/
|
||||
model?: string;
|
||||
vae?: VaeField;
|
||||
/**
|
||||
* Encode latents by overlaping tiles(less memory consumption)
|
||||
*/
|
||||
tiled?: boolean;
|
||||
};
|
||||
|
||||
|
@ -3,7 +3,10 @@
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ColorField } from './ColorField';
|
||||
import type { ConditioningField } from './ConditioningField';
|
||||
import type { ImageField } from './ImageField';
|
||||
import type { UNetField } from './UNetField';
|
||||
import type { VaeField } from './VaeField';
|
||||
|
||||
/**
|
||||
* Generates an image using inpaint.
|
||||
@ -19,9 +22,13 @@ export type InpaintInvocation = {
|
||||
is_intermediate?: boolean;
|
||||
type?: 'inpaint';
|
||||
/**
|
||||
* The prompt to generate an image from
|
||||
* Positive conditioning for generation
|
||||
*/
|
||||
prompt?: string;
|
||||
positive_conditioning?: ConditioningField;
|
||||
/**
|
||||
* Negative conditioning for generation
|
||||
*/
|
||||
negative_conditioning?: ConditioningField;
|
||||
/**
|
||||
* The seed to use (omit for random)
|
||||
*/
|
||||
@ -45,23 +52,15 @@ export type InpaintInvocation = {
|
||||
/**
|
||||
* The scheduler to use
|
||||
*/
|
||||
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'unipc';
|
||||
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'lms_k' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2s_k' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'dpmpp_2m_sde' | 'dpmpp_2m_sde_k' | 'dpmpp_sde' | 'dpmpp_sde_k' | 'unipc';
|
||||
/**
|
||||
* The model to use (currently ignored)
|
||||
* UNet model
|
||||
*/
|
||||
model?: string;
|
||||
unet?: UNetField;
|
||||
/**
|
||||
* Whether or not to produce progress images during generation
|
||||
* Vae model
|
||||
*/
|
||||
progress_images?: boolean;
|
||||
/**
|
||||
* The control model to use
|
||||
*/
|
||||
control_model?: string;
|
||||
/**
|
||||
* The processed control image
|
||||
*/
|
||||
control_image?: ImageField;
|
||||
vae?: VaeField;
|
||||
/**
|
||||
* The input image
|
||||
*/
|
||||
|
@ -3,6 +3,7 @@
|
||||
/* eslint-disable */
|
||||
|
||||
import type { LatentsField } from './LatentsField';
|
||||
import type { VaeField } from './VaeField';
|
||||
|
||||
/**
|
||||
* Generates an image from latents.
|
||||
@ -22,8 +23,12 @@ export type LatentsToImageInvocation = {
|
||||
*/
|
||||
latents?: LatentsField;
|
||||
/**
|
||||
* The model to use
|
||||
* Vae submodel
|
||||
*/
|
||||
model?: string;
|
||||
vae?: VaeField;
|
||||
/**
|
||||
* Decode latents by overlaping tiles(less memory consumption)
|
||||
*/
|
||||
tiled?: boolean;
|
||||
};
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
import type { ConditioningField } from './ConditioningField';
|
||||
import type { ControlField } from './ControlField';
|
||||
import type { LatentsField } from './LatentsField';
|
||||
import type { UNetField } from './UNetField';
|
||||
|
||||
/**
|
||||
* Generates latents using latents as base image.
|
||||
@ -42,11 +43,11 @@ export type LatentsToLatentsInvocation = {
|
||||
/**
|
||||
* The scheduler to use
|
||||
*/
|
||||
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'unipc';
|
||||
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'lms_k' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2s_k' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'dpmpp_2m_sde' | 'dpmpp_2m_sde_k' | 'dpmpp_sde' | 'dpmpp_sde_k' | 'unipc';
|
||||
/**
|
||||
* The model to use (currently ignored)
|
||||
* UNet submodel
|
||||
*/
|
||||
model?: string;
|
||||
unet?: UNetField;
|
||||
/**
|
||||
* The control to use
|
||||
*/
|
||||
|
31
invokeai/frontend/web/src/services/api/models/LoraInfo.ts
Normal file
31
invokeai/frontend/web/src/services/api/models/LoraInfo.ts
Normal file
@ -0,0 +1,31 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { BaseModelType } from './BaseModelType';
|
||||
import type { ModelType } from './ModelType';
|
||||
import type { SubModelType } from './SubModelType';
|
||||
|
||||
export type LoraInfo = {
|
||||
/**
|
||||
* Info to load submodel
|
||||
*/
|
||||
model_name: string;
|
||||
/**
|
||||
* Base model
|
||||
*/
|
||||
base_model: BaseModelType;
|
||||
/**
|
||||
* Info to load submodel
|
||||
*/
|
||||
model_type: ModelType;
|
||||
/**
|
||||
* Info to load submodel
|
||||
*/
|
||||
submodel?: SubModelType;
|
||||
/**
|
||||
* Lora's weight which to use when apply to model
|
||||
*/
|
||||
weight: number;
|
||||
};
|
||||
|
@ -0,0 +1,38 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ClipField } from './ClipField';
|
||||
import type { UNetField } from './UNetField';
|
||||
|
||||
/**
|
||||
* Apply selected lora to unet and text_encoder.
|
||||
*/
|
||||
export type LoraLoaderInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'lora_loader';
|
||||
/**
|
||||
* Lora model name
|
||||
*/
|
||||
lora_name: string;
|
||||
/**
|
||||
* With what weight to apply lora
|
||||
*/
|
||||
weight?: number;
|
||||
/**
|
||||
* UNet model for applying lora
|
||||
*/
|
||||
unet?: UNetField;
|
||||
/**
|
||||
* Clip model for applying lora
|
||||
*/
|
||||
clip?: ClipField;
|
||||
};
|
||||
|
@ -0,0 +1,22 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ClipField } from './ClipField';
|
||||
import type { UNetField } from './UNetField';
|
||||
|
||||
/**
|
||||
* Model loader output
|
||||
*/
|
||||
export type LoraLoaderOutput = {
|
||||
type?: 'lora_loader_output';
|
||||
/**
|
||||
* UNet submodel
|
||||
*/
|
||||
unet?: UNetField;
|
||||
/**
|
||||
* Tokenizer and text_encoder submodels
|
||||
*/
|
||||
clip?: ClipField;
|
||||
};
|
||||
|
@ -0,0 +1,8 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* An enumeration.
|
||||
*/
|
||||
export type ModelError = 'not_found';
|
27
invokeai/frontend/web/src/services/api/models/ModelInfo.ts
Normal file
27
invokeai/frontend/web/src/services/api/models/ModelInfo.ts
Normal file
@ -0,0 +1,27 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { BaseModelType } from './BaseModelType';
|
||||
import type { ModelType } from './ModelType';
|
||||
import type { SubModelType } from './SubModelType';
|
||||
|
||||
export type ModelInfo = {
|
||||
/**
|
||||
* Info to load submodel
|
||||
*/
|
||||
model_name: string;
|
||||
/**
|
||||
* Base model
|
||||
*/
|
||||
base_model: BaseModelType;
|
||||
/**
|
||||
* Info to load submodel
|
||||
*/
|
||||
model_type: ModelType;
|
||||
/**
|
||||
* Info to load submodel
|
||||
*/
|
||||
submodel?: SubModelType;
|
||||
};
|
||||
|
@ -0,0 +1,27 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ClipField } from './ClipField';
|
||||
import type { UNetField } from './UNetField';
|
||||
import type { VaeField } from './VaeField';
|
||||
|
||||
/**
|
||||
* Model loader output
|
||||
*/
|
||||
export type ModelLoaderOutput = {
|
||||
type?: 'model_loader_output';
|
||||
/**
|
||||
* UNet submodel
|
||||
*/
|
||||
unet?: UNetField;
|
||||
/**
|
||||
* Tokenizer and text_encoder submodels
|
||||
*/
|
||||
clip?: ClipField;
|
||||
/**
|
||||
* Vae submodel
|
||||
*/
|
||||
vae?: VaeField;
|
||||
};
|
||||
|
@ -0,0 +1,8 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* An enumeration.
|
||||
*/
|
||||
export type ModelType = 'pipeline' | 'vae' | 'lora' | 'controlnet' | 'embedding';
|
@ -0,0 +1,8 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* An enumeration.
|
||||
*/
|
||||
export type ModelVariantType = 'normal' | 'inpaint' | 'depth';
|
@ -2,10 +2,16 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { CkptModelInfo } from './CkptModelInfo';
|
||||
import type { DiffusersModelInfo } from './DiffusersModelInfo';
|
||||
import type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './invokeai__backend__model_management__models__controlnet__ControlNetModel__Config';
|
||||
import type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './invokeai__backend__model_management__models__lora__LoRAModel__Config';
|
||||
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig';
|
||||
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig';
|
||||
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig';
|
||||
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig';
|
||||
import type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config';
|
||||
import type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './invokeai__backend__model_management__models__vae__VaeModel__Config';
|
||||
|
||||
export type ModelsList = {
|
||||
models: Record<string, (CkptModelInfo | DiffusersModelInfo)>;
|
||||
models: Record<string, Record<string, Record<string, (invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig | invokeai__backend__model_management__models__controlnet__ControlNetModel__Config | invokeai__backend__model_management__models__lora__LoRAModel__Config | invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig | invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config | invokeai__backend__model_management__models__vae__VaeModel__Config | invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig | invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig)>>>;
|
||||
};
|
||||
|
||||
|
@ -0,0 +1,23 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Loading submodels of selected model.
|
||||
*/
|
||||
export type SD1ModelLoaderInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'sd1_model_loader';
|
||||
/**
|
||||
* Model to load
|
||||
*/
|
||||
model_name?: string;
|
||||
};
|
||||
|
@ -0,0 +1,23 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Loading submodels of selected model.
|
||||
*/
|
||||
export type SD2ModelLoaderInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'sd2_model_loader';
|
||||
/**
|
||||
* Model to load
|
||||
*/
|
||||
model_name?: string;
|
||||
};
|
||||
|
@ -0,0 +1,8 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* An enumeration.
|
||||
*/
|
||||
export type SchedulerPredictionType = 'epsilon' | 'v_prediction' | 'sample';
|
@ -0,0 +1,8 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* An enumeration.
|
||||
*/
|
||||
export type SubModelType = 'unet' | 'text_encoder' | 'tokenizer' | 'vae' | 'scheduler' | 'safety_checker';
|
@ -1,65 +0,0 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Generates an image using text2img.
|
||||
*/
|
||||
export type TextToImageInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'txt2img';
|
||||
/**
|
||||
* The prompt to generate an image from
|
||||
*/
|
||||
prompt?: string;
|
||||
/**
|
||||
* The seed to use (omit for random)
|
||||
*/
|
||||
seed?: number;
|
||||
/**
|
||||
* The number of steps to use to generate the image
|
||||
*/
|
||||
steps?: number;
|
||||
/**
|
||||
* The width of the resulting image
|
||||
*/
|
||||
width?: number;
|
||||
/**
|
||||
* The height of the resulting image
|
||||
*/
|
||||
height?: number;
|
||||
/**
|
||||
* The Classifier-Free Guidance, higher values may result in a result closer to the prompt
|
||||
*/
|
||||
cfg_scale?: number;
|
||||
/**
|
||||
* The scheduler to use
|
||||
*/
|
||||
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'unipc';
|
||||
/**
|
||||
* The model to use (currently ignored)
|
||||
*/
|
||||
model?: string;
|
||||
/**
|
||||
* Whether or not to produce progress images during generation
|
||||
*/
|
||||
progress_images?: boolean;
|
||||
/**
|
||||
* The control model to use
|
||||
*/
|
||||
control_model?: string;
|
||||
/**
|
||||
* The processed control image
|
||||
*/
|
||||
control_image?: ImageField;
|
||||
};
|
||||
|
@ -5,6 +5,7 @@
|
||||
import type { ConditioningField } from './ConditioningField';
|
||||
import type { ControlField } from './ControlField';
|
||||
import type { LatentsField } from './LatentsField';
|
||||
import type { UNetField } from './UNetField';
|
||||
|
||||
/**
|
||||
* Generates latents from conditionings.
|
||||
@ -42,11 +43,11 @@ export type TextToLatentsInvocation = {
|
||||
/**
|
||||
* The scheduler to use
|
||||
*/
|
||||
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'unipc';
|
||||
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'lms_k' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2s_k' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'dpmpp_2m_sde' | 'dpmpp_2m_sde_k' | 'dpmpp_sde' | 'dpmpp_sde_k' | 'unipc';
|
||||
/**
|
||||
* The model to use (currently ignored)
|
||||
* UNet submodel
|
||||
*/
|
||||
model?: string;
|
||||
unet?: UNetField;
|
||||
/**
|
||||
* The control to use
|
||||
*/
|
||||
|
22
invokeai/frontend/web/src/services/api/models/UNetField.ts
Normal file
22
invokeai/frontend/web/src/services/api/models/UNetField.ts
Normal file
@ -0,0 +1,22 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { LoraInfo } from './LoraInfo';
|
||||
import type { ModelInfo } from './ModelInfo';
|
||||
|
||||
export type UNetField = {
|
||||
/**
|
||||
* Info to load unet submodel
|
||||
*/
|
||||
unet: ModelInfo;
|
||||
/**
|
||||
* Info to load scheduler submodel
|
||||
*/
|
||||
scheduler: ModelInfo;
|
||||
/**
|
||||
* Loras to apply on model loading
|
||||
*/
|
||||
loras: Array<LoraInfo>;
|
||||
};
|
||||
|
13
invokeai/frontend/web/src/services/api/models/VaeField.ts
Normal file
13
invokeai/frontend/web/src/services/api/models/VaeField.ts
Normal file
@ -0,0 +1,13 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ModelInfo } from './ModelInfo';
|
||||
|
||||
export type VaeField = {
|
||||
/**
|
||||
* Info to load vae submodel
|
||||
*/
|
||||
vae: ModelInfo;
|
||||
};
|
||||
|
@ -0,0 +1,14 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ModelError } from './ModelError';
|
||||
|
||||
export type invokeai__backend__model_management__models__controlnet__ControlNetModel__Config = {
|
||||
path: string;
|
||||
description?: string;
|
||||
format: ('checkpoint' | 'diffusers');
|
||||
default?: boolean;
|
||||
error?: ModelError;
|
||||
};
|
||||
|
@ -0,0 +1,14 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ModelError } from './ModelError';
|
||||
|
||||
export type invokeai__backend__model_management__models__lora__LoRAModel__Config = {
|
||||
path: string;
|
||||
description?: string;
|
||||
format: ('lycoris' | 'diffusers');
|
||||
default?: boolean;
|
||||
error?: ModelError;
|
||||
};
|
||||
|
@ -0,0 +1,18 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ModelError } from './ModelError';
|
||||
import type { ModelVariantType } from './ModelVariantType';
|
||||
|
||||
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig = {
|
||||
path: string;
|
||||
description?: string;
|
||||
format: 'checkpoint';
|
||||
default?: boolean;
|
||||
error?: ModelError;
|
||||
vae?: string;
|
||||
config?: string;
|
||||
variant: ModelVariantType;
|
||||
};
|
||||
|
@ -0,0 +1,17 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ModelError } from './ModelError';
|
||||
import type { ModelVariantType } from './ModelVariantType';
|
||||
|
||||
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig = {
|
||||
path: string;
|
||||
description?: string;
|
||||
format: 'diffusers';
|
||||
default?: boolean;
|
||||
error?: ModelError;
|
||||
vae?: string;
|
||||
variant: ModelVariantType;
|
||||
};
|
||||
|
@ -0,0 +1,21 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ModelError } from './ModelError';
|
||||
import type { ModelVariantType } from './ModelVariantType';
|
||||
import type { SchedulerPredictionType } from './SchedulerPredictionType';
|
||||
|
||||
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig = {
|
||||
path: string;
|
||||
description?: string;
|
||||
format: 'checkpoint';
|
||||
default?: boolean;
|
||||
error?: ModelError;
|
||||
vae?: string;
|
||||
config?: string;
|
||||
variant: ModelVariantType;
|
||||
prediction_type: SchedulerPredictionType;
|
||||
upcast_attention: boolean;
|
||||
};
|
||||
|
@ -0,0 +1,20 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ModelError } from './ModelError';
|
||||
import type { ModelVariantType } from './ModelVariantType';
|
||||
import type { SchedulerPredictionType } from './SchedulerPredictionType';
|
||||
|
||||
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig = {
|
||||
path: string;
|
||||
description?: string;
|
||||
format: 'diffusers';
|
||||
default?: boolean;
|
||||
error?: ModelError;
|
||||
vae?: string;
|
||||
variant: ModelVariantType;
|
||||
prediction_type: SchedulerPredictionType;
|
||||
upcast_attention: boolean;
|
||||
};
|
||||
|
@ -0,0 +1,14 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ModelError } from './ModelError';
|
||||
|
||||
export type invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config = {
|
||||
path: string;
|
||||
description?: string;
|
||||
format: null;
|
||||
default?: boolean;
|
||||
error?: ModelError;
|
||||
};
|
||||
|
@ -0,0 +1,14 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ModelError } from './ModelError';
|
||||
|
||||
export type invokeai__backend__model_management__models__vae__VaeModel__Config = {
|
||||
path: string;
|
||||
description?: string;
|
||||
format: ('checkpoint' | 'diffusers');
|
||||
default?: boolean;
|
||||
error?: ModelError;
|
||||
};
|
||||
|
@ -1,8 +1,10 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
import type { BaseModelType } from '../models/BaseModelType';
|
||||
import type { CreateModelRequest } from '../models/CreateModelRequest';
|
||||
import type { ModelsList } from '../models/ModelsList';
|
||||
import type { ModelType } from '../models/ModelType';
|
||||
|
||||
import type { CancelablePromise } from '../core/CancelablePromise';
|
||||
import { OpenAPI } from '../core/OpenAPI';
|
||||
@ -16,10 +18,29 @@ export class ModelsService {
|
||||
* @returns ModelsList Successful Response
|
||||
* @throws ApiError
|
||||
*/
|
||||
public static listModels(): CancelablePromise<ModelsList> {
|
||||
public static listModels({
|
||||
baseModel,
|
||||
modelType,
|
||||
}: {
|
||||
/**
|
||||
* Base model
|
||||
*/
|
||||
baseModel?: BaseModelType,
|
||||
/**
|
||||
* The type of model to get
|
||||
*/
|
||||
modelType?: ModelType,
|
||||
}): CancelablePromise<ModelsList> {
|
||||
return __request(OpenAPI, {
|
||||
method: 'GET',
|
||||
url: '/api/v1/models/',
|
||||
query: {
|
||||
'base_model': baseModel,
|
||||
'model_type': modelType,
|
||||
},
|
||||
errors: {
|
||||
422: `Validation Error`,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,6 @@ import type { ImagePasteInvocation } from '../models/ImagePasteInvocation';
|
||||
import type { ImageProcessorInvocation } from '../models/ImageProcessorInvocation';
|
||||
import type { ImageResizeInvocation } from '../models/ImageResizeInvocation';
|
||||
import type { ImageScaleInvocation } from '../models/ImageScaleInvocation';
|
||||
import type { ImageToImageInvocation } from '../models/ImageToImageInvocation';
|
||||
import type { ImageToLatentsInvocation } from '../models/ImageToLatentsInvocation';
|
||||
import type { InfillColorInvocation } from '../models/InfillColorInvocation';
|
||||
import type { InfillPatchMatchInvocation } from '../models/InfillPatchMatchInvocation';
|
||||
@ -39,6 +38,7 @@ import type { LatentsToLatentsInvocation } from '../models/LatentsToLatentsInvoc
|
||||
import type { LineartAnimeImageProcessorInvocation } from '../models/LineartAnimeImageProcessorInvocation';
|
||||
import type { LineartImageProcessorInvocation } from '../models/LineartImageProcessorInvocation';
|
||||
import type { LoadImageInvocation } from '../models/LoadImageInvocation';
|
||||
import type { LoraLoaderInvocation } from '../models/LoraLoaderInvocation';
|
||||
import type { MaskFromAlphaInvocation } from '../models/MaskFromAlphaInvocation';
|
||||
import type { MediapipeFaceProcessorInvocation } from '../models/MediapipeFaceProcessorInvocation';
|
||||
import type { MidasDepthImageProcessorInvocation } from '../models/MidasDepthImageProcessorInvocation';
|
||||
@ -58,10 +58,11 @@ import type { RangeOfSizeInvocation } from '../models/RangeOfSizeInvocation';
|
||||
import type { ResizeLatentsInvocation } from '../models/ResizeLatentsInvocation';
|
||||
import type { RestoreFaceInvocation } from '../models/RestoreFaceInvocation';
|
||||
import type { ScaleLatentsInvocation } from '../models/ScaleLatentsInvocation';
|
||||
import type { SD1ModelLoaderInvocation } from '../models/SD1ModelLoaderInvocation';
|
||||
import type { SD2ModelLoaderInvocation } from '../models/SD2ModelLoaderInvocation';
|
||||
import type { ShowImageInvocation } from '../models/ShowImageInvocation';
|
||||
import type { StepParamEasingInvocation } from '../models/StepParamEasingInvocation';
|
||||
import type { SubtractInvocation } from '../models/SubtractInvocation';
|
||||
import type { TextToImageInvocation } from '../models/TextToImageInvocation';
|
||||
import type { TextToLatentsInvocation } from '../models/TextToLatentsInvocation';
|
||||
import type { UpscaleInvocation } from '../models/UpscaleInvocation';
|
||||
import type { ZoeDepthImageProcessorInvocation } from '../models/ZoeDepthImageProcessorInvocation';
|
||||
@ -174,7 +175,7 @@ export class SessionsService {
|
||||
* The id of the session
|
||||
*/
|
||||
sessionId: string,
|
||||
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
|
||||
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation),
|
||||
}): CancelablePromise<string> {
|
||||
return __request(OpenAPI, {
|
||||
method: 'POST',
|
||||
@ -211,7 +212,7 @@ export class SessionsService {
|
||||
* The path to the node in the graph
|
||||
*/
|
||||
nodePath: string,
|
||||
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
|
||||
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation),
|
||||
}): CancelablePromise<GraphExecutionState> {
|
||||
return __request(OpenAPI, {
|
||||
method: 'PUT',
|
||||
|
Loading…
Reference in New Issue
Block a user