mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into model-manager-ui-30
This commit is contained in:
commit
b0c4451324
@ -12,12 +12,19 @@ from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
|
|||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.generator.inpaint import infill_methods
|
from invokeai.backend.generator.inpaint import infill_methods
|
||||||
|
|
||||||
from ...backend.generator import Img2Img, Inpaint, InvokeAIGenerator, Txt2Img
|
from ...backend.generator import Inpaint, InvokeAIGenerator
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ..util.step_callback import stable_diffusion_step_callback
|
from ..util.step_callback import stable_diffusion_step_callback
|
||||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
|
import re
|
||||||
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
|
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
from .model import UNetField, VaeField
|
||||||
|
from .compel import ConditioningField
|
||||||
|
from contextlib import contextmanager, ExitStack, ContextDecorator
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||||
DEFAULT_INFILL_METHOD = (
|
DEFAULT_INFILL_METHOD = (
|
||||||
@ -25,114 +32,48 @@ DEFAULT_INFILL_METHOD = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDImageInvocation(BaseModel):
|
from .latent import get_scheduler
|
||||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
|
||||||
|
|
||||||
# Schema customisation
|
class OldModelContext(ContextDecorator):
|
||||||
class Config(InvocationConfig):
|
model: StableDiffusionGeneratorPipeline
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
def __init__(self, model):
|
||||||
"tags": ["stable-diffusion", "image"],
|
self.model = model
|
||||||
"type_hints": {
|
|
||||||
"model": "model",
|
def __enter__(self):
|
||||||
},
|
return self.model
|
||||||
},
|
|
||||||
}
|
def __exit__(self, *exc):
|
||||||
|
return False
|
||||||
|
|
||||||
|
class OldModelInfo:
|
||||||
|
name: str
|
||||||
|
hash: str
|
||||||
|
context: OldModelContext
|
||||||
|
|
||||||
|
def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline):
|
||||||
|
self.name = name
|
||||||
|
self.hash = hash
|
||||||
|
self.context = OldModelContext(
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Text to image
|
class InpaintInvocation(BaseInvocation):
|
||||||
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
"""Generates an image using inpaint."""
|
||||||
"""Generates an image using text2img."""
|
|
||||||
|
|
||||||
type: Literal["txt2img"] = "txt2img"
|
type: Literal["inpaint"] = "inpaint"
|
||||||
|
|
||||||
# Inputs
|
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||||
# fmt: off
|
|
||||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
|
||||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
unet: UNetField = Field(default=None, description="UNet model")
|
||||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
vae: VaeField = Field(default=None, description="Vae model")
|
||||||
control_model: Optional[str] = Field(default=None, description="The control model to use")
|
|
||||||
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
|
||||||
def dispatch_progress(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
source_node_id: str,
|
|
||||||
intermediate_state: PipelineIntermediateState,
|
|
||||||
) -> None:
|
|
||||||
stable_diffusion_step_callback(
|
|
||||||
context=context,
|
|
||||||
intermediate_state=intermediate_state,
|
|
||||||
node=self.dict(),
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
# Handle invalid model parameter
|
|
||||||
model = context.services.model_manager.get_model(self.model,node=self,context=context)
|
|
||||||
|
|
||||||
# loading controlnet image (currently requires pre-processed image)
|
|
||||||
control_image = (
|
|
||||||
None if self.control_image is None
|
|
||||||
else context.services.images.get_pil_image(self.control_image.image_name)
|
|
||||||
)
|
|
||||||
# loading controlnet model
|
|
||||||
if (self.control_model is None or self.control_model==''):
|
|
||||||
control_model = None
|
|
||||||
else:
|
|
||||||
# FIXME: change this to dropdown menu?
|
|
||||||
# FIXME: generalize so don't have to hardcode torch_dtype and device
|
|
||||||
control_model = ControlNetModel.from_pretrained(self.control_model,
|
|
||||||
torch_dtype=torch.float16).to("cuda")
|
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
|
||||||
context.graph_execution_state_id
|
|
||||||
)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
|
|
||||||
txt2img = Txt2Img(model, control_model=control_model)
|
|
||||||
outputs = txt2img.generate(
|
|
||||||
prompt=self.prompt,
|
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
|
||||||
control_image=control_image,
|
|
||||||
**self.dict(
|
|
||||||
exclude={"prompt", "control_image" }
|
|
||||||
), # Shorthand for passing all of the parameters above manually
|
|
||||||
)
|
|
||||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
|
||||||
# each time it is called. We only need the first one.
|
|
||||||
generate_output = next(outputs)
|
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
|
||||||
image=generate_output.image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
node_id=self.id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageToImageInvocation(TextToImageInvocation):
|
|
||||||
"""Generates an image using img2img."""
|
|
||||||
|
|
||||||
type: Literal["img2img"] = "img2img"
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Union[ImageField, None] = Field(description="The input image")
|
image: Union[ImageField, None] = Field(description="The input image")
|
||||||
@ -144,72 +85,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||||
)
|
)
|
||||||
|
|
||||||
def dispatch_progress(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
source_node_id: str,
|
|
||||||
intermediate_state: PipelineIntermediateState,
|
|
||||||
) -> None:
|
|
||||||
stable_diffusion_step_callback(
|
|
||||||
context=context,
|
|
||||||
intermediate_state=intermediate_state,
|
|
||||||
node=self.dict(),
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
image = (
|
|
||||||
None
|
|
||||||
if self.image is None
|
|
||||||
else context.services.images.get_pil_image(self.image.image_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.fit:
|
|
||||||
image = image.resize((self.width, self.height))
|
|
||||||
|
|
||||||
# Handle invalid model parameter
|
|
||||||
model = context.services.model_manager.get_model(self.model,node=self,context=context)
|
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
|
||||||
context.graph_execution_state_id
|
|
||||||
)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
|
|
||||||
outputs = Img2Img(model).generate(
|
|
||||||
prompt=self.prompt,
|
|
||||||
init_image=image,
|
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
|
||||||
**self.dict(
|
|
||||||
exclude={"prompt", "image", "mask"}
|
|
||||||
), # Shorthand for passing all of the parameters above manually
|
|
||||||
)
|
|
||||||
|
|
||||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
|
||||||
# each time it is called. We only need the first one.
|
|
||||||
generator_output = next(outputs)
|
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
|
||||||
image=generator_output.image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
node_id=self.id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InpaintInvocation(ImageToImageInvocation):
|
|
||||||
"""Generates an image using inpaint."""
|
|
||||||
|
|
||||||
type: Literal["inpaint"] = "inpaint"
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
mask: Union[ImageField, None] = Field(description="The mask")
|
mask: Union[ImageField, None] = Field(description="The mask")
|
||||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||||
@ -252,6 +127,14 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
description="The amount by which to replace masked areas with latent noise",
|
description="The amount by which to replace masked areas with latent noise",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["stable-diffusion", "image"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
@ -265,6 +148,49 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_conditioning(self, context):
|
||||||
|
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||||
|
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
|
|
||||||
|
return (uc, c, extra_conditioning_info)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def load_model_old_way(self, context, scheduler):
|
||||||
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||||
|
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||||
|
|
||||||
|
#unet = unet_info.context.model
|
||||||
|
#vae = vae_info.context.model
|
||||||
|
|
||||||
|
with ExitStack() as stack:
|
||||||
|
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||||
|
|
||||||
|
with vae_info as vae,\
|
||||||
|
unet_info as unet,\
|
||||||
|
ModelPatcher.apply_lora_unet(unet, loras):
|
||||||
|
|
||||||
|
device = context.services.model_manager.mgr.cache.execution_device
|
||||||
|
dtype = context.services.model_manager.mgr.cache.precision
|
||||||
|
|
||||||
|
pipeline = StableDiffusionGeneratorPipeline(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=None,
|
||||||
|
tokenizer=None,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
precision="float16" if dtype == torch.float16 else "float32",
|
||||||
|
execution_device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield OldModelInfo(
|
||||||
|
name=self.unet.unet.model_name,
|
||||||
|
hash="<NO-HASH>",
|
||||||
|
model=pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = (
|
image = (
|
||||||
None
|
None
|
||||||
@ -277,25 +203,31 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
else context.services.images.get_pil_image(self.mask.image_name)
|
else context.services.images.get_pil_image(self.mask.image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle invalid model parameter
|
|
||||||
model = context.services.model_manager.get_model(self.model,node=self,context=context)
|
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
context.graph_execution_state_id
|
context.graph_execution_state_id
|
||||||
)
|
)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
outputs = Inpaint(model).generate(
|
conditioning = self.get_conditioning(context)
|
||||||
prompt=self.prompt,
|
scheduler = get_scheduler(
|
||||||
init_image=image,
|
context=context,
|
||||||
mask_image=mask,
|
scheduler_info=self.unet.scheduler,
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
scheduler_name=self.scheduler,
|
||||||
**self.dict(
|
|
||||||
exclude={"prompt", "image", "mask"}
|
|
||||||
), # Shorthand for passing all of the parameters above manually
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with self.load_model_old_way(context, scheduler) as model:
|
||||||
|
outputs = Inpaint(model).generate(
|
||||||
|
conditioning=conditioning,
|
||||||
|
scheduler=scheduler,
|
||||||
|
init_image=image,
|
||||||
|
mask_image=mask,
|
||||||
|
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||||
|
**self.dict(
|
||||||
|
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
|
||||||
|
), # Shorthand for passing all of the parameters above manually
|
||||||
|
)
|
||||||
|
|
||||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||||
# each time it is called. We only need the first one.
|
# each time it is called. We only need the first one.
|
||||||
generator_output = next(outputs)
|
generator_output = next(outputs)
|
||||||
|
@ -5,7 +5,6 @@ from .generator import (
|
|||||||
InvokeAIGeneratorBasicParams,
|
InvokeAIGeneratorBasicParams,
|
||||||
InvokeAIGenerator,
|
InvokeAIGenerator,
|
||||||
InvokeAIGeneratorOutput,
|
InvokeAIGeneratorOutput,
|
||||||
Txt2Img,
|
|
||||||
Img2Img,
|
Img2Img,
|
||||||
Inpaint
|
Inpaint
|
||||||
)
|
)
|
||||||
|
@ -5,7 +5,6 @@ from .base import (
|
|||||||
InvokeAIGenerator,
|
InvokeAIGenerator,
|
||||||
InvokeAIGeneratorBasicParams,
|
InvokeAIGeneratorBasicParams,
|
||||||
InvokeAIGeneratorOutput,
|
InvokeAIGeneratorOutput,
|
||||||
Txt2Img,
|
|
||||||
Img2Img,
|
Img2Img,
|
||||||
Inpaint,
|
Inpaint,
|
||||||
Generator,
|
Generator,
|
||||||
|
@ -29,7 +29,6 @@ import invokeai.backend.util.logging as logger
|
|||||||
from ..image_util import configure_model_padding
|
from ..image_util import configure_model_padding
|
||||||
from ..util.util import rand_perlin_2d
|
from ..util.util import rand_perlin_2d
|
||||||
from ..safety_checker import SafetyChecker
|
from ..safety_checker import SafetyChecker
|
||||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
|
||||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
|
|
||||||
@ -81,13 +80,15 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
self.params=params
|
self.params=params
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def generate(self,
|
def generate(
|
||||||
prompt: str='',
|
self,
|
||||||
callback: Optional[Callable]=None,
|
conditioning: tuple,
|
||||||
step_callback: Optional[Callable]=None,
|
scheduler,
|
||||||
iterations: int=1,
|
callback: Optional[Callable]=None,
|
||||||
**keyword_args,
|
step_callback: Optional[Callable]=None,
|
||||||
)->Iterator[InvokeAIGeneratorOutput]:
|
iterations: int=1,
|
||||||
|
**keyword_args,
|
||||||
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
'''
|
'''
|
||||||
Return an iterator across the indicated number of generations.
|
Return an iterator across the indicated number of generations.
|
||||||
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
||||||
@ -116,11 +117,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
model_name = model_info.name
|
model_name = model_info.name
|
||||||
model_hash = model_info.hash
|
model_hash = model_info.hash
|
||||||
with model_info.context as model:
|
with model_info.context as model:
|
||||||
scheduler: Scheduler = self.get_scheduler(
|
|
||||||
model=model,
|
|
||||||
scheduler_name=generator_args.get('scheduler')
|
|
||||||
)
|
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
|
||||||
gen_class = self._generator_class()
|
gen_class = self._generator_class()
|
||||||
generator = gen_class(model, self.params.precision, **self.kwargs)
|
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||||
if self.params.variation_amount > 0:
|
if self.params.variation_amount > 0:
|
||||||
@ -143,12 +139,12 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
|
|
||||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||||
for i in iteration_count:
|
for i in iteration_count:
|
||||||
results = generator.generate(prompt,
|
results = generator.generate(
|
||||||
conditioning=(uc, c, extra_conditioning_info),
|
conditioning=conditioning,
|
||||||
step_callback=step_callback,
|
step_callback=step_callback,
|
||||||
sampler=scheduler,
|
sampler=scheduler,
|
||||||
**generator_args,
|
**generator_args,
|
||||||
)
|
)
|
||||||
output = InvokeAIGeneratorOutput(
|
output = InvokeAIGeneratorOutput(
|
||||||
image=results[0][0],
|
image=results[0][0],
|
||||||
seed=results[0][1],
|
seed=results[0][1],
|
||||||
@ -170,20 +166,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||||
return generator_class(model, self.params.precision)
|
return generator_class(model, self.params.precision)
|
||||||
|
|
||||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
|
||||||
|
|
||||||
scheduler_config = model.scheduler.config
|
|
||||||
if "_backup" in scheduler_config:
|
|
||||||
scheduler_config = scheduler_config["_backup"]
|
|
||||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
|
||||||
scheduler = scheduler_class.from_config(scheduler_config)
|
|
||||||
|
|
||||||
# hack copied over from generate.py
|
|
||||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
|
||||||
scheduler.uses_inpainting_model = lambda: False
|
|
||||||
return scheduler
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _generator_class(cls)->Type[Generator]:
|
def _generator_class(cls)->Type[Generator]:
|
||||||
'''
|
'''
|
||||||
@ -193,13 +175,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
'''
|
'''
|
||||||
return Generator
|
return Generator
|
||||||
|
|
||||||
# ------------------------------------
|
|
||||||
class Txt2Img(InvokeAIGenerator):
|
|
||||||
@classmethod
|
|
||||||
def _generator_class(cls):
|
|
||||||
from .txt2img import Txt2Img
|
|
||||||
return Txt2Img
|
|
||||||
|
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
class Img2Img(InvokeAIGenerator):
|
class Img2Img(InvokeAIGenerator):
|
||||||
def generate(self,
|
def generate(self,
|
||||||
@ -253,24 +228,6 @@ class Inpaint(Img2Img):
|
|||||||
from .inpaint import Inpaint
|
from .inpaint import Inpaint
|
||||||
return Inpaint
|
return Inpaint
|
||||||
|
|
||||||
# ------------------------------------
|
|
||||||
class Embiggen(Txt2Img):
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
embiggen: list=None,
|
|
||||||
embiggen_tiles: list = None,
|
|
||||||
strength: float=0.75,
|
|
||||||
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
|
|
||||||
return super().generate(embiggen=embiggen,
|
|
||||||
embiggen_tiles=embiggen_tiles,
|
|
||||||
strength=strength,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _generator_class(cls):
|
|
||||||
from .embiggen import Embiggen
|
|
||||||
return Embiggen
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
downsampling_factor: int
|
downsampling_factor: int
|
||||||
latent_channels: int
|
latent_channels: int
|
||||||
@ -281,7 +238,7 @@ class Generator:
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.seed = None
|
self.seed = None
|
||||||
self.latent_channels = model.channels
|
self.latent_channels = model.unet.config.in_channels
|
||||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||||
self.safety_checker = None
|
self.safety_checker = None
|
||||||
self.perlin = 0.0
|
self.perlin = 0.0
|
||||||
@ -292,7 +249,7 @@ class Generator:
|
|||||||
self.free_gpu_mem = None
|
self.free_gpu_mem = None
|
||||||
|
|
||||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||||
def get_make_image(self, prompt, **kwargs):
|
def get_make_image(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
Return value depends on the seed at the time you call it
|
Return value depends on the seed at the time you call it
|
||||||
@ -308,7 +265,6 @@ class Generator:
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompt,
|
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
sampler,
|
sampler,
|
||||||
@ -333,7 +289,6 @@ class Generator:
|
|||||||
saver.get_stacked_maps_image()
|
saver.get_stacked_maps_image()
|
||||||
)
|
)
|
||||||
make_image = self.get_make_image(
|
make_image = self.get_make_image(
|
||||||
prompt,
|
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
init_image=init_image,
|
init_image=init_image,
|
||||||
width=width,
|
width=width,
|
||||||
|
@ -1,559 +0,0 @@
|
|||||||
"""
|
|
||||||
invokeai.backend.generator.embiggen descends from .generator
|
|
||||||
and generates with .generator.img2img
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import trange
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
from .base import Generator
|
|
||||||
from .img2img import Img2Img
|
|
||||||
|
|
||||||
class Embiggen(Generator):
|
|
||||||
def __init__(self, model, precision):
|
|
||||||
super().__init__(model, precision)
|
|
||||||
self.init_latent = None
|
|
||||||
|
|
||||||
# Replace generate because Embiggen doesn't need/use most of what it does normallly
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
iterations=1,
|
|
||||||
seed=None,
|
|
||||||
image_callback=None,
|
|
||||||
step_callback=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
make_image = self.get_make_image(prompt, step_callback=step_callback, **kwargs)
|
|
||||||
results = []
|
|
||||||
seed = seed if seed else self.new_seed()
|
|
||||||
|
|
||||||
# Noise will be generated by the Img2Img generator when called
|
|
||||||
for _ in trange(iterations, desc="Generating"):
|
|
||||||
# make_image will call Img2Img which will do the equivalent of get_noise itself
|
|
||||||
image = make_image()
|
|
||||||
results.append([image, seed])
|
|
||||||
if image_callback is not None:
|
|
||||||
image_callback(image, seed, prompt_in=prompt)
|
|
||||||
seed = self.new_seed()
|
|
||||||
return results
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def get_make_image(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
sampler,
|
|
||||||
steps,
|
|
||||||
cfg_scale,
|
|
||||||
ddim_eta,
|
|
||||||
conditioning,
|
|
||||||
init_img,
|
|
||||||
strength,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
embiggen,
|
|
||||||
embiggen_tiles,
|
|
||||||
step_callback=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
|
|
||||||
Return value depends on the seed at the time you call it
|
|
||||||
"""
|
|
||||||
assert (
|
|
||||||
not sampler.uses_inpainting_model()
|
|
||||||
), "--embiggen is not supported by inpainting models"
|
|
||||||
|
|
||||||
# Construct embiggen arg array, and sanity check arguments
|
|
||||||
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
|
||||||
embiggen = [1.0] # If not specified, assume no scaling
|
|
||||||
elif embiggen[0] < 0:
|
|
||||||
embiggen[0] = 1.0
|
|
||||||
logger.warning(
|
|
||||||
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
|
||||||
)
|
|
||||||
if len(embiggen) < 2:
|
|
||||||
embiggen.append(0.75)
|
|
||||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
|
||||||
embiggen[1] = 0.75
|
|
||||||
logger.warning(
|
|
||||||
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
|
||||||
)
|
|
||||||
if len(embiggen) < 3:
|
|
||||||
embiggen.append(0.25)
|
|
||||||
elif embiggen[2] < 0:
|
|
||||||
embiggen[2] = 0.25
|
|
||||||
logger.warning(
|
|
||||||
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
|
||||||
# and then sort them, because... people.
|
|
||||||
if embiggen_tiles:
|
|
||||||
embiggen_tiles = list(map(lambda n: n - 1, embiggen_tiles))
|
|
||||||
embiggen_tiles.sort()
|
|
||||||
|
|
||||||
if strength >= 0.5:
|
|
||||||
logger.warning(
|
|
||||||
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prep img2img generator, since we wrap over it
|
|
||||||
gen_img2img = Img2Img(self.model, self.precision)
|
|
||||||
|
|
||||||
# Open original init image (not a tensor) to manipulate
|
|
||||||
initsuperimage = Image.open(init_img)
|
|
||||||
|
|
||||||
with Image.open(init_img) as img:
|
|
||||||
initsuperimage = img.convert("RGB")
|
|
||||||
|
|
||||||
# Size of the target super init image in pixels
|
|
||||||
initsuperwidth, initsuperheight = initsuperimage.size
|
|
||||||
|
|
||||||
# Increase by scaling factor if not already resized, using ESRGAN as able
|
|
||||||
if embiggen[0] != 1.0:
|
|
||||||
initsuperwidth = round(initsuperwidth * embiggen[0])
|
|
||||||
initsuperheight = round(initsuperheight * embiggen[0])
|
|
||||||
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
|
||||||
from ..restoration.realesrgan import ESRGAN
|
|
||||||
|
|
||||||
esrgan = ESRGAN()
|
|
||||||
logger.info(
|
|
||||||
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
|
||||||
)
|
|
||||||
if embiggen[0] > 2:
|
|
||||||
initsuperimage = esrgan.process(
|
|
||||||
initsuperimage,
|
|
||||||
embiggen[1], # upscale strength
|
|
||||||
self.seed,
|
|
||||||
4, # upscale scale
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
initsuperimage = esrgan.process(
|
|
||||||
initsuperimage,
|
|
||||||
embiggen[1], # upscale strength
|
|
||||||
self.seed,
|
|
||||||
2, # upscale scale
|
|
||||||
)
|
|
||||||
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
|
|
||||||
# but from personal experiance it doesn't greatly improve anything after 4x
|
|
||||||
# Resize to target scaling factor resolution
|
|
||||||
initsuperimage = initsuperimage.resize(
|
|
||||||
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use width and height as tile widths and height
|
|
||||||
# Determine buffer size in pixels
|
|
||||||
if embiggen[2] < 1:
|
|
||||||
if embiggen[2] < 0:
|
|
||||||
embiggen[2] = 0
|
|
||||||
overlap_size_x = round(embiggen[2] * width)
|
|
||||||
overlap_size_y = round(embiggen[2] * height)
|
|
||||||
else:
|
|
||||||
overlap_size_x = round(embiggen[2])
|
|
||||||
overlap_size_y = round(embiggen[2])
|
|
||||||
|
|
||||||
# With overall image width and height known, determine how many tiles we need
|
|
||||||
def ceildiv(a, b):
|
|
||||||
return -1 * (-a // b)
|
|
||||||
|
|
||||||
# X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count)
|
|
||||||
# (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill
|
|
||||||
# (width - overlap_size_x) is how much new we can fill with a single tile
|
|
||||||
emb_tiles_x = 1
|
|
||||||
emb_tiles_y = 1
|
|
||||||
if (initsuperwidth - width) > 0:
|
|
||||||
emb_tiles_x = ceildiv(initsuperwidth - width, width - overlap_size_x) + 1
|
|
||||||
if (initsuperheight - height) > 0:
|
|
||||||
emb_tiles_y = ceildiv(initsuperheight - height, height - overlap_size_y) + 1
|
|
||||||
# Sanity
|
|
||||||
assert (
|
|
||||||
emb_tiles_x > 1 or emb_tiles_y > 1
|
|
||||||
), f"ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don't need to Embiggen! Check your arguments."
|
|
||||||
|
|
||||||
# Prep alpha layers --------------
|
|
||||||
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
|
|
||||||
# agradientL is Left-side transparent
|
|
||||||
agradientL = (
|
|
||||||
Image.linear_gradient("L").rotate(90).resize((overlap_size_x, height))
|
|
||||||
)
|
|
||||||
# agradientT is Top-side transparent
|
|
||||||
agradientT = Image.linear_gradient("L").resize((width, overlap_size_y))
|
|
||||||
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
|
|
||||||
agradientC = Image.new("L", (256, 256))
|
|
||||||
for y in range(256):
|
|
||||||
for x in range(256):
|
|
||||||
# Find distance to lower right corner (numpy takes arrays)
|
|
||||||
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
|
|
||||||
# Clamp values to max 255
|
|
||||||
if distanceToLR > 255:
|
|
||||||
distanceToLR = 255
|
|
||||||
# Place the pixel as invert of distance
|
|
||||||
agradientC.putpixel((x, y), round(255 - distanceToLR))
|
|
||||||
|
|
||||||
# Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
|
|
||||||
# Fits for a left-fading gradient on the bottom side and full opacity on the right side.
|
|
||||||
agradientAsymC = Image.new("L", (256, 256))
|
|
||||||
for y in range(256):
|
|
||||||
for x in range(256):
|
|
||||||
value = round(max(0, x - (255 - y)) * (255 / max(1, y)))
|
|
||||||
# Clamp values
|
|
||||||
value = max(0, value)
|
|
||||||
value = min(255, value)
|
|
||||||
agradientAsymC.putpixel((x, y), value)
|
|
||||||
|
|
||||||
# Create alpha layers default fully white
|
|
||||||
alphaLayerL = Image.new("L", (width, height), 255)
|
|
||||||
alphaLayerT = Image.new("L", (width, height), 255)
|
|
||||||
alphaLayerLTC = Image.new("L", (width, height), 255)
|
|
||||||
# Paste gradients into alpha layers
|
|
||||||
alphaLayerL.paste(agradientL, (0, 0))
|
|
||||||
alphaLayerT.paste(agradientT, (0, 0))
|
|
||||||
alphaLayerLTC.paste(agradientL, (0, 0))
|
|
||||||
alphaLayerLTC.paste(agradientT, (0, 0))
|
|
||||||
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
|
||||||
# make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
|
|
||||||
# to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
|
|
||||||
alphaLayerTaC = alphaLayerT.copy()
|
|
||||||
alphaLayerTaC.paste(
|
|
||||||
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
|
||||||
(width - overlap_size_x, 0),
|
|
||||||
)
|
|
||||||
alphaLayerLTaC = alphaLayerLTC.copy()
|
|
||||||
alphaLayerLTaC.paste(
|
|
||||||
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
|
||||||
(width - overlap_size_x, 0),
|
|
||||||
)
|
|
||||||
|
|
||||||
if embiggen_tiles:
|
|
||||||
# Individual unconnected sides
|
|
||||||
alphaLayerR = Image.new("L", (width, height), 255)
|
|
||||||
alphaLayerR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
|
||||||
alphaLayerB = Image.new("L", (width, height), 255)
|
|
||||||
alphaLayerB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
|
||||||
alphaLayerTB = Image.new("L", (width, height), 255)
|
|
||||||
alphaLayerTB.paste(agradientT, (0, 0))
|
|
||||||
alphaLayerTB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
|
||||||
alphaLayerLR = Image.new("L", (width, height), 255)
|
|
||||||
alphaLayerLR.paste(agradientL, (0, 0))
|
|
||||||
alphaLayerLR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
|
||||||
|
|
||||||
# Sides and corner Layers
|
|
||||||
alphaLayerRBC = Image.new("L", (width, height), 255)
|
|
||||||
alphaLayerRBC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
|
||||||
alphaLayerRBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
|
||||||
alphaLayerRBC.paste(
|
|
||||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
|
||||||
(width - overlap_size_x, height - overlap_size_y),
|
|
||||||
)
|
|
||||||
alphaLayerLBC = Image.new("L", (width, height), 255)
|
|
||||||
alphaLayerLBC.paste(agradientL, (0, 0))
|
|
||||||
alphaLayerLBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
|
||||||
alphaLayerLBC.paste(
|
|
||||||
agradientC.rotate(90).resize((overlap_size_x, overlap_size_y)),
|
|
||||||
(0, height - overlap_size_y),
|
|
||||||
)
|
|
||||||
alphaLayerRTC = Image.new("L", (width, height), 255)
|
|
||||||
alphaLayerRTC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
|
||||||
alphaLayerRTC.paste(agradientT, (0, 0))
|
|
||||||
alphaLayerRTC.paste(
|
|
||||||
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
|
||||||
(width - overlap_size_x, 0),
|
|
||||||
)
|
|
||||||
|
|
||||||
# All but X layers
|
|
||||||
alphaLayerABT = Image.new("L", (width, height), 255)
|
|
||||||
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
|
|
||||||
alphaLayerABT.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
|
||||||
alphaLayerABT.paste(
|
|
||||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
|
||||||
(width - overlap_size_x, height - overlap_size_y),
|
|
||||||
)
|
|
||||||
alphaLayerABL = Image.new("L", (width, height), 255)
|
|
||||||
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
|
|
||||||
alphaLayerABL.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
|
||||||
alphaLayerABL.paste(
|
|
||||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
|
||||||
(width - overlap_size_x, height - overlap_size_y),
|
|
||||||
)
|
|
||||||
alphaLayerABR = Image.new("L", (width, height), 255)
|
|
||||||
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
|
|
||||||
alphaLayerABR.paste(agradientT, (0, 0))
|
|
||||||
alphaLayerABR.paste(
|
|
||||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
|
||||||
)
|
|
||||||
alphaLayerABB = Image.new("L", (width, height), 255)
|
|
||||||
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
|
|
||||||
alphaLayerABB.paste(agradientL, (0, 0))
|
|
||||||
alphaLayerABB.paste(
|
|
||||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
# All-around layer
|
|
||||||
alphaLayerAA = Image.new("L", (width, height), 255)
|
|
||||||
alphaLayerAA.paste(alphaLayerABT, (0, 0))
|
|
||||||
alphaLayerAA.paste(agradientT, (0, 0))
|
|
||||||
alphaLayerAA.paste(
|
|
||||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
|
||||||
)
|
|
||||||
alphaLayerAA.paste(
|
|
||||||
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
|
||||||
(width - overlap_size_x, 0),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean up temporary gradients
|
|
||||||
del agradientL
|
|
||||||
del agradientT
|
|
||||||
del agradientC
|
|
||||||
|
|
||||||
def make_image():
|
|
||||||
# Make main tiles -------------------------------------------------
|
|
||||||
if embiggen_tiles:
|
|
||||||
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
|
||||||
)
|
|
||||||
|
|
||||||
emb_tile_store = []
|
|
||||||
# Although we could use the same seed for every tile for determinism, at higher strengths this may
|
|
||||||
# produce duplicated structures for each tile and make the tiling effect more obvious
|
|
||||||
# instead track and iterate a local seed we pass to Img2Img
|
|
||||||
seed = self.seed
|
|
||||||
seedintlimit = (
|
|
||||||
np.iinfo(np.uint32).max - 1
|
|
||||||
) # only retreive this one from numpy
|
|
||||||
|
|
||||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
|
||||||
# Don't iterate on first tile
|
|
||||||
if tile != 0:
|
|
||||||
if seed < seedintlimit:
|
|
||||||
seed += 1
|
|
||||||
else:
|
|
||||||
seed = 0
|
|
||||||
|
|
||||||
# Determine if this is a re-run and replace
|
|
||||||
if embiggen_tiles and not tile in embiggen_tiles:
|
|
||||||
continue
|
|
||||||
# Get row and column entries
|
|
||||||
emb_row_i = tile // emb_tiles_x
|
|
||||||
emb_column_i = tile % emb_tiles_x
|
|
||||||
# Determine bounds to cut up the init image
|
|
||||||
# Determine upper-left point
|
|
||||||
if emb_column_i + 1 == emb_tiles_x:
|
|
||||||
left = initsuperwidth - width
|
|
||||||
else:
|
|
||||||
left = round(emb_column_i * (width - overlap_size_x))
|
|
||||||
if emb_row_i + 1 == emb_tiles_y:
|
|
||||||
top = initsuperheight - height
|
|
||||||
else:
|
|
||||||
top = round(emb_row_i * (height - overlap_size_y))
|
|
||||||
right = left + width
|
|
||||||
bottom = top + height
|
|
||||||
|
|
||||||
# Cropped image of above dimension (does not modify the original)
|
|
||||||
newinitimage = initsuperimage.crop((left, top, right, bottom))
|
|
||||||
# DEBUG:
|
|
||||||
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
|
|
||||||
# newinitimage.save(newinitimagepath)
|
|
||||||
|
|
||||||
if embiggen_tiles:
|
|
||||||
logger.debug(
|
|
||||||
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
|
||||||
|
|
||||||
# create a torch tensor from an Image
|
|
||||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
|
||||||
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
|
|
||||||
newinitimage = torch.from_numpy(newinitimage)
|
|
||||||
newinitimage = 2.0 * newinitimage - 1.0
|
|
||||||
newinitimage = newinitimage.to(self.model.device)
|
|
||||||
clear_cuda_cache = (
|
|
||||||
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
|
||||||
)
|
|
||||||
|
|
||||||
tile_results = gen_img2img.generate(
|
|
||||||
prompt,
|
|
||||||
iterations=1,
|
|
||||||
seed=seed,
|
|
||||||
sampler=sampler,
|
|
||||||
steps=steps,
|
|
||||||
cfg_scale=cfg_scale,
|
|
||||||
conditioning=conditioning,
|
|
||||||
ddim_eta=ddim_eta,
|
|
||||||
image_callback=None, # called only after the final image is generated
|
|
||||||
step_callback=step_callback, # called after each intermediate image is generated
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
init_image=newinitimage, # notice that init_image is different from init_img
|
|
||||||
mask_image=None,
|
|
||||||
strength=strength,
|
|
||||||
clear_cuda_cache=clear_cuda_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
emb_tile_store.append(tile_results[0][0])
|
|
||||||
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
|
|
||||||
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
|
|
||||||
del newinitimage
|
|
||||||
|
|
||||||
# Sanity check we have them all
|
|
||||||
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (
|
|
||||||
embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)
|
|
||||||
):
|
|
||||||
outputsuperimage = Image.new("RGBA", (initsuperwidth, initsuperheight))
|
|
||||||
if embiggen_tiles:
|
|
||||||
outputsuperimage.alpha_composite(
|
|
||||||
initsuperimage.convert("RGBA"), (0, 0)
|
|
||||||
)
|
|
||||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
|
||||||
if embiggen_tiles:
|
|
||||||
if tile in embiggen_tiles:
|
|
||||||
intileimage = emb_tile_store.pop(0)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
intileimage = emb_tile_store[tile]
|
|
||||||
intileimage = intileimage.convert("RGBA")
|
|
||||||
# Get row and column entries
|
|
||||||
emb_row_i = tile // emb_tiles_x
|
|
||||||
emb_column_i = tile % emb_tiles_x
|
|
||||||
if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles:
|
|
||||||
left = 0
|
|
||||||
top = 0
|
|
||||||
else:
|
|
||||||
# Determine upper-left point
|
|
||||||
if emb_column_i + 1 == emb_tiles_x:
|
|
||||||
left = initsuperwidth - width
|
|
||||||
else:
|
|
||||||
left = round(emb_column_i * (width - overlap_size_x))
|
|
||||||
if emb_row_i + 1 == emb_tiles_y:
|
|
||||||
top = initsuperheight - height
|
|
||||||
else:
|
|
||||||
top = round(emb_row_i * (height - overlap_size_y))
|
|
||||||
# Handle gradients for various conditions
|
|
||||||
# Handle emb_rerun case
|
|
||||||
if embiggen_tiles:
|
|
||||||
# top of image
|
|
||||||
if emb_row_i == 0:
|
|
||||||
if emb_column_i == 0:
|
|
||||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
|
||||||
if (
|
|
||||||
tile + emb_tiles_x
|
|
||||||
) not in embiggen_tiles: # Look-ahead down
|
|
||||||
intileimage.putalpha(alphaLayerB)
|
|
||||||
# Otherwise do nothing on this tile
|
|
||||||
elif (
|
|
||||||
tile + emb_tiles_x
|
|
||||||
) in embiggen_tiles: # Look-ahead down only
|
|
||||||
intileimage.putalpha(alphaLayerR)
|
|
||||||
else:
|
|
||||||
intileimage.putalpha(alphaLayerRBC)
|
|
||||||
elif emb_column_i == emb_tiles_x - 1:
|
|
||||||
if (
|
|
||||||
tile + emb_tiles_x
|
|
||||||
) in embiggen_tiles: # Look-ahead down
|
|
||||||
intileimage.putalpha(alphaLayerL)
|
|
||||||
else:
|
|
||||||
intileimage.putalpha(alphaLayerLBC)
|
|
||||||
else:
|
|
||||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
|
||||||
if (
|
|
||||||
tile + emb_tiles_x
|
|
||||||
) in embiggen_tiles: # Look-ahead down
|
|
||||||
intileimage.putalpha(alphaLayerL)
|
|
||||||
else:
|
|
||||||
intileimage.putalpha(alphaLayerLBC)
|
|
||||||
elif (
|
|
||||||
tile + emb_tiles_x
|
|
||||||
) in embiggen_tiles: # Look-ahead down only
|
|
||||||
intileimage.putalpha(alphaLayerLR)
|
|
||||||
else:
|
|
||||||
intileimage.putalpha(alphaLayerABT)
|
|
||||||
# bottom of image
|
|
||||||
elif emb_row_i == emb_tiles_y - 1:
|
|
||||||
if emb_column_i == 0:
|
|
||||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
|
||||||
intileimage.putalpha(alphaLayerTaC)
|
|
||||||
else:
|
|
||||||
intileimage.putalpha(alphaLayerRTC)
|
|
||||||
elif emb_column_i == emb_tiles_x - 1:
|
|
||||||
# No tiles to look ahead to
|
|
||||||
intileimage.putalpha(alphaLayerLTC)
|
|
||||||
else:
|
|
||||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
|
||||||
intileimage.putalpha(alphaLayerLTaC)
|
|
||||||
else:
|
|
||||||
intileimage.putalpha(alphaLayerABB)
|
|
||||||
# vertical middle of image
|
|
||||||
else:
|
|
||||||
if emb_column_i == 0:
|
|
||||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
|
||||||
if (
|
|
||||||
tile + emb_tiles_x
|
|
||||||
) in embiggen_tiles: # Look-ahead down
|
|
||||||
intileimage.putalpha(alphaLayerTaC)
|
|
||||||
else:
|
|
||||||
intileimage.putalpha(alphaLayerTB)
|
|
||||||
elif (
|
|
||||||
tile + emb_tiles_x
|
|
||||||
) in embiggen_tiles: # Look-ahead down only
|
|
||||||
intileimage.putalpha(alphaLayerRTC)
|
|
||||||
else:
|
|
||||||
intileimage.putalpha(alphaLayerABL)
|
|
||||||
elif emb_column_i == emb_tiles_x - 1:
|
|
||||||
if (
|
|
||||||
tile + emb_tiles_x
|
|
||||||
) in embiggen_tiles: # Look-ahead down
|
|
||||||
intileimage.putalpha(alphaLayerLTC)
|
|
||||||
else:
|
|
||||||
intileimage.putalpha(alphaLayerABR)
|
|
||||||
else:
|
|
||||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
|
||||||
if (
|
|
||||||
tile + emb_tiles_x
|
|
||||||
) in embiggen_tiles: # Look-ahead down
|
|
||||||
intileimage.putalpha(alphaLayerLTaC)
|
|
||||||
else:
|
|
||||||
intileimage.putalpha(alphaLayerABR)
|
|
||||||
elif (
|
|
||||||
tile + emb_tiles_x
|
|
||||||
) in embiggen_tiles: # Look-ahead down only
|
|
||||||
intileimage.putalpha(alphaLayerABB)
|
|
||||||
else:
|
|
||||||
intileimage.putalpha(alphaLayerAA)
|
|
||||||
# Handle normal tiling case (much simpler - since we tile left to right, top to bottom)
|
|
||||||
else:
|
|
||||||
if emb_row_i == 0 and emb_column_i >= 1:
|
|
||||||
intileimage.putalpha(alphaLayerL)
|
|
||||||
elif emb_row_i >= 1 and emb_column_i == 0:
|
|
||||||
if (
|
|
||||||
emb_column_i + 1 == emb_tiles_x
|
|
||||||
): # If we don't have anything that can be placed to the right
|
|
||||||
intileimage.putalpha(alphaLayerT)
|
|
||||||
else:
|
|
||||||
intileimage.putalpha(alphaLayerTaC)
|
|
||||||
else:
|
|
||||||
if (
|
|
||||||
emb_column_i + 1 == emb_tiles_x
|
|
||||||
): # If we don't have anything that can be placed to the right
|
|
||||||
intileimage.putalpha(alphaLayerLTC)
|
|
||||||
else:
|
|
||||||
intileimage.putalpha(alphaLayerLTaC)
|
|
||||||
# Layer tile onto final image
|
|
||||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
|
||||||
)
|
|
||||||
|
|
||||||
# after internal loops and patching up return Embiggen image
|
|
||||||
return outputsuperimage
|
|
||||||
|
|
||||||
# end of function declaration
|
|
||||||
return make_image
|
|
@ -22,7 +22,6 @@ class Img2Img(Generator):
|
|||||||
|
|
||||||
def get_make_image(
|
def get_make_image(
|
||||||
self,
|
self,
|
||||||
prompt,
|
|
||||||
sampler,
|
sampler,
|
||||||
steps,
|
steps,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
@ -161,9 +161,7 @@ class Inpaint(Img2Img):
|
|||||||
im: Image.Image,
|
im: Image.Image,
|
||||||
seam_size: int,
|
seam_size: int,
|
||||||
seam_blur: int,
|
seam_blur: int,
|
||||||
prompt,
|
|
||||||
seed,
|
seed,
|
||||||
sampler,
|
|
||||||
steps,
|
steps,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
ddim_eta,
|
ddim_eta,
|
||||||
@ -177,8 +175,6 @@ class Inpaint(Img2Img):
|
|||||||
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
||||||
|
|
||||||
make_image = self.get_make_image(
|
make_image = self.get_make_image(
|
||||||
prompt,
|
|
||||||
sampler,
|
|
||||||
steps,
|
steps,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
ddim_eta,
|
ddim_eta,
|
||||||
@ -203,8 +199,6 @@ class Inpaint(Img2Img):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(
|
def get_make_image(
|
||||||
self,
|
self,
|
||||||
prompt,
|
|
||||||
sampler,
|
|
||||||
steps,
|
steps,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
ddim_eta,
|
ddim_eta,
|
||||||
@ -306,7 +300,6 @@ class Inpaint(Img2Img):
|
|||||||
|
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||||
pipeline.scheduler = sampler
|
|
||||||
|
|
||||||
# todo: support cross-attention control
|
# todo: support cross-attention control
|
||||||
uc, c, _ = conditioning
|
uc, c, _ = conditioning
|
||||||
@ -345,9 +338,7 @@ class Inpaint(Img2Img):
|
|||||||
result,
|
result,
|
||||||
seam_size,
|
seam_size,
|
||||||
seam_blur,
|
seam_blur,
|
||||||
prompt,
|
|
||||||
seed,
|
seed,
|
||||||
sampler,
|
|
||||||
seam_steps,
|
seam_steps,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
ddim_eta,
|
ddim_eta,
|
||||||
@ -360,8 +351,6 @@ class Inpaint(Img2Img):
|
|||||||
|
|
||||||
# Restore original settings
|
# Restore original settings
|
||||||
self.get_make_image(
|
self.get_make_image(
|
||||||
prompt,
|
|
||||||
sampler,
|
|
||||||
steps,
|
steps,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
ddim_eta,
|
ddim_eta,
|
||||||
|
@ -1,125 +0,0 @@
|
|||||||
"""
|
|
||||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
|
||||||
"""
|
|
||||||
import PIL.Image
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
||||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
|
||||||
from diffusers.pipelines.controlnet import MultiControlNetModel
|
|
||||||
|
|
||||||
from ..stable_diffusion import (
|
|
||||||
ConditioningData,
|
|
||||||
PostprocessingSettings,
|
|
||||||
StableDiffusionGeneratorPipeline,
|
|
||||||
)
|
|
||||||
from .base import Generator
|
|
||||||
|
|
||||||
|
|
||||||
class Txt2Img(Generator):
|
|
||||||
def __init__(self, model, precision,
|
|
||||||
control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
|
|
||||||
**kwargs):
|
|
||||||
self.control_model = control_model
|
|
||||||
if isinstance(self.control_model, list):
|
|
||||||
self.control_model = MultiControlNetModel(self.control_model)
|
|
||||||
super().__init__(model, precision, **kwargs)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def get_make_image(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
sampler,
|
|
||||||
steps,
|
|
||||||
cfg_scale,
|
|
||||||
ddim_eta,
|
|
||||||
conditioning,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
step_callback=None,
|
|
||||||
threshold=0.0,
|
|
||||||
warmup=0.2,
|
|
||||||
perlin=0.0,
|
|
||||||
h_symmetry_time_pct=None,
|
|
||||||
v_symmetry_time_pct=None,
|
|
||||||
attention_maps_callback=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
|
||||||
Return value depends on the seed at the time you call it
|
|
||||||
kwargs are 'width' and 'height'
|
|
||||||
"""
|
|
||||||
self.perlin = perlin
|
|
||||||
control_image = kwargs.get("control_image", None)
|
|
||||||
do_classifier_free_guidance = cfg_scale > 1.0
|
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
|
||||||
pipeline.control_model = self.control_model
|
|
||||||
pipeline.scheduler = sampler
|
|
||||||
|
|
||||||
uc, c, extra_conditioning_info = conditioning
|
|
||||||
conditioning_data = ConditioningData(
|
|
||||||
uc,
|
|
||||||
c,
|
|
||||||
cfg_scale,
|
|
||||||
extra_conditioning_info,
|
|
||||||
postprocessing_settings=PostprocessingSettings(
|
|
||||||
threshold=threshold,
|
|
||||||
warmup=warmup,
|
|
||||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
|
||||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
|
||||||
),
|
|
||||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
|
||||||
|
|
||||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
|
||||||
# and add in batch_size, num_images_per_prompt?
|
|
||||||
if control_image is not None:
|
|
||||||
if isinstance(self.control_model, ControlNetModel):
|
|
||||||
control_image = pipeline.prepare_control_image(
|
|
||||||
image=control_image,
|
|
||||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
# batch_size=batch_size * num_images_per_prompt,
|
|
||||||
# num_images_per_prompt=num_images_per_prompt,
|
|
||||||
device=self.control_model.device,
|
|
||||||
dtype=self.control_model.dtype,
|
|
||||||
)
|
|
||||||
elif isinstance(self.control_model, MultiControlNetModel):
|
|
||||||
images = []
|
|
||||||
for image_ in control_image:
|
|
||||||
image_ = pipeline.prepare_control_image(
|
|
||||||
image=image_,
|
|
||||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
# batch_size=batch_size * num_images_per_prompt,
|
|
||||||
# num_images_per_prompt=num_images_per_prompt,
|
|
||||||
device=self.control_model.device,
|
|
||||||
dtype=self.control_model.dtype,
|
|
||||||
)
|
|
||||||
images.append(image_)
|
|
||||||
control_image = images
|
|
||||||
kwargs["control_image"] = control_image
|
|
||||||
|
|
||||||
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
|
||||||
pipeline_output = pipeline.image_from_embeddings(
|
|
||||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
|
||||||
noise=x_T,
|
|
||||||
num_inference_steps=steps,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
callback=step_callback,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
pipeline_output.attention_map_saver is not None
|
|
||||||
and attention_maps_callback is not None
|
|
||||||
):
|
|
||||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
|
||||||
|
|
||||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
|
||||||
|
|
||||||
return make_image
|
|
@ -1,209 +0,0 @@
|
|||||||
"""
|
|
||||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
|
|
||||||
|
|
||||||
from ..stable_diffusion import PostprocessingSettings
|
|
||||||
from .base import Generator
|
|
||||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
|
||||||
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
|
||||||
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
class Txt2Img2Img(Generator):
|
|
||||||
def __init__(self, model, precision):
|
|
||||||
super().__init__(model, precision)
|
|
||||||
self.init_latent = None # for get_noise()
|
|
||||||
|
|
||||||
def get_make_image(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
sampler,
|
|
||||||
steps: int,
|
|
||||||
cfg_scale: float,
|
|
||||||
ddim_eta,
|
|
||||||
conditioning,
|
|
||||||
width: int,
|
|
||||||
height: int,
|
|
||||||
strength: float,
|
|
||||||
step_callback: Optional[Callable] = None,
|
|
||||||
threshold=0.0,
|
|
||||||
warmup=0.2,
|
|
||||||
perlin=0.0,
|
|
||||||
h_symmetry_time_pct=None,
|
|
||||||
v_symmetry_time_pct=None,
|
|
||||||
attention_maps_callback=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
|
||||||
Return value depends on the seed at the time you call it
|
|
||||||
kwargs are 'width' and 'height'
|
|
||||||
"""
|
|
||||||
self.perlin = perlin
|
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
|
||||||
pipeline.scheduler = sampler
|
|
||||||
|
|
||||||
uc, c, extra_conditioning_info = conditioning
|
|
||||||
conditioning_data = ConditioningData(
|
|
||||||
uc,
|
|
||||||
c,
|
|
||||||
cfg_scale,
|
|
||||||
extra_conditioning_info,
|
|
||||||
postprocessing_settings=PostprocessingSettings(
|
|
||||||
threshold=threshold,
|
|
||||||
warmup=0.2,
|
|
||||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
|
||||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
|
||||||
),
|
|
||||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
|
||||||
|
|
||||||
def make_image(x_T: torch.Tensor, _: int):
|
|
||||||
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
|
|
||||||
latents=torch.zeros_like(x_T),
|
|
||||||
num_inference_steps=steps,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
noise=x_T,
|
|
||||||
callback=step_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get our initial generation width and height directly from the latent output so
|
|
||||||
# the message below is accurate.
|
|
||||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
|
||||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
|
||||||
logger.info(
|
|
||||||
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
|
||||||
)
|
|
||||||
|
|
||||||
# resizing
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
|
||||||
first_pass_latent_output,
|
|
||||||
size=(
|
|
||||||
height // self.downsampling_factor,
|
|
||||||
width // self.downsampling_factor,
|
|
||||||
),
|
|
||||||
mode="bilinear",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Free up memory from the last generation.
|
|
||||||
clear_cuda_cache = kwargs["clear_cuda_cache"] or None
|
|
||||||
if clear_cuda_cache is not None:
|
|
||||||
clear_cuda_cache()
|
|
||||||
|
|
||||||
second_pass_noise = self.get_noise_like(
|
|
||||||
resized_latents, override_perlin=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clear symmetry for the second pass
|
|
||||||
from dataclasses import replace
|
|
||||||
|
|
||||||
new_postprocessing_settings = replace(
|
|
||||||
conditioning_data.postprocessing_settings, h_symmetry_time_pct=None
|
|
||||||
)
|
|
||||||
new_postprocessing_settings = replace(
|
|
||||||
new_postprocessing_settings, v_symmetry_time_pct=None
|
|
||||||
)
|
|
||||||
new_conditioning_data = replace(
|
|
||||||
conditioning_data, postprocessing_settings=new_postprocessing_settings
|
|
||||||
)
|
|
||||||
|
|
||||||
verbosity = get_verbosity()
|
|
||||||
set_verbosity_error()
|
|
||||||
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
|
|
||||||
resized_latents,
|
|
||||||
num_inference_steps=steps,
|
|
||||||
conditioning_data=new_conditioning_data,
|
|
||||||
strength=strength,
|
|
||||||
noise=second_pass_noise,
|
|
||||||
callback=step_callback,
|
|
||||||
)
|
|
||||||
set_verbosity(verbosity)
|
|
||||||
|
|
||||||
if (
|
|
||||||
pipeline_output.attention_map_saver is not None
|
|
||||||
and attention_maps_callback is not None
|
|
||||||
):
|
|
||||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
|
||||||
|
|
||||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
|
||||||
|
|
||||||
# FIXME: do we really need something entirely different for the inpainting model?
|
|
||||||
|
|
||||||
# in the case of the inpainting model being loaded, the trick of
|
|
||||||
# providing an interpolated latent doesn't work, so we transiently
|
|
||||||
# create a 512x512 PIL image, upscale it, and run the inpainting
|
|
||||||
# over it in img2img mode. Because the inpaing model is so conservative
|
|
||||||
# it doesn't change the image (much)
|
|
||||||
|
|
||||||
return make_image
|
|
||||||
|
|
||||||
def get_noise_like(self, like: torch.Tensor, override_perlin: bool = False):
|
|
||||||
device = like.device
|
|
||||||
if device.type == "mps":
|
|
||||||
x = torch.randn_like(like, device="cpu", dtype=self.torch_dtype()).to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
|
|
||||||
if self.perlin > 0.0 and override_perlin == False:
|
|
||||||
shape = like.shape
|
|
||||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
|
|
||||||
shape[3], shape[2]
|
|
||||||
)
|
|
||||||
return x
|
|
||||||
|
|
||||||
# returns a tensor filled with random numbers from a normal distribution
|
|
||||||
def get_noise(self, width, height, scale=True):
|
|
||||||
# print(f"Get noise: {width}x{height}")
|
|
||||||
if scale:
|
|
||||||
# Scale the input width and height for the initial generation
|
|
||||||
# Make their area equivalent to the model's resolution area (e.g. 512*512 = 262144),
|
|
||||||
# while keeping the minimum dimension at least 0.5 * resolution (e.g. 512*0.5 = 256)
|
|
||||||
|
|
||||||
aspect = width / height
|
|
||||||
dimension = self.model.unet.config.sample_size * self.model.vae_scale_factor
|
|
||||||
min_dimension = math.floor(dimension * 0.5)
|
|
||||||
model_area = (
|
|
||||||
dimension * dimension
|
|
||||||
) # hardcoded for now since all models are trained on square images
|
|
||||||
|
|
||||||
if aspect > 1.0:
|
|
||||||
init_height = max(min_dimension, math.sqrt(model_area / aspect))
|
|
||||||
init_width = init_height * aspect
|
|
||||||
else:
|
|
||||||
init_width = max(min_dimension, math.sqrt(model_area * aspect))
|
|
||||||
init_height = init_width / aspect
|
|
||||||
|
|
||||||
scaled_width, scaled_height = trim_to_multiple_of(
|
|
||||||
math.floor(init_width), math.floor(init_height)
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
scaled_width = width
|
|
||||||
scaled_height = height
|
|
||||||
|
|
||||||
device = self.model.device
|
|
||||||
channels = self.latent_channels
|
|
||||||
if channels == 9:
|
|
||||||
channels = 4 # we don't really want noise for all the mask channels
|
|
||||||
shape = (
|
|
||||||
1,
|
|
||||||
channels,
|
|
||||||
scaled_height // self.downsampling_factor,
|
|
||||||
scaled_width // self.downsampling_factor,
|
|
||||||
)
|
|
||||||
if self.use_mps_noise or device.type == "mps":
|
|
||||||
tensor = torch.empty(size=shape, device="cpu")
|
|
||||||
tensor = self.get_noise_like(like=tensor).to(device)
|
|
||||||
else:
|
|
||||||
tensor = torch.empty(size=shape, device=device)
|
|
||||||
tensor = self.get_noise_like(like=tensor)
|
|
||||||
return tensor
|
|
@ -556,8 +556,8 @@ class ModelPatcher:
|
|||||||
new_tokens_added = None
|
new_tokens_added = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ti_manager = TextualInversionManager()
|
|
||||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||||
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||||
|
|
||||||
def _get_trigger(ti, index):
|
def _get_trigger(ti, index):
|
||||||
@ -650,22 +650,24 @@ class TextualInversionModel:
|
|||||||
|
|
||||||
class TextualInversionManager(BaseTextualInversionManager):
|
class TextualInversionManager(BaseTextualInversionManager):
|
||||||
pad_tokens: Dict[int, List[int]]
|
pad_tokens: Dict[int, List[int]]
|
||||||
|
tokenizer: CLIPTokenizer
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, tokenizer: CLIPTokenizer):
|
||||||
self.pad_tokens = dict()
|
self.pad_tokens = dict()
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def expand_textual_inversion_token_ids_if_necessary(
|
def expand_textual_inversion_token_ids_if_necessary(
|
||||||
self, token_ids: list[int]
|
self, token_ids: list[int]
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
|
|
||||||
#if token_ids[0] == self.tokenizer.bos_token_id:
|
|
||||||
# raise ValueError("token_ids must not start with bos_token_id")
|
|
||||||
#if token_ids[-1] == self.tokenizer.eos_token_id:
|
|
||||||
# raise ValueError("token_ids must not end with eos_token_id")
|
|
||||||
|
|
||||||
if len(self.pad_tokens) == 0:
|
if len(self.pad_tokens) == 0:
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
|
if token_ids[0] == self.tokenizer.bos_token_id:
|
||||||
|
raise ValueError("token_ids must not start with bos_token_id")
|
||||||
|
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||||
|
raise ValueError("token_ids must not end with eos_token_id")
|
||||||
|
|
||||||
new_token_ids = []
|
new_token_ids = []
|
||||||
for token_id in token_ids:
|
for token_id in token_ids:
|
||||||
new_token_ids.append(token_id)
|
new_token_ids.append(token_id)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from .base import (
|
from .base import (
|
||||||
|
@ -1,9 +0,0 @@
|
|||||||
"""
|
|
||||||
Initialization file for invokeai.backend.prompting
|
|
||||||
"""
|
|
||||||
from .conditioning import (
|
|
||||||
get_prompt_structure,
|
|
||||||
get_tokens_for_prompt_object,
|
|
||||||
get_uc_and_c_and_ec,
|
|
||||||
split_weighted_subprompts,
|
|
||||||
)
|
|
@ -1,297 +0,0 @@
|
|||||||
"""
|
|
||||||
This module handles the generation of the conditioning tensors.
|
|
||||||
|
|
||||||
Useful function exports:
|
|
||||||
|
|
||||||
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
|
|
||||||
|
|
||||||
"""
|
|
||||||
import re
|
|
||||||
import torch
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from compel import Compel
|
|
||||||
from compel.prompt_parser import (
|
|
||||||
Blend,
|
|
||||||
CrossAttentionControlSubstitute,
|
|
||||||
FlattenedPrompt,
|
|
||||||
Fragment,
|
|
||||||
PromptParser,
|
|
||||||
Conjunction,
|
|
||||||
)
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
|
||||||
from ..util import torch_dtype
|
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
|
|
||||||
def get_uc_and_c_and_ec(prompt_string,
|
|
||||||
model: InvokeAIDiffuserComponent,
|
|
||||||
log_tokens=False, skip_normalize_legacy_blend=False):
|
|
||||||
# lazy-load any deferred textual inversions.
|
|
||||||
# this might take a couple of seconds the first time a textual inversion is used.
|
|
||||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
|
||||||
|
|
||||||
compel = Compel(tokenizer=model.tokenizer,
|
|
||||||
text_encoder=model.text_encoder,
|
|
||||||
textual_inversion_manager=model.textual_inversion_manager,
|
|
||||||
dtype_for_device_getter=torch_dtype,
|
|
||||||
truncate_long_prompts=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get rid of any newline characters
|
|
||||||
prompt_string = prompt_string.replace("\n", " ")
|
|
||||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
|
||||||
|
|
||||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
|
||||||
positive_conjunction: Conjunction
|
|
||||||
if legacy_blend is not None:
|
|
||||||
positive_conjunction = legacy_blend
|
|
||||||
else:
|
|
||||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
|
||||||
positive_prompt = positive_conjunction.prompts[0]
|
|
||||||
|
|
||||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
|
||||||
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
|
||||||
|
|
||||||
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
|
|
||||||
if log_tokens or config.log_tokenization:
|
|
||||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
|
||||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
|
||||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
|
||||||
cross_attention_control_args=options.get(
|
|
||||||
'cross_attention_control', None))
|
|
||||||
return uc, c, ec
|
|
||||||
|
|
||||||
def get_prompt_structure(
|
|
||||||
prompt_string, skip_normalize_legacy_blend: bool = False
|
|
||||||
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
|
||||||
(
|
|
||||||
positive_prompt_string,
|
|
||||||
negative_prompt_string,
|
|
||||||
) = split_prompt_to_positive_and_negative(prompt_string)
|
|
||||||
legacy_blend = try_parse_legacy_blend(
|
|
||||||
positive_prompt_string, skip_normalize_legacy_blend
|
|
||||||
)
|
|
||||||
positive_prompt: Conjunction
|
|
||||||
if legacy_blend is not None:
|
|
||||||
positive_conjunction = legacy_blend
|
|
||||||
else:
|
|
||||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
|
||||||
positive_prompt = positive_conjunction.prompts[0]
|
|
||||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
|
||||||
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
|
|
||||||
|
|
||||||
return positive_prompt, negative_prompt
|
|
||||||
|
|
||||||
def get_max_token_count(
|
|
||||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
|
||||||
) -> int:
|
|
||||||
if type(prompt) is Blend:
|
|
||||||
blend: Blend = prompt
|
|
||||||
return max(
|
|
||||||
[
|
|
||||||
get_max_token_count(tokenizer, c, truncate_if_too_long)
|
|
||||||
for c in blend.prompts
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return len(
|
|
||||||
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_tokens_for_prompt_object(
|
|
||||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
|
||||||
) -> [str]:
|
|
||||||
if type(parsed_prompt) is Blend:
|
|
||||||
raise ValueError(
|
|
||||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
|
||||||
)
|
|
||||||
|
|
||||||
text_fragments = [
|
|
||||||
x.text
|
|
||||||
if type(x) is Fragment
|
|
||||||
else (
|
|
||||||
" ".join([f.text for f in x.original])
|
|
||||||
if type(x) is CrossAttentionControlSubstitute
|
|
||||||
else str(x)
|
|
||||||
)
|
|
||||||
for x in parsed_prompt.children
|
|
||||||
]
|
|
||||||
text = " ".join(text_fragments)
|
|
||||||
tokens = tokenizer.tokenize(text)
|
|
||||||
if truncate_if_too_long:
|
|
||||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
|
||||||
tokens = tokens[0:max_tokens_length]
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
|
|
||||||
def split_prompt_to_positive_and_negative(prompt_string_uncleaned: str):
|
|
||||||
unconditioned_words = ""
|
|
||||||
unconditional_regex = r"\[(.*?)\]"
|
|
||||||
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
|
|
||||||
if len(unconditionals) > 0:
|
|
||||||
unconditioned_words = " ".join(unconditionals)
|
|
||||||
|
|
||||||
# Remove Unconditioned Words From Prompt
|
|
||||||
unconditional_regex_compile = re.compile(unconditional_regex)
|
|
||||||
clean_prompt = unconditional_regex_compile.sub(" ", prompt_string_uncleaned)
|
|
||||||
prompt_string_cleaned = re.sub(" +", " ", clean_prompt)
|
|
||||||
else:
|
|
||||||
prompt_string_cleaned = prompt_string_uncleaned
|
|
||||||
return prompt_string_cleaned, unconditioned_words
|
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization(
|
|
||||||
positive_prompt: Union[Blend, FlattenedPrompt],
|
|
||||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
|
||||||
tokenizer,
|
|
||||||
):
|
|
||||||
logger.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
|
|
||||||
logger.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
|
||||||
|
|
||||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
|
||||||
log_tokenization_for_prompt_object(
|
|
||||||
negative_prompt, tokenizer, display_label_prefix="(negative prompt)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_prompt_object(
|
|
||||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
|
||||||
):
|
|
||||||
display_label_prefix = display_label_prefix or ""
|
|
||||||
if type(p) is Blend:
|
|
||||||
blend: Blend = p
|
|
||||||
for i, c in enumerate(blend.prompts):
|
|
||||||
log_tokenization_for_prompt_object(
|
|
||||||
c,
|
|
||||||
tokenizer,
|
|
||||||
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
|
|
||||||
)
|
|
||||||
elif type(p) is FlattenedPrompt:
|
|
||||||
flattened_prompt: FlattenedPrompt = p
|
|
||||||
if flattened_prompt.wants_cross_attention_control:
|
|
||||||
original_fragments = []
|
|
||||||
edited_fragments = []
|
|
||||||
for f in flattened_prompt.children:
|
|
||||||
if type(f) is CrossAttentionControlSubstitute:
|
|
||||||
original_fragments += f.original
|
|
||||||
edited_fragments += f.edited
|
|
||||||
else:
|
|
||||||
original_fragments.append(f)
|
|
||||||
edited_fragments.append(f)
|
|
||||||
|
|
||||||
original_text = " ".join([x.text for x in original_fragments])
|
|
||||||
log_tokenization_for_text(
|
|
||||||
original_text,
|
|
||||||
tokenizer,
|
|
||||||
display_label=f"{display_label_prefix}(.swap originals)",
|
|
||||||
)
|
|
||||||
edited_text = " ".join([x.text for x in edited_fragments])
|
|
||||||
log_tokenization_for_text(
|
|
||||||
edited_text,
|
|
||||||
tokenizer,
|
|
||||||
display_label=f"{display_label_prefix}(.swap replacements)",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
text = " ".join([x.text for x in flattened_prompt.children])
|
|
||||||
log_tokenization_for_text(
|
|
||||||
text, tokenizer, display_label=display_label_prefix
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
|
||||||
"""shows how the prompt is tokenized
|
|
||||||
# usually tokens have '</w>' to indicate end-of-word,
|
|
||||||
# but for readability it has been replaced with ' '
|
|
||||||
"""
|
|
||||||
tokens = tokenizer.tokenize(text)
|
|
||||||
tokenized = ""
|
|
||||||
discarded = ""
|
|
||||||
usedTokens = 0
|
|
||||||
totalTokens = len(tokens)
|
|
||||||
|
|
||||||
for i in range(0, totalTokens):
|
|
||||||
token = tokens[i].replace("</w>", " ")
|
|
||||||
# alternate color
|
|
||||||
s = (usedTokens % 6) + 1
|
|
||||||
if truncate_if_too_long and i >= tokenizer.model_max_length:
|
|
||||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
|
||||||
else:
|
|
||||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
|
||||||
usedTokens += 1
|
|
||||||
|
|
||||||
if usedTokens > 0:
|
|
||||||
logger.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
|
||||||
logger.debug(f"{tokenized}\x1b[0m")
|
|
||||||
|
|
||||||
if discarded != "":
|
|
||||||
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
|
||||||
logger.debug(f"{discarded}\x1b[0m")
|
|
||||||
|
|
||||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
|
|
||||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
|
||||||
if len(weighted_subprompts) <= 1:
|
|
||||||
return None
|
|
||||||
strings = [x[0] for x in weighted_subprompts]
|
|
||||||
|
|
||||||
pp = PromptParser()
|
|
||||||
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
|
||||||
flattened_prompts = []
|
|
||||||
weights = []
|
|
||||||
for i, x in enumerate(parsed_conjunctions):
|
|
||||||
if len(x.prompts)>0:
|
|
||||||
flattened_prompts.append(x.prompts[0])
|
|
||||||
weights.append(weighted_subprompts[i][1])
|
|
||||||
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
|
|
||||||
|
|
||||||
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
|
||||||
"""
|
|
||||||
Legacy blend parsing.
|
|
||||||
|
|
||||||
grabs all text up to the first occurrence of ':'
|
|
||||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
|
||||||
if ':' has no value defined, defaults to 1.0
|
|
||||||
repeats until no text remaining
|
|
||||||
"""
|
|
||||||
prompt_parser = re.compile(
|
|
||||||
"""
|
|
||||||
(?P<prompt> # capture group for 'prompt'
|
|
||||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
|
||||||
) # end 'prompt'
|
|
||||||
(?: # non-capture group
|
|
||||||
:+ # match one or more ':' characters
|
|
||||||
(?P<weight> # capture group for 'weight'
|
|
||||||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
|
||||||
)? # end weight capture group, make optional
|
|
||||||
\s* # strip spaces after weight
|
|
||||||
| # OR
|
|
||||||
$ # else, if no ':' then match end of line
|
|
||||||
) # end non-capture group
|
|
||||||
""",
|
|
||||||
re.VERBOSE,
|
|
||||||
)
|
|
||||||
parsed_prompts = [
|
|
||||||
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
|
|
||||||
for match in re.finditer(prompt_parser, text)
|
|
||||||
]
|
|
||||||
if len(parsed_prompts) == 0:
|
|
||||||
return []
|
|
||||||
if skip_normalize:
|
|
||||||
return parsed_prompts
|
|
||||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
|
||||||
if weight_sum == 0:
|
|
||||||
logger.warning(
|
|
||||||
"Subprompt weights add up to zero. Discarding and using even weights instead."
|
|
||||||
)
|
|
||||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
|
||||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
|
||||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
|
@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for the invokeai.backend.stable_diffusion package
|
Initialization file for the invokeai.backend.stable_diffusion package
|
||||||
"""
|
"""
|
||||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
|
||||||
from .diffusers_pipeline import (
|
from .diffusers_pipeline import (
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
PipelineIntermediateState,
|
PipelineIntermediateState,
|
||||||
@ -10,4 +9,3 @@ from .diffusers_pipeline import (
|
|||||||
from .diffusion import InvokeAIDiffuserComponent
|
from .diffusion import InvokeAIDiffuserComponent
|
||||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from .textual_inversion_manager import TextualInversionManager
|
|
||||||
|
@ -1,275 +0,0 @@
|
|||||||
"""
|
|
||||||
Query and install embeddings from the HuggingFace SD Concepts Library
|
|
||||||
at https://huggingface.co/sd-concepts-library.
|
|
||||||
|
|
||||||
The interface is through the Concepts() object.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from typing import Callable
|
|
||||||
from urllib import error as ul_error
|
|
||||||
from urllib import request
|
|
||||||
|
|
||||||
from huggingface_hub import (
|
|
||||||
HfApi,
|
|
||||||
HfFolder,
|
|
||||||
ModelFilter,
|
|
||||||
hf_hub_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
logger = InvokeAILogger.getLogger()
|
|
||||||
|
|
||||||
class HuggingFaceConceptsLibrary(object):
|
|
||||||
def __init__(self, root=None):
|
|
||||||
"""
|
|
||||||
Initialize the Concepts object. May optionally pass a root directory.
|
|
||||||
"""
|
|
||||||
self.config = InvokeAIAppConfig.get_config()
|
|
||||||
self.root = root or self.config.root
|
|
||||||
self.hf_api = HfApi()
|
|
||||||
self.local_concepts = dict()
|
|
||||||
self.concept_list = None
|
|
||||||
self.concepts_loaded = dict()
|
|
||||||
self.triggers = dict() # concept name to trigger phrase
|
|
||||||
self.concept_names = dict() # trigger phrase to concept name
|
|
||||||
self.match_trigger = re.compile(
|
|
||||||
"(<[\w\- >]+>)"
|
|
||||||
) # trigger is slightly less restrictive than HF concept name
|
|
||||||
self.match_concept = re.compile(
|
|
||||||
"<([\w\-]+)>"
|
|
||||||
) # HF concept name can only contain A-Za-z0-9_-
|
|
||||||
|
|
||||||
def list_concepts(self) -> list:
|
|
||||||
"""
|
|
||||||
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
|
|
||||||
Also adds local concepts in invokeai/embeddings folder.
|
|
||||||
"""
|
|
||||||
local_concepts_now = self.get_local_concepts(
|
|
||||||
os.path.join(self.root, "embeddings")
|
|
||||||
)
|
|
||||||
local_concepts_to_add = set(local_concepts_now).difference(
|
|
||||||
set(self.local_concepts)
|
|
||||||
)
|
|
||||||
self.local_concepts.update(local_concepts_now)
|
|
||||||
|
|
||||||
if self.concept_list is not None:
|
|
||||||
if local_concepts_to_add:
|
|
||||||
self.concept_list.extend(list(local_concepts_to_add))
|
|
||||||
return self.concept_list
|
|
||||||
return self.concept_list
|
|
||||||
elif self.config.internet_available is True:
|
|
||||||
try:
|
|
||||||
models = self.hf_api.list_models(
|
|
||||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
|
||||||
)
|
|
||||||
self.concept_list = [a.id.split("/")[1] for a in models]
|
|
||||||
# when init, add all in dir. when not init, add only concepts added between init and now
|
|
||||||
self.concept_list.extend(list(local_concepts_to_add))
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
|
||||||
)
|
|
||||||
return self.concept_list
|
|
||||||
else:
|
|
||||||
return self.concept_list
|
|
||||||
|
|
||||||
def get_concept_model_path(self, concept_name: str) -> str:
|
|
||||||
"""
|
|
||||||
Returns the path to the 'learned_embeds.bin' file in
|
|
||||||
the named concept. Returns None if invalid or cannot
|
|
||||||
be downloaded.
|
|
||||||
"""
|
|
||||||
if not concept_name in self.list_concepts():
|
|
||||||
logger.warning(
|
|
||||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
|
|
||||||
|
|
||||||
def concept_to_trigger(self, concept_name: str) -> str:
|
|
||||||
"""
|
|
||||||
Given a concept name returns its trigger by looking in the
|
|
||||||
"token_identifier.txt" file.
|
|
||||||
"""
|
|
||||||
if concept_name in self.triggers:
|
|
||||||
return self.triggers[concept_name]
|
|
||||||
elif self.concept_is_local(concept_name):
|
|
||||||
trigger = f"<{concept_name}>"
|
|
||||||
self.triggers[concept_name] = trigger
|
|
||||||
self.concept_names[trigger] = concept_name
|
|
||||||
return trigger
|
|
||||||
|
|
||||||
file = self.get_concept_file(
|
|
||||||
concept_name, "token_identifier.txt", local_only=True
|
|
||||||
)
|
|
||||||
if not file:
|
|
||||||
return None
|
|
||||||
with open(file, "r") as f:
|
|
||||||
trigger = f.readline()
|
|
||||||
trigger = trigger.strip()
|
|
||||||
self.triggers[concept_name] = trigger
|
|
||||||
self.concept_names[trigger] = concept_name
|
|
||||||
return trigger
|
|
||||||
|
|
||||||
def trigger_to_concept(self, trigger: str) -> str:
|
|
||||||
"""
|
|
||||||
Given a trigger phrase, maps it to the concept library name.
|
|
||||||
Only works if concept_to_trigger() has previously been called
|
|
||||||
on this library. There needs to be a persistent database for
|
|
||||||
this.
|
|
||||||
"""
|
|
||||||
concept = self.concept_names.get(trigger, None)
|
|
||||||
return f"<{concept}>" if concept else f"{trigger}"
|
|
||||||
|
|
||||||
def replace_triggers_with_concepts(self, prompt: str) -> str:
|
|
||||||
"""
|
|
||||||
Given a prompt string that contains <trigger> tags, replace these
|
|
||||||
tags with the concept name. The reason for this is so that the
|
|
||||||
concept names get stored in the prompt metadata. There is no
|
|
||||||
controlling of colliding triggers in the SD library, so it is
|
|
||||||
better to store the concept name (unique) than the concept trigger
|
|
||||||
(not necessarily unique!)
|
|
||||||
"""
|
|
||||||
if not prompt:
|
|
||||||
return prompt
|
|
||||||
triggers = self.match_trigger.findall(prompt)
|
|
||||||
if not triggers:
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def do_replace(match) -> str:
|
|
||||||
return self.trigger_to_concept(match.group(1)) or f"<{match.group(1)}>"
|
|
||||||
|
|
||||||
return self.match_trigger.sub(do_replace, prompt)
|
|
||||||
|
|
||||||
def replace_concepts_with_triggers(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
load_concepts_callback: Callable[[list], any],
|
|
||||||
excluded_tokens: list[str],
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Given a prompt string that contains `<concept_name>` tags, replace
|
|
||||||
these tags with the appropriate trigger.
|
|
||||||
|
|
||||||
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
|
||||||
of `concepts_name` strings.
|
|
||||||
|
|
||||||
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
|
||||||
are trigger tokens from a locally-loaded embedding.
|
|
||||||
"""
|
|
||||||
concepts = self.match_concept.findall(prompt)
|
|
||||||
if not concepts:
|
|
||||||
return prompt
|
|
||||||
load_concepts_callback(concepts)
|
|
||||||
|
|
||||||
def do_replace(match) -> str:
|
|
||||||
if excluded_tokens and f"<{match.group(1)}>" in excluded_tokens:
|
|
||||||
return f"<{match.group(1)}>"
|
|
||||||
return self.concept_to_trigger(match.group(1)) or f"<{match.group(1)}>"
|
|
||||||
|
|
||||||
return self.match_concept.sub(do_replace, prompt)
|
|
||||||
|
|
||||||
def get_concept_file(
|
|
||||||
self,
|
|
||||||
concept_name: str,
|
|
||||||
file_name: str = "learned_embeds.bin",
|
|
||||||
local_only: bool = False,
|
|
||||||
) -> str:
|
|
||||||
if not (
|
|
||||||
self.concept_is_downloaded(concept_name)
|
|
||||||
or self.concept_is_local(concept_name)
|
|
||||||
or local_only
|
|
||||||
):
|
|
||||||
self.download_concept(concept_name)
|
|
||||||
|
|
||||||
# get local path in invokeai/embeddings if local concept
|
|
||||||
if self.concept_is_local(concept_name):
|
|
||||||
concept_path = self._concept_local_path(concept_name)
|
|
||||||
path = concept_path
|
|
||||||
else:
|
|
||||||
concept_path = self._concept_path(concept_name)
|
|
||||||
path = os.path.join(concept_path, file_name)
|
|
||||||
return path if os.path.exists(path) else None
|
|
||||||
|
|
||||||
def concept_is_local(self, concept_name) -> bool:
|
|
||||||
return concept_name in self.local_concepts
|
|
||||||
|
|
||||||
def concept_is_downloaded(self, concept_name) -> bool:
|
|
||||||
concept_directory = self._concept_path(concept_name)
|
|
||||||
return os.path.exists(concept_directory)
|
|
||||||
|
|
||||||
def download_concept(self, concept_name) -> bool:
|
|
||||||
repo_id = self._concept_id(concept_name)
|
|
||||||
dest = self._concept_path(concept_name)
|
|
||||||
|
|
||||||
access_token = HfFolder.get_token()
|
|
||||||
header = [("Authorization", f"Bearer {access_token}")] if access_token else []
|
|
||||||
opener = request.build_opener()
|
|
||||||
opener.addheaders = header
|
|
||||||
request.install_opener(opener)
|
|
||||||
|
|
||||||
os.makedirs(dest, exist_ok=True)
|
|
||||||
succeeded = True
|
|
||||||
|
|
||||||
bytes = 0
|
|
||||||
|
|
||||||
def tally_download_size(chunk, size, total):
|
|
||||||
nonlocal bytes
|
|
||||||
if chunk == 0:
|
|
||||||
bytes += total
|
|
||||||
|
|
||||||
logger.info(f"Downloading {repo_id}...", end="")
|
|
||||||
try:
|
|
||||||
for file in (
|
|
||||||
"README.md",
|
|
||||||
"learned_embeds.bin",
|
|
||||||
"token_identifier.txt",
|
|
||||||
"type_of_concept.txt",
|
|
||||||
):
|
|
||||||
url = hf_hub_url(repo_id, file)
|
|
||||||
request.urlretrieve(
|
|
||||||
url, os.path.join(dest, file), reporthook=tally_download_size
|
|
||||||
)
|
|
||||||
except ul_error.HTTPError as e:
|
|
||||||
if e.code == 404:
|
|
||||||
logger.warning(
|
|
||||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
|
||||||
)
|
|
||||||
os.rmdir(dest)
|
|
||||||
return False
|
|
||||||
except ul_error.URLError as e:
|
|
||||||
logger.error(
|
|
||||||
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
|
||||||
)
|
|
||||||
os.rmdir(dest)
|
|
||||||
return False
|
|
||||||
logger.info("...{:.2f}Kb".format(bytes / 1024))
|
|
||||||
return succeeded
|
|
||||||
|
|
||||||
def _concept_id(self, concept_name: str) -> str:
|
|
||||||
return f"sd-concepts-library/{concept_name}"
|
|
||||||
|
|
||||||
def _concept_path(self, concept_name: str) -> str:
|
|
||||||
return os.path.join(self.root, "models", "sd-concepts-library", concept_name)
|
|
||||||
|
|
||||||
def _concept_local_path(self, concept_name: str) -> str:
|
|
||||||
filename = self.local_concepts[concept_name]
|
|
||||||
return os.path.join(self.root, "embeddings", filename)
|
|
||||||
|
|
||||||
def get_local_concepts(self, loc_dir: str):
|
|
||||||
locs_dic = dict()
|
|
||||||
if os.path.isdir(loc_dir):
|
|
||||||
for file in os.listdir(loc_dir):
|
|
||||||
f = os.path.splitext(file)
|
|
||||||
if f[1] == ".bin" or f[1] == ".pt":
|
|
||||||
locs_dic[f[0]] = file
|
|
||||||
return locs_dic
|
|
@ -16,7 +16,6 @@ from accelerate.utils import set_seed
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from compel import EmbeddingsProvider
|
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
@ -48,7 +47,6 @@ from .diffusion import (
|
|||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
)
|
)
|
||||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||||
from .textual_inversion_manager import TextualInversionManager
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineIntermediateState:
|
class PipelineIntermediateState:
|
||||||
@ -317,6 +315,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
requires_safety_checker: bool = False,
|
requires_safety_checker: bool = False,
|
||||||
precision: str = "float32",
|
precision: str = "float32",
|
||||||
control_model: ControlNetModel = None,
|
control_model: ControlNetModel = None,
|
||||||
|
execution_device: Optional[torch.device] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vae,
|
vae,
|
||||||
@ -341,22 +340,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# control_model=control_model,
|
# control_model=control_model,
|
||||||
)
|
)
|
||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||||
self.unet, self._unet_forward, is_running_diffusers=True
|
self.unet, self._unet_forward
|
||||||
)
|
|
||||||
use_full_precision = precision == "float32" or precision == "autocast"
|
|
||||||
self.textual_inversion_manager = TextualInversionManager(
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
text_encoder=self.text_encoder,
|
|
||||||
full_precision=use_full_precision,
|
|
||||||
)
|
|
||||||
# InvokeAI's interface for text embeddings and whatnot
|
|
||||||
self.embeddings_provider = EmbeddingsProvider(
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
text_encoder=self.text_encoder,
|
|
||||||
textual_inversion_manager=self.textual_inversion_manager,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
||||||
self._model_group.install(*self._submodels)
|
self._model_group.install(*self._submodels)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
|
|
||||||
@ -404,50 +391,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
else:
|
else:
|
||||||
self.disable_attention_slicing()
|
self.disable_attention_slicing()
|
||||||
|
|
||||||
def enable_offload_submodels(self, device: torch.device):
|
|
||||||
"""
|
|
||||||
Offload each submodel when it's not in use.
|
|
||||||
|
|
||||||
Useful for low-vRAM situations where the size of the model in memory is a big chunk of
|
|
||||||
the total available resource, and you want to free up as much for inference as possible.
|
|
||||||
|
|
||||||
This requires more moving parts and may add some delay as the U-Net is swapped out for the
|
|
||||||
VAE and vice-versa.
|
|
||||||
"""
|
|
||||||
models = self._submodels
|
|
||||||
if self._model_group is not None:
|
|
||||||
self._model_group.uninstall(*models)
|
|
||||||
group = LazilyLoadedModelGroup(device)
|
|
||||||
group.install(*models)
|
|
||||||
self._model_group = group
|
|
||||||
|
|
||||||
def disable_offload_submodels(self):
|
|
||||||
"""
|
|
||||||
Leave all submodels loaded.
|
|
||||||
|
|
||||||
Appropriate for cases where the size of the model in memory is small compared to the memory
|
|
||||||
required for inference. Avoids the delay and complexity of shuffling the submodels to and
|
|
||||||
from the GPU.
|
|
||||||
"""
|
|
||||||
models = self._submodels
|
|
||||||
if self._model_group is not None:
|
|
||||||
self._model_group.uninstall(*models)
|
|
||||||
group = FullyLoadedModelGroup(self._model_group.execution_device)
|
|
||||||
group.install(*models)
|
|
||||||
self._model_group = group
|
|
||||||
|
|
||||||
def offload_all(self):
|
|
||||||
"""Offload all this pipeline's models to CPU."""
|
|
||||||
self._model_group.offload_current()
|
|
||||||
|
|
||||||
def ready(self):
|
|
||||||
"""
|
|
||||||
Ready this pipeline's models.
|
|
||||||
|
|
||||||
i.e. preload them to the GPU if appropriate.
|
|
||||||
"""
|
|
||||||
self._model_group.ready()
|
|
||||||
|
|
||||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||||
# overridden method; types match the superclass.
|
# overridden method; types match the superclass.
|
||||||
if torch_device is None:
|
if torch_device is None:
|
||||||
@ -991,25 +934,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
device = self._model_group.device_for(self.safety_checker)
|
device = self._model_group.device_for(self.safety_checker)
|
||||||
return super().run_safety_checker(image, device, dtype)
|
return super().run_safety_checker(image, device, dtype)
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def get_learned_conditioning(
|
|
||||||
self, c: List[List[str]], *, return_tokens=True, fragment_weights=None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Compatibility function for invokeai.models.diffusion.ddpm.LatentDiffusion.
|
|
||||||
"""
|
|
||||||
return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments(
|
|
||||||
text_batch=c,
|
|
||||||
fragment_weights_batch=fragment_weights,
|
|
||||||
should_return_tokens=return_tokens,
|
|
||||||
device=self._model_group.device_for(self.unet),
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def channels(self) -> int:
|
|
||||||
"""Compatible with DiffusionWrapper"""
|
|
||||||
return self.unet.config.in_channels
|
|
||||||
|
|
||||||
def decode_latents(self, latents):
|
def decode_latents(self, latents):
|
||||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||||
self._model_group.load(self.vae)
|
self._model_group.load(self.vae)
|
||||||
|
@ -18,7 +18,6 @@ from .cross_attention_control import (
|
|||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
SwapCrossAttnContext,
|
SwapCrossAttnContext,
|
||||||
get_cross_attention_modules,
|
get_cross_attention_modules,
|
||||||
restore_default_cross_attention,
|
|
||||||
setup_cross_attention_control_attention_processors,
|
setup_cross_attention_control_attention_processors,
|
||||||
)
|
)
|
||||||
from .cross_attention_map_saving import AttentionMapSaver
|
from .cross_attention_map_saving import AttentionMapSaver
|
||||||
@ -66,7 +65,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
model_forward_callback: ModelForwardCallback,
|
model_forward_callback: ModelForwardCallback,
|
||||||
is_running_diffusers: bool = False,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param model: the unet model to pass through to cross attention control
|
:param model: the unet model to pass through to cross attention control
|
||||||
@ -75,7 +73,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.model = model
|
self.model = model
|
||||||
self.is_running_diffusers = is_running_diffusers
|
|
||||||
self.model_forward_callback = model_forward_callback
|
self.model_forward_callback = model_forward_callback
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
self.sequential_guidance = config.sequential_guidance
|
self.sequential_guidance = config.sequential_guidance
|
||||||
@ -112,37 +109,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
# TODO resuscitate attention map saving
|
# TODO resuscitate attention map saving
|
||||||
# self.remove_attention_map_saving()
|
# self.remove_attention_map_saving()
|
||||||
|
|
||||||
# apparently unused code
|
|
||||||
# TODO: delete
|
|
||||||
# def override_cross_attention(
|
|
||||||
# self, conditioning: ExtraConditioningInfo, step_count: int
|
|
||||||
# ) -> Dict[str, AttentionProcessor]:
|
|
||||||
# """
|
|
||||||
# setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
|
||||||
# the previous attention processor is returned so that the caller can restore it later.
|
|
||||||
# """
|
|
||||||
# self.conditioning = conditioning
|
|
||||||
# self.cross_attention_control_context = Context(
|
|
||||||
# arguments=self.conditioning.cross_attention_control_args,
|
|
||||||
# step_count=step_count,
|
|
||||||
# )
|
|
||||||
# return override_cross_attention(
|
|
||||||
# self.model,
|
|
||||||
# self.cross_attention_control_context,
|
|
||||||
# is_running_diffusers=self.is_running_diffusers,
|
|
||||||
# )
|
|
||||||
|
|
||||||
def restore_default_cross_attention(
|
|
||||||
self, restore_attention_processor: Optional["AttentionProcessor"] = None
|
|
||||||
):
|
|
||||||
self.conditioning = None
|
|
||||||
self.cross_attention_control_context = None
|
|
||||||
restore_default_cross_attention(
|
|
||||||
self.model,
|
|
||||||
is_running_diffusers=self.is_running_diffusers,
|
|
||||||
restore_attention_processor=restore_attention_processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||||
def callback(slice, dim, offset, slice_size, key):
|
def callback(slice, dim, offset, slice_size, key):
|
||||||
if dim is not None:
|
if dim is not None:
|
||||||
@ -204,9 +170,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
percent_through = self.calculate_percent_through(
|
percent_through = step_index / total_step_count
|
||||||
sigma, step_index, total_step_count
|
|
||||||
)
|
|
||||||
cross_attention_control_types_to_do = (
|
cross_attention_control_types_to_do = (
|
||||||
context.get_active_cross_attention_control_types_for_step(
|
context.get_active_cross_attention_control_types_for_step(
|
||||||
percent_through
|
percent_through
|
||||||
@ -264,9 +228,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
total_step_count,
|
total_step_count,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if postprocessing_settings is not None:
|
if postprocessing_settings is not None:
|
||||||
percent_through = self.calculate_percent_through(
|
percent_through = step_index / total_step_count
|
||||||
sigma, step_index, total_step_count
|
|
||||||
)
|
|
||||||
latents = self.apply_threshold(
|
latents = self.apply_threshold(
|
||||||
postprocessing_settings, latents, percent_through
|
postprocessing_settings, latents, percent_through
|
||||||
)
|
)
|
||||||
@ -275,22 +237,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
|
||||||
if step_index is not None and total_step_count is not None:
|
|
||||||
# 🧨diffusers codepath
|
|
||||||
percent_through = (
|
|
||||||
step_index / total_step_count
|
|
||||||
) # will never reach 1.0 - this is deliberate
|
|
||||||
else:
|
|
||||||
# legacy compvis codepath
|
|
||||||
# TODO remove when compvis codepath support is dropped
|
|
||||||
if step_index is None and sigma is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Either step_index or sigma is required when doing cross attention control, but both are None."
|
|
||||||
)
|
|
||||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
|
||||||
return percent_through
|
|
||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
@ -323,6 +269,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
conditioned_next_x = conditioned_next_x.clone()
|
conditioned_next_x = conditioned_next_x.clone()
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
# TODO: looks unused
|
||||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
assert isinstance(conditioning, dict)
|
assert isinstance(conditioning, dict)
|
||||||
assert isinstance(unconditioning, dict)
|
assert isinstance(unconditioning, dict)
|
||||||
@ -350,34 +297,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
|
||||||
if self.is_running_diffusers:
|
|
||||||
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
|
||||||
x,
|
|
||||||
sigma,
|
|
||||||
unconditioning,
|
|
||||||
conditioning,
|
|
||||||
cross_attention_control_types_to_do,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self._apply_cross_attention_controlled_conditioning__compvis(
|
|
||||||
x,
|
|
||||||
sigma,
|
|
||||||
unconditioning,
|
|
||||||
conditioning,
|
|
||||||
cross_attention_control_types_to_do,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _apply_cross_attention_controlled_conditioning__diffusers(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
sigma,
|
|
||||||
unconditioning,
|
|
||||||
conditioning,
|
|
||||||
cross_attention_control_types_to_do,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
@ -409,54 +328,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
def _apply_cross_attention_controlled_conditioning__compvis(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
sigma,
|
|
||||||
unconditioning,
|
|
||||||
conditioning,
|
|
||||||
cross_attention_control_types_to_do,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
|
||||||
# slower non-batched path (20% slower on mac MPS)
|
|
||||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
|
||||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
|
||||||
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
|
|
||||||
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
|
||||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
|
||||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
|
||||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
|
||||||
context: Context = self.cross_attention_control_context
|
|
||||||
|
|
||||||
try:
|
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
|
||||||
|
|
||||||
# process x using the original prompt, saving the attention maps
|
|
||||||
# print("saving attention maps for", cross_attention_control_types_to_do)
|
|
||||||
for ca_type in cross_attention_control_types_to_do:
|
|
||||||
context.request_save_attention_maps(ca_type)
|
|
||||||
_ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
|
|
||||||
context.clear_requests(cleanup=False)
|
|
||||||
|
|
||||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
|
||||||
# print("applying saved attention maps for", cross_attention_control_types_to_do)
|
|
||||||
for ca_type in cross_attention_control_types_to_do:
|
|
||||||
context.request_apply_saved_attention_maps(ca_type)
|
|
||||||
edited_conditioning = (
|
|
||||||
self.conditioning.cross_attention_control_args.edited_conditioning
|
|
||||||
)
|
|
||||||
conditioned_next_x = self.model_forward_callback(
|
|
||||||
x, sigma, edited_conditioning, **kwargs,
|
|
||||||
)
|
|
||||||
context.clear_requests(cleanup=True)
|
|
||||||
|
|
||||||
except:
|
|
||||||
context.clear_requests(cleanup=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return unconditioned_next_x, conditioned_next_x
|
|
||||||
|
|
||||||
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
||||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||||
|
@ -1,429 +0,0 @@
|
|||||||
import traceback
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Union, List
|
|
||||||
|
|
||||||
import safetensors.torch
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
|
||||||
from picklescan.scanner import scan_file_path
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EmbeddingInfo:
|
|
||||||
name: str
|
|
||||||
embedding: torch.Tensor
|
|
||||||
num_vectors_per_token: int
|
|
||||||
token_dim: int
|
|
||||||
trained_steps: int = None
|
|
||||||
trained_model_name: str = None
|
|
||||||
trained_model_checksum: str = None
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TextualInversion:
|
|
||||||
trigger_string: str
|
|
||||||
embedding: torch.Tensor
|
|
||||||
trigger_token_id: Optional[int] = None
|
|
||||||
pad_token_ids: Optional[list[int]] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embedding_vector_length(self) -> int:
|
|
||||||
return self.embedding.shape[0]
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionManager(BaseTextualInversionManager):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenizer: CLIPTokenizer,
|
|
||||||
text_encoder: CLIPTextModel,
|
|
||||||
full_precision: bool = True,
|
|
||||||
):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.text_encoder = text_encoder
|
|
||||||
self.full_precision = full_precision
|
|
||||||
self.hf_concepts_library = HuggingFaceConceptsLibrary()
|
|
||||||
self.trigger_to_sourcefile = dict()
|
|
||||||
default_textual_inversions: list[TextualInversion] = []
|
|
||||||
self.textual_inversions = default_textual_inversions
|
|
||||||
|
|
||||||
def load_huggingface_concepts(self, concepts: list[str]):
|
|
||||||
for concept_name in concepts:
|
|
||||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
|
||||||
continue
|
|
||||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
|
||||||
if (
|
|
||||||
self.has_textual_inversion_for_trigger_string(trigger)
|
|
||||||
or self.has_textual_inversion_for_trigger_string(concept_name)
|
|
||||||
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
|
||||||
): # in case a token with literal angle brackets encountered
|
|
||||||
logger.info(f"Loaded local embedding for trigger {concept_name}")
|
|
||||||
continue
|
|
||||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
|
||||||
if not bin_file:
|
|
||||||
continue
|
|
||||||
logger.info(f"Loaded remote embedding for trigger {concept_name}")
|
|
||||||
self.load_textual_inversion(bin_file)
|
|
||||||
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
|
||||||
|
|
||||||
def get_all_trigger_strings(self) -> list[str]:
|
|
||||||
return [ti.trigger_string for ti in self.textual_inversions]
|
|
||||||
|
|
||||||
def load_textual_inversion(
|
|
||||||
self, ckpt_path: Union[str, Path], defer_injecting_tokens: bool = False
|
|
||||||
):
|
|
||||||
ckpt_path = Path(ckpt_path)
|
|
||||||
|
|
||||||
if not ckpt_path.is_file():
|
|
||||||
return
|
|
||||||
|
|
||||||
if str(ckpt_path).endswith(".DS_Store"):
|
|
||||||
return
|
|
||||||
|
|
||||||
embedding_list = self._parse_embedding(str(ckpt_path))
|
|
||||||
for embedding_info in embedding_list:
|
|
||||||
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
|
||||||
logger.warning(
|
|
||||||
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Resolve the situation in which an earlier embedding has claimed the same
|
|
||||||
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
|
||||||
trigger_str = embedding_info.name
|
|
||||||
sourcefile = (
|
|
||||||
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
|
||||||
if ckpt_path.name == "learned_embeds.bin"
|
|
||||||
else ckpt_path.name
|
|
||||||
)
|
|
||||||
|
|
||||||
if trigger_str in self.trigger_to_sourcefile:
|
|
||||||
replacement_trigger_str = (
|
|
||||||
f"<{ckpt_path.parent.name}>"
|
|
||||||
if ckpt_path.name == "learned_embeds.bin"
|
|
||||||
else f"<{ckpt_path.stem}>"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
|
||||||
)
|
|
||||||
trigger_str = replacement_trigger_str
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._add_textual_inversion(
|
|
||||||
trigger_str,
|
|
||||||
embedding_info.embedding,
|
|
||||||
defer_injecting_tokens=defer_injecting_tokens,
|
|
||||||
)
|
|
||||||
# remember which source file claims this trigger
|
|
||||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
logger.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
|
|
||||||
logger.debug(f"The error was {str(e)}")
|
|
||||||
|
|
||||||
def _add_textual_inversion(
|
|
||||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
|
||||||
) -> Optional[TextualInversion]:
|
|
||||||
"""
|
|
||||||
Add a textual inversion to be recognised.
|
|
||||||
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
|
||||||
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
|
|
||||||
:return: The token id for the added embedding, either existing or newly-added.
|
|
||||||
"""
|
|
||||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
|
||||||
logger.warning(
|
|
||||||
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
if not self.full_precision:
|
|
||||||
embedding = embedding.half()
|
|
||||||
if len(embedding.shape) == 1:
|
|
||||||
embedding = embedding.unsqueeze(0)
|
|
||||||
elif len(embedding.shape) > 2:
|
|
||||||
raise ValueError(
|
|
||||||
f"** TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
ti = TextualInversion(trigger_string=trigger_str, embedding=embedding)
|
|
||||||
if not defer_injecting_tokens:
|
|
||||||
self._inject_tokens_and_assign_embeddings(ti)
|
|
||||||
self.textual_inversions.append(ti)
|
|
||||||
return ti
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
if str(e).startswith("Warning"):
|
|
||||||
logger.warning(f"{str(e)}")
|
|
||||||
else:
|
|
||||||
traceback.print_exc()
|
|
||||||
logger.error(
|
|
||||||
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int:
|
|
||||||
if ti.trigger_token_id is not None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(
|
|
||||||
ti.trigger_string, ti.embedding[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
if ti.embedding_vector_length > 1:
|
|
||||||
# for embeddings with vector length > 1
|
|
||||||
pad_token_strings = [
|
|
||||||
ti.trigger_string + "-!pad-" + str(pad_index)
|
|
||||||
for pad_index in range(1, ti.embedding_vector_length)
|
|
||||||
]
|
|
||||||
# todo: batched UI for faster loading when vector length >2
|
|
||||||
pad_token_ids = [
|
|
||||||
self._get_or_create_token_id_and_assign_embedding(
|
|
||||||
pad_token_str, ti.embedding[1 + i]
|
|
||||||
)
|
|
||||||
for (i, pad_token_str) in enumerate(pad_token_strings)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
pad_token_ids = []
|
|
||||||
|
|
||||||
ti.trigger_token_id = trigger_token_id
|
|
||||||
ti.pad_token_ids = pad_token_ids
|
|
||||||
return ti.trigger_token_id
|
|
||||||
|
|
||||||
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
|
|
||||||
try:
|
|
||||||
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
|
|
||||||
return ti is not None
|
|
||||||
except StopIteration:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_textual_inversion_for_trigger_string(
|
|
||||||
self, trigger_string: str
|
|
||||||
) -> TextualInversion:
|
|
||||||
return next(
|
|
||||||
ti for ti in self.textual_inversions if ti.trigger_string == trigger_string
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
|
|
||||||
return next(
|
|
||||||
ti for ti in self.textual_inversions if ti.trigger_token_id == token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_deferred_token_ids_for_any_trigger_terms(
|
|
||||||
self, prompt_string: str
|
|
||||||
) -> list[int]:
|
|
||||||
injected_token_ids = []
|
|
||||||
for ti in self.textual_inversions:
|
|
||||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
|
||||||
if ti.embedding_vector_length > 1:
|
|
||||||
logger.info(
|
|
||||||
f"Preparing tokens for textual inversion {ti.trigger_string}..."
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
self._inject_tokens_and_assign_embeddings(ti)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.debug(
|
|
||||||
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
|
|
||||||
)
|
|
||||||
logger.debug(f"The error was {str(e)}")
|
|
||||||
continue
|
|
||||||
injected_token_ids.append(ti.trigger_token_id)
|
|
||||||
injected_token_ids.extend(ti.pad_token_ids)
|
|
||||||
return injected_token_ids
|
|
||||||
|
|
||||||
def expand_textual_inversion_token_ids_if_necessary(
|
|
||||||
self, prompt_token_ids: list[int]
|
|
||||||
) -> list[int]:
|
|
||||||
"""
|
|
||||||
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
|
|
||||||
|
|
||||||
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
|
|
||||||
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
|
|
||||||
long - caller is responsible for prepending/appending eos and bos token ids, and truncating if necessary.
|
|
||||||
"""
|
|
||||||
if len(prompt_token_ids) == 0:
|
|
||||||
return prompt_token_ids
|
|
||||||
|
|
||||||
if prompt_token_ids[0] == self.tokenizer.bos_token_id:
|
|
||||||
raise ValueError("prompt_token_ids must not start with bos_token_id")
|
|
||||||
if prompt_token_ids[-1] == self.tokenizer.eos_token_id:
|
|
||||||
raise ValueError("prompt_token_ids must not end with eos_token_id")
|
|
||||||
textual_inversion_trigger_token_ids = [
|
|
||||||
ti.trigger_token_id for ti in self.textual_inversions
|
|
||||||
]
|
|
||||||
prompt_token_ids = prompt_token_ids.copy()
|
|
||||||
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
|
||||||
if token_id in textual_inversion_trigger_token_ids:
|
|
||||||
textual_inversion = next(
|
|
||||||
ti
|
|
||||||
for ti in self.textual_inversions
|
|
||||||
if ti.trigger_token_id == token_id
|
|
||||||
)
|
|
||||||
for pad_idx in range(0, textual_inversion.embedding_vector_length - 1):
|
|
||||||
prompt_token_ids.insert(
|
|
||||||
i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx]
|
|
||||||
)
|
|
||||||
|
|
||||||
return prompt_token_ids
|
|
||||||
|
|
||||||
def _get_or_create_token_id_and_assign_embedding(
|
|
||||||
self, token_str: str, embedding: torch.Tensor
|
|
||||||
) -> int:
|
|
||||||
if len(embedding.shape) != 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2"
|
|
||||||
)
|
|
||||||
existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
|
||||||
if existing_token_id == self.tokenizer.unk_token_id:
|
|
||||||
num_tokens_added = self.tokenizer.add_tokens(token_str)
|
|
||||||
current_embeddings = self.text_encoder.resize_token_embeddings(None)
|
|
||||||
current_token_count = current_embeddings.num_embeddings
|
|
||||||
new_token_count = current_token_count + num_tokens_added
|
|
||||||
# the following call is slow - todo make batched for better performance with vector length >1
|
|
||||||
self.text_encoder.resize_token_embeddings(new_token_count)
|
|
||||||
|
|
||||||
token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
|
||||||
if token_id == self.tokenizer.unk_token_id:
|
|
||||||
raise RuntimeError(f"Unable to find token id for token '{token_str}'")
|
|
||||||
if (
|
|
||||||
self.text_encoder.get_input_embeddings().weight.data[token_id].shape
|
|
||||||
!= embedding.shape
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}."
|
|
||||||
)
|
|
||||||
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
|
||||||
|
|
||||||
return token_id
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
|
|
||||||
suffix = Path(embedding_file).suffix
|
|
||||||
try:
|
|
||||||
if suffix in [".pt",".ckpt",".bin"]:
|
|
||||||
scan_result = scan_file_path(embedding_file)
|
|
||||||
if scan_result.infected_files > 0:
|
|
||||||
logger.critical(
|
|
||||||
f"Security Issues Found in Model: {scan_result.issues_count}"
|
|
||||||
)
|
|
||||||
logger.critical("For your safety, InvokeAI will not load this embed.")
|
|
||||||
return list()
|
|
||||||
ckpt = torch.load(embedding_file,map_location="cpu")
|
|
||||||
else:
|
|
||||||
ckpt = safetensors.torch.load_file(embedding_file)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
|
||||||
return list()
|
|
||||||
|
|
||||||
# try to figure out what kind of embedding file it is and parse accordingly
|
|
||||||
keys = list(ckpt.keys())
|
|
||||||
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
|
|
||||||
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
|
|
||||||
|
|
||||||
elif all(x in keys for x in ['string_to_token','string_to_param']):
|
|
||||||
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
|
|
||||||
|
|
||||||
elif 'emb_params' in keys:
|
|
||||||
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
|
|
||||||
|
|
||||||
else:
|
|
||||||
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
|
|
||||||
|
|
||||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
|
||||||
basename = Path(file_path).stem
|
|
||||||
logger.debug(f'Loading v1 embedding file: {basename}')
|
|
||||||
|
|
||||||
embeddings = list()
|
|
||||||
token_counter = -1
|
|
||||||
for token,embedding in embedding_ckpt["string_to_param"].items():
|
|
||||||
if token_counter < 0:
|
|
||||||
trigger = embedding_ckpt["name"]
|
|
||||||
elif token_counter == 0:
|
|
||||||
trigger = '<basename>'
|
|
||||||
else:
|
|
||||||
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
|
||||||
token_counter += 1
|
|
||||||
embedding_info = EmbeddingInfo(
|
|
||||||
name = trigger,
|
|
||||||
embedding = embedding,
|
|
||||||
num_vectors_per_token = embedding.size()[0],
|
|
||||||
token_dim = embedding.size()[1],
|
|
||||||
trained_steps = embedding_ckpt["step"],
|
|
||||||
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
|
|
||||||
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
|
|
||||||
)
|
|
||||||
embeddings.append(embedding_info)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def _parse_embedding_v2 (
|
|
||||||
self, embedding_ckpt: dict, file_path: str
|
|
||||||
) -> List[EmbeddingInfo]:
|
|
||||||
"""
|
|
||||||
This handles embedding .pt file variant #2.
|
|
||||||
"""
|
|
||||||
basename = Path(file_path).stem
|
|
||||||
logger.debug(f'Loading v2 embedding file: {basename}')
|
|
||||||
embeddings = list()
|
|
||||||
|
|
||||||
if isinstance(
|
|
||||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
|
||||||
):
|
|
||||||
token_counter = 0
|
|
||||||
for token,embedding in embedding_ckpt["string_to_param"].items():
|
|
||||||
trigger = token if token != '*' \
|
|
||||||
else f'<{basename}>' if token_counter == 0 \
|
|
||||||
else f'<{basename}-{int(token_counter:=token_counter+1)}>'
|
|
||||||
embedding_info = EmbeddingInfo(
|
|
||||||
name = trigger,
|
|
||||||
embedding = embedding,
|
|
||||||
num_vectors_per_token = embedding.size()[0],
|
|
||||||
token_dim = embedding.size()[1],
|
|
||||||
)
|
|
||||||
embeddings.append(embedding_info)
|
|
||||||
else:
|
|
||||||
logger.warning(f"{basename}: Unrecognized embedding format")
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
|
||||||
"""
|
|
||||||
Parse 'version 3' of the .pt textual inversion embedding files.
|
|
||||||
"""
|
|
||||||
basename = Path(file_path).stem
|
|
||||||
logger.debug(f'Loading v3 embedding file: {basename}')
|
|
||||||
embedding = embedding_ckpt['emb_params']
|
|
||||||
embedding_info = EmbeddingInfo(
|
|
||||||
name = f'<{basename}>',
|
|
||||||
embedding = embedding,
|
|
||||||
num_vectors_per_token = embedding.size()[0],
|
|
||||||
token_dim = embedding.size()[1],
|
|
||||||
)
|
|
||||||
return [embedding_info]
|
|
||||||
|
|
||||||
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
|
|
||||||
"""
|
|
||||||
Parse 'version 4' of the textual inversion embedding files. This one
|
|
||||||
is usually associated with .bin files trained by HuggingFace diffusers.
|
|
||||||
"""
|
|
||||||
basename = Path(filepath).stem
|
|
||||||
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
|
||||||
|
|
||||||
logger.debug(f'Loading v4 embedding file: {short_path}')
|
|
||||||
|
|
||||||
embeddings = list()
|
|
||||||
if list(embedding_ckpt.keys()) == 0:
|
|
||||||
logger.warning(f"Invalid embeddings file: {short_path}")
|
|
||||||
else:
|
|
||||||
for token,embedding in embedding_ckpt.items():
|
|
||||||
embedding_info = EmbeddingInfo(
|
|
||||||
name = token or f"<{basename}>",
|
|
||||||
embedding = embedding,
|
|
||||||
num_vectors_per_token = 1, # All Concepts seem to default to 1
|
|
||||||
token_dim = embedding.size()[0],
|
|
||||||
)
|
|
||||||
embeddings.append(embedding_info)
|
|
||||||
return embeddings
|
|
@ -1,11 +1,10 @@
|
|||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { sessionCreated } from 'services/thunks/session';
|
import { sessionCreated } from 'services/thunks/session';
|
||||||
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { imageUpdated, imageUploaded } from 'services/thunks/image';
|
import { imageUpdated, imageUploaded } from 'services/thunks/image';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { ImageDTO } from 'services/api';
|
||||||
import { Graph } from 'services/api';
|
|
||||||
import {
|
import {
|
||||||
canvasSessionIdChanged,
|
canvasSessionIdChanged,
|
||||||
stagingAreaInitialized,
|
stagingAreaInitialized,
|
||||||
@ -67,112 +66,106 @@ export const addUserInvokedCanvasListener = () => {
|
|||||||
|
|
||||||
moduleLog.debug(`Generation mode: ${generationMode}`);
|
moduleLog.debug(`Generation mode: ${generationMode}`);
|
||||||
|
|
||||||
// Build the canvas graph
|
// Temp placeholders for the init and mask images
|
||||||
const graphComponents = await buildCanvasGraphComponents(
|
let canvasInitImage: ImageDTO | undefined;
|
||||||
state,
|
let canvasMaskImage: ImageDTO | undefined;
|
||||||
generationMode
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!graphComponents) {
|
// For img2img and inpaint/outpaint, we need to upload the init images
|
||||||
moduleLog.error('Problem building graph');
|
if (['img2img', 'inpaint', 'outpaint'].includes(generationMode)) {
|
||||||
return;
|
// upload the image, saving the request id
|
||||||
}
|
const { requestId: initImageUploadedRequestId } = dispatch(
|
||||||
|
|
||||||
const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
|
|
||||||
|
|
||||||
// Assemble! Note that this graph *does not have the init or mask image set yet!*
|
|
||||||
const nodes: Graph['nodes'] = {
|
|
||||||
[rangeNode.id]: rangeNode,
|
|
||||||
[iterateNode.id]: iterateNode,
|
|
||||||
[baseNode.id]: baseNode,
|
|
||||||
};
|
|
||||||
|
|
||||||
const graph = { nodes, edges };
|
|
||||||
|
|
||||||
dispatch(canvasGraphBuilt(graph));
|
|
||||||
|
|
||||||
moduleLog.debug({ data: graph }, 'Canvas graph built');
|
|
||||||
|
|
||||||
// If we are generating img2img or inpaint, we need to upload the init images
|
|
||||||
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
|
|
||||||
const baseFilename = `${uuidv4()}.png`;
|
|
||||||
dispatch(
|
|
||||||
imageUploaded({
|
imageUploaded({
|
||||||
formData: {
|
formData: {
|
||||||
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
|
file: new File([baseBlob], 'canvasInitImage.png', {
|
||||||
|
type: 'image/png',
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
imageCategory: 'general',
|
imageCategory: 'general',
|
||||||
isIntermediate: true,
|
isIntermediate: true,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
// Wait for the image to be uploaded
|
// Wait for the image to be uploaded, matching by request id
|
||||||
const [{ payload: baseImageDTO }] = await take(
|
const [{ payload }] = await take(
|
||||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||||
imageUploaded.fulfilled.match(action) &&
|
imageUploaded.fulfilled.match(action) &&
|
||||||
action.meta.arg.formData.file.name === baseFilename
|
action.meta.requestId === initImageUploadedRequestId
|
||||||
);
|
);
|
||||||
|
|
||||||
// Update the base node with the image name and type
|
canvasInitImage = payload;
|
||||||
baseNode.image = {
|
|
||||||
image_name: baseImageDTO.image_name,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For inpaint, we also need to upload the mask layer
|
// For inpaint/outpaint, we also need to upload the mask layer
|
||||||
if (baseNode.type === 'inpaint') {
|
if (['inpaint', 'outpaint'].includes(generationMode)) {
|
||||||
const maskFilename = `${uuidv4()}.png`;
|
// upload the image, saving the request id
|
||||||
dispatch(
|
const { requestId: maskImageUploadedRequestId } = dispatch(
|
||||||
imageUploaded({
|
imageUploaded({
|
||||||
formData: {
|
formData: {
|
||||||
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
|
file: new File([maskBlob], 'canvasMaskImage.png', {
|
||||||
|
type: 'image/png',
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
imageCategory: 'mask',
|
imageCategory: 'mask',
|
||||||
isIntermediate: true,
|
isIntermediate: true,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
// Wait for the mask to be uploaded
|
// Wait for the image to be uploaded, matching by request id
|
||||||
const [{ payload: maskImageDTO }] = await take(
|
const [{ payload }] = await take(
|
||||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||||
imageUploaded.fulfilled.match(action) &&
|
imageUploaded.fulfilled.match(action) &&
|
||||||
action.meta.arg.formData.file.name === maskFilename
|
action.meta.requestId === maskImageUploadedRequestId
|
||||||
);
|
);
|
||||||
|
|
||||||
// Update the base node with the image name and type
|
canvasMaskImage = payload;
|
||||||
baseNode.mask = {
|
|
||||||
image_name: maskImageDTO.image_name,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the session and wait for response
|
const graph = buildCanvasGraph(
|
||||||
dispatch(sessionCreated({ graph }));
|
state,
|
||||||
const [sessionCreatedAction] = await take(sessionCreated.fulfilled.match);
|
generationMode,
|
||||||
|
canvasInitImage,
|
||||||
|
canvasMaskImage
|
||||||
|
);
|
||||||
|
|
||||||
|
moduleLog.debug({ graph }, `Canvas graph built`);
|
||||||
|
|
||||||
|
// currently this action is just listened to for logging
|
||||||
|
dispatch(canvasGraphBuilt(graph));
|
||||||
|
|
||||||
|
// Create the session, store the request id
|
||||||
|
const { requestId: sessionCreatedRequestId } = dispatch(
|
||||||
|
sessionCreated({ graph })
|
||||||
|
);
|
||||||
|
|
||||||
|
// Take the session created action, matching by its request id
|
||||||
|
const [sessionCreatedAction] = await take(
|
||||||
|
(action): action is ReturnType<typeof sessionCreated.fulfilled> =>
|
||||||
|
sessionCreated.fulfilled.match(action) &&
|
||||||
|
action.meta.requestId === sessionCreatedRequestId
|
||||||
|
);
|
||||||
const sessionId = sessionCreatedAction.payload.id;
|
const sessionId = sessionCreatedAction.payload.id;
|
||||||
|
|
||||||
// Associate the init image with the session, now that we have the session ID
|
// Associate the init image with the session, now that we have the session ID
|
||||||
if (
|
if (['img2img', 'inpaint'].includes(generationMode) && canvasInitImage) {
|
||||||
(baseNode.type === 'img2img' || baseNode.type === 'inpaint') &&
|
|
||||||
baseNode.image
|
|
||||||
) {
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imageUpdated({
|
imageUpdated({
|
||||||
imageName: baseNode.image.image_name,
|
imageName: canvasInitImage.image_name,
|
||||||
requestBody: { session_id: sessionId },
|
requestBody: { session_id: sessionId },
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Associate the mask image with the session, now that we have the session ID
|
// Associate the mask image with the session, now that we have the session ID
|
||||||
if (baseNode.type === 'inpaint' && baseNode.mask) {
|
if (['inpaint'].includes(generationMode) && canvasMaskImage) {
|
||||||
dispatch(
|
dispatch(
|
||||||
imageUpdated({
|
imageUpdated({
|
||||||
imageName: baseNode.mask.image_name,
|
imageName: canvasMaskImage.image_name,
|
||||||
requestBody: { session_id: sessionId },
|
requestBody: { session_id: sessionId },
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prep the canvas staging area if it is not yet initialized
|
||||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||||
dispatch(
|
dispatch(
|
||||||
stagingAreaInitialized({
|
stagingAreaInitialized({
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { buildImageToImageGraph } from 'features/nodes/util/graphBuilders/buildImageToImageGraph';
|
|
||||||
import { sessionCreated } from 'services/thunks/session';
|
import { sessionCreated } from 'services/thunks/session';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
|
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'invoke' });
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
@ -15,7 +15,7 @@ export const addUserInvokedImageToImageListener = () => {
|
|||||||
effect: async (action, { getState, dispatch, take }) => {
|
effect: async (action, { getState, dispatch, take }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const graph = buildImageToImageGraph(state);
|
const graph = buildLinearImageToImageGraph(state);
|
||||||
dispatch(imageToImageGraphBuilt(graph));
|
dispatch(imageToImageGraphBuilt(graph));
|
||||||
moduleLog.debug({ data: graph }, 'Image to Image graph built');
|
moduleLog.debug({ data: graph }, 'Image to Image graph built');
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { buildTextToImageGraph } from 'features/nodes/util/graphBuilders/buildTextToImageGraph';
|
|
||||||
import { sessionCreated } from 'services/thunks/session';
|
import { sessionCreated } from 'services/thunks/session';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
|
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'invoke' });
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
@ -15,7 +15,7 @@ export const addUserInvokedTextToImageListener = () => {
|
|||||||
effect: async (action, { getState, dispatch, take }) => {
|
effect: async (action, { getState, dispatch, take }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const graph = buildTextToImageGraph(state);
|
const graph = buildLinearTextToImageGraph(state);
|
||||||
|
|
||||||
dispatch(textToImageGraphBuilt(graph));
|
dispatch(textToImageGraphBuilt(graph));
|
||||||
|
|
||||||
|
@ -2,8 +2,7 @@ import { RootState } from 'app/store/store';
|
|||||||
import { filter, forEach, size } from 'lodash-es';
|
import { filter, forEach, size } from 'lodash-es';
|
||||||
import { CollectInvocation, ControlNetInvocation } from 'services/api';
|
import { CollectInvocation, ControlNetInvocation } from 'services/api';
|
||||||
import { NonNullableGraph } from '../types/types';
|
import { NonNullableGraph } from '../types/types';
|
||||||
|
import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
|
||||||
const CONTROL_NET_COLLECT = 'control_net_collect';
|
|
||||||
|
|
||||||
export const addControlNetToLinearGraph = (
|
export const addControlNetToLinearGraph = (
|
||||||
graph: NonNullableGraph,
|
graph: NonNullableGraph,
|
||||||
@ -37,7 +36,7 @@ export const addControlNetToLinearGraph = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
forEach(controlNets, (controlNet, index) => {
|
forEach(controlNets, (controlNet) => {
|
||||||
const {
|
const {
|
||||||
controlNetId,
|
controlNetId,
|
||||||
isEnabled,
|
isEnabled,
|
||||||
|
@ -1,116 +1,39 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import {
|
import { ImageDTO } from 'services/api';
|
||||||
Edge,
|
|
||||||
ImageToImageInvocation,
|
|
||||||
InpaintInvocation,
|
|
||||||
IterateInvocation,
|
|
||||||
RandomRangeInvocation,
|
|
||||||
RangeInvocation,
|
|
||||||
TextToImageInvocation,
|
|
||||||
} from 'services/api';
|
|
||||||
import { buildImg2ImgNode } from '../nodeBuilders/buildImageToImageNode';
|
|
||||||
import { buildTxt2ImgNode } from '../nodeBuilders/buildTextToImageNode';
|
|
||||||
import { buildRangeNode } from '../nodeBuilders/buildRangeNode';
|
|
||||||
import { buildIterateNode } from '../nodeBuilders/buildIterateNode';
|
|
||||||
import { buildEdges } from '../edgeBuilders/buildEdges';
|
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { buildInpaintNode } from '../nodeBuilders/buildInpaintNode';
|
import { forEach } from 'lodash-es';
|
||||||
|
import { buildCanvasInpaintGraph } from './buildCanvasInpaintGraph';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { buildCanvasImageToImageGraph } from './buildCanvasImageToImageGraph';
|
||||||
|
import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
const buildBaseNode = (
|
export const buildCanvasGraph = (
|
||||||
nodeType: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
|
|
||||||
state: RootState
|
|
||||||
):
|
|
||||||
| TextToImageInvocation
|
|
||||||
| ImageToImageInvocation
|
|
||||||
| InpaintInvocation
|
|
||||||
| undefined => {
|
|
||||||
const overrides = {
|
|
||||||
...state.canvas.boundingBoxDimensions,
|
|
||||||
is_intermediate: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (nodeType === 'txt2img') {
|
|
||||||
return buildTxt2ImgNode(state, overrides);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (nodeType === 'img2img') {
|
|
||||||
return buildImg2ImgNode(state, overrides);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (nodeType === 'inpaint' || nodeType === 'outpaint') {
|
|
||||||
return buildInpaintNode(state, overrides);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds the Canvas workflow graph and image blobs.
|
|
||||||
*/
|
|
||||||
export const buildCanvasGraphComponents = async (
|
|
||||||
state: RootState,
|
state: RootState,
|
||||||
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint'
|
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
|
||||||
): Promise<
|
canvasInitImage: ImageDTO | undefined,
|
||||||
| {
|
canvasMaskImage: ImageDTO | undefined
|
||||||
rangeNode: RangeInvocation | RandomRangeInvocation;
|
) => {
|
||||||
iterateNode: IterateInvocation;
|
let graph: NonNullableGraph;
|
||||||
baseNode:
|
|
||||||
| TextToImageInvocation
|
|
||||||
| ImageToImageInvocation
|
|
||||||
| InpaintInvocation;
|
|
||||||
edges: Edge[];
|
|
||||||
}
|
|
||||||
| undefined
|
|
||||||
> => {
|
|
||||||
// The base node is a txt2img, img2img or inpaint node
|
|
||||||
const baseNode = buildBaseNode(generationMode, state);
|
|
||||||
|
|
||||||
if (!baseNode) {
|
if (generationMode === 'txt2img') {
|
||||||
moduleLog.error('Problem building base node');
|
graph = buildCanvasTextToImageGraph(state);
|
||||||
return;
|
} else if (generationMode === 'img2img') {
|
||||||
|
if (!canvasInitImage) {
|
||||||
|
throw new Error('Missing canvas init image');
|
||||||
|
}
|
||||||
|
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
|
||||||
|
} else {
|
||||||
|
if (!canvasInitImage || !canvasMaskImage) {
|
||||||
|
throw new Error('Missing canvas init and mask images');
|
||||||
|
}
|
||||||
|
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (baseNode.type === 'inpaint') {
|
forEach(graph.nodes, (node) => {
|
||||||
const {
|
graph.nodes[node.id].is_intermediate = true;
|
||||||
seamSize,
|
});
|
||||||
seamBlur,
|
|
||||||
seamSteps,
|
|
||||||
seamStrength,
|
|
||||||
tileSize,
|
|
||||||
infillMethod,
|
|
||||||
} = state.generation;
|
|
||||||
|
|
||||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } =
|
return graph;
|
||||||
state.canvas;
|
|
||||||
|
|
||||||
if (boundingBoxScaleMethod !== 'none') {
|
|
||||||
baseNode.inpaint_width = scaledBoundingBoxDimensions.width;
|
|
||||||
baseNode.inpaint_height = scaledBoundingBoxDimensions.height;
|
|
||||||
}
|
|
||||||
|
|
||||||
baseNode.seam_size = seamSize;
|
|
||||||
baseNode.seam_blur = seamBlur;
|
|
||||||
baseNode.seam_strength = seamStrength;
|
|
||||||
baseNode.seam_steps = seamSteps;
|
|
||||||
baseNode.infill_method = infillMethod as InpaintInvocation['infill_method'];
|
|
||||||
|
|
||||||
if (infillMethod === 'tile') {
|
|
||||||
baseNode.tile_size = tileSize;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We always range and iterate nodes, no matter the iteration count
|
|
||||||
// This is required to provide the correct seeds to the backend engine
|
|
||||||
const rangeNode = buildRangeNode(state);
|
|
||||||
const iterateNode = buildIterateNode();
|
|
||||||
|
|
||||||
// Build the edges for the nodes selected.
|
|
||||||
const edges = buildEdges(baseNode, rangeNode, iterateNode);
|
|
||||||
|
|
||||||
return {
|
|
||||||
rangeNode,
|
|
||||||
iterateNode,
|
|
||||||
baseNode,
|
|
||||||
edges,
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,331 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import {
|
||||||
|
ImageDTO,
|
||||||
|
ImageResizeInvocation,
|
||||||
|
RandomIntInvocation,
|
||||||
|
RangeOfSizeInvocation,
|
||||||
|
} from 'services/api';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
ITERATE,
|
||||||
|
LATENTS_TO_IMAGE,
|
||||||
|
MODEL_LOADER,
|
||||||
|
NEGATIVE_CONDITIONING,
|
||||||
|
NOISE,
|
||||||
|
POSITIVE_CONDITIONING,
|
||||||
|
RANDOM_INT,
|
||||||
|
RANGE_OF_SIZE,
|
||||||
|
IMAGE_TO_IMAGE_GRAPH,
|
||||||
|
IMAGE_TO_LATENTS,
|
||||||
|
LATENTS_TO_LATENTS,
|
||||||
|
RESIZE,
|
||||||
|
} from './constants';
|
||||||
|
import { set } from 'lodash-es';
|
||||||
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds the Canvas tab's Image to Image graph.
|
||||||
|
*/
|
||||||
|
export const buildCanvasImageToImageGraph = (
|
||||||
|
state: RootState,
|
||||||
|
initialImage: ImageDTO
|
||||||
|
): NonNullableGraph => {
|
||||||
|
const {
|
||||||
|
positivePrompt,
|
||||||
|
negativePrompt,
|
||||||
|
model: model_name,
|
||||||
|
cfgScale: cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
img2imgStrength: strength,
|
||||||
|
iterations,
|
||||||
|
seed,
|
||||||
|
shouldRandomizeSeed,
|
||||||
|
} = state.generation;
|
||||||
|
|
||||||
|
// The bounding box determines width and height, not the width and height params
|
||||||
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
|
* ids.
|
||||||
|
*
|
||||||
|
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||||
|
* the `fit` param. These are added to the graph at the end.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
|
const graph: NonNullableGraph = {
|
||||||
|
id: IMAGE_TO_IMAGE_GRAPH,
|
||||||
|
nodes: {
|
||||||
|
[POSITIVE_CONDITIONING]: {
|
||||||
|
type: 'compel',
|
||||||
|
id: POSITIVE_CONDITIONING,
|
||||||
|
prompt: positivePrompt,
|
||||||
|
},
|
||||||
|
[NEGATIVE_CONDITIONING]: {
|
||||||
|
type: 'compel',
|
||||||
|
id: NEGATIVE_CONDITIONING,
|
||||||
|
prompt: negativePrompt,
|
||||||
|
},
|
||||||
|
[RANGE_OF_SIZE]: {
|
||||||
|
type: 'range_of_size',
|
||||||
|
id: RANGE_OF_SIZE,
|
||||||
|
// seed - must be connected manually
|
||||||
|
// start: 0,
|
||||||
|
size: iterations,
|
||||||
|
step: 1,
|
||||||
|
},
|
||||||
|
[NOISE]: {
|
||||||
|
type: 'noise',
|
||||||
|
id: NOISE,
|
||||||
|
},
|
||||||
|
[MODEL_LOADER]: {
|
||||||
|
type: 'sd1_model_loader',
|
||||||
|
id: MODEL_LOADER,
|
||||||
|
model_name,
|
||||||
|
},
|
||||||
|
[LATENTS_TO_IMAGE]: {
|
||||||
|
type: 'l2i',
|
||||||
|
id: LATENTS_TO_IMAGE,
|
||||||
|
},
|
||||||
|
[ITERATE]: {
|
||||||
|
type: 'iterate',
|
||||||
|
id: ITERATE,
|
||||||
|
},
|
||||||
|
[LATENTS_TO_LATENTS]: {
|
||||||
|
type: 'l2l',
|
||||||
|
id: LATENTS_TO_LATENTS,
|
||||||
|
cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
strength,
|
||||||
|
},
|
||||||
|
[IMAGE_TO_LATENTS]: {
|
||||||
|
type: 'i2l',
|
||||||
|
id: IMAGE_TO_LATENTS,
|
||||||
|
// must be set manually later, bc `fit` parameter may require a resize node inserted
|
||||||
|
// image: {
|
||||||
|
// image_name: initialImage.image_name,
|
||||||
|
// },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
edges: [
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: RANGE_OF_SIZE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'item',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'seed',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'negative_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'positive_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
// handle seed
|
||||||
|
if (shouldRandomizeSeed) {
|
||||||
|
// Random int node to generate the starting seed
|
||||||
|
const randomIntNode: RandomIntInvocation = {
|
||||||
|
id: RANDOM_INT,
|
||||||
|
type: 'rand_int',
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||||
|
|
||||||
|
// Connect random int to the start of the range of size so the range starts on the random first seed
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RANDOM_INT, field: 'a' },
|
||||||
|
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// User specified seed, so set the start of the range of size to the seed
|
||||||
|
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle `fit`
|
||||||
|
if (initialImage.width !== width || initialImage.height !== height) {
|
||||||
|
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||||
|
|
||||||
|
// Create a resize node, explicitly setting its image
|
||||||
|
const resizeNode: ImageResizeInvocation = {
|
||||||
|
id: RESIZE,
|
||||||
|
type: 'img_resize',
|
||||||
|
image: {
|
||||||
|
image_name: initialImage.image_name,
|
||||||
|
},
|
||||||
|
is_intermediate: true,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[RESIZE] = resizeNode;
|
||||||
|
|
||||||
|
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'image' },
|
||||||
|
destination: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// The `RESIZE` node also passes its width and height to `NOISE`
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'width' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'width',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'height' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'height',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||||
|
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
|
||||||
|
image_name: initialImage.image_name,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Pass the image's dimensions to the `NOISE` node
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'width',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'height',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// add controlnet
|
||||||
|
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
};
|
@ -0,0 +1,224 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import {
|
||||||
|
ImageDTO,
|
||||||
|
InpaintInvocation,
|
||||||
|
RandomIntInvocation,
|
||||||
|
RangeOfSizeInvocation,
|
||||||
|
} from 'services/api';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
ITERATE,
|
||||||
|
MODEL_LOADER,
|
||||||
|
NEGATIVE_CONDITIONING,
|
||||||
|
POSITIVE_CONDITIONING,
|
||||||
|
RANDOM_INT,
|
||||||
|
RANGE_OF_SIZE,
|
||||||
|
INPAINT_GRAPH,
|
||||||
|
INPAINT,
|
||||||
|
} from './constants';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds the Canvas tab's Inpaint graph.
|
||||||
|
*/
|
||||||
|
export const buildCanvasInpaintGraph = (
|
||||||
|
state: RootState,
|
||||||
|
canvasInitImage: ImageDTO,
|
||||||
|
canvasMaskImage: ImageDTO
|
||||||
|
): NonNullableGraph => {
|
||||||
|
const {
|
||||||
|
positivePrompt,
|
||||||
|
negativePrompt,
|
||||||
|
model: model_name,
|
||||||
|
cfgScale: cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
img2imgStrength: strength,
|
||||||
|
shouldFitToWidthHeight,
|
||||||
|
iterations,
|
||||||
|
seed,
|
||||||
|
shouldRandomizeSeed,
|
||||||
|
seamSize,
|
||||||
|
seamBlur,
|
||||||
|
seamSteps,
|
||||||
|
seamStrength,
|
||||||
|
tileSize,
|
||||||
|
infillMethod,
|
||||||
|
} = state.generation;
|
||||||
|
|
||||||
|
// The bounding box determines width and height, not the width and height params
|
||||||
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
|
// We may need to set the inpaint width and height to scale the image
|
||||||
|
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
||||||
|
|
||||||
|
const graph: NonNullableGraph = {
|
||||||
|
id: INPAINT_GRAPH,
|
||||||
|
nodes: {
|
||||||
|
[INPAINT]: {
|
||||||
|
type: 'inpaint',
|
||||||
|
id: INPAINT,
|
||||||
|
steps,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
image: {
|
||||||
|
image_name: canvasInitImage.image_name,
|
||||||
|
},
|
||||||
|
strength,
|
||||||
|
fit: shouldFitToWidthHeight,
|
||||||
|
mask: {
|
||||||
|
image_name: canvasMaskImage.image_name,
|
||||||
|
},
|
||||||
|
seam_size: seamSize,
|
||||||
|
seam_blur: seamBlur,
|
||||||
|
seam_strength: seamStrength,
|
||||||
|
seam_steps: seamSteps,
|
||||||
|
tile_size: infillMethod === 'tile' ? tileSize : undefined,
|
||||||
|
infill_method: infillMethod as InpaintInvocation['infill_method'],
|
||||||
|
inpaint_width:
|
||||||
|
boundingBoxScaleMethod !== 'none'
|
||||||
|
? scaledBoundingBoxDimensions.width
|
||||||
|
: undefined,
|
||||||
|
inpaint_height:
|
||||||
|
boundingBoxScaleMethod !== 'none'
|
||||||
|
? scaledBoundingBoxDimensions.height
|
||||||
|
: undefined,
|
||||||
|
},
|
||||||
|
[POSITIVE_CONDITIONING]: {
|
||||||
|
type: 'compel',
|
||||||
|
id: POSITIVE_CONDITIONING,
|
||||||
|
prompt: positivePrompt,
|
||||||
|
},
|
||||||
|
[NEGATIVE_CONDITIONING]: {
|
||||||
|
type: 'compel',
|
||||||
|
id: NEGATIVE_CONDITIONING,
|
||||||
|
prompt: negativePrompt,
|
||||||
|
},
|
||||||
|
[MODEL_LOADER]: {
|
||||||
|
type: 'sd1_model_loader',
|
||||||
|
id: MODEL_LOADER,
|
||||||
|
model_name,
|
||||||
|
},
|
||||||
|
[RANGE_OF_SIZE]: {
|
||||||
|
type: 'range_of_size',
|
||||||
|
id: RANGE_OF_SIZE,
|
||||||
|
// seed - must be connected manually
|
||||||
|
// start: 0,
|
||||||
|
size: iterations,
|
||||||
|
step: 1,
|
||||||
|
},
|
||||||
|
[ITERATE]: {
|
||||||
|
type: 'iterate',
|
||||||
|
id: ITERATE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
edges: [
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT,
|
||||||
|
field: 'negative_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT,
|
||||||
|
field: 'positive_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: RANGE_OF_SIZE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'item',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: INPAINT,
|
||||||
|
field: 'seed',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
// handle seed
|
||||||
|
if (shouldRandomizeSeed) {
|
||||||
|
// Random int node to generate the starting seed
|
||||||
|
const randomIntNode: RandomIntInvocation = {
|
||||||
|
id: RANDOM_INT,
|
||||||
|
type: 'rand_int',
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||||
|
|
||||||
|
// Connect random int to the start of the range of size so the range starts on the random first seed
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RANDOM_INT, field: 'a' },
|
||||||
|
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// User specified seed, so set the start of the range of size to the seed
|
||||||
|
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
};
|
@ -0,0 +1,224 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api';
|
||||||
|
import {
|
||||||
|
ITERATE,
|
||||||
|
LATENTS_TO_IMAGE,
|
||||||
|
MODEL_LOADER,
|
||||||
|
NEGATIVE_CONDITIONING,
|
||||||
|
NOISE,
|
||||||
|
POSITIVE_CONDITIONING,
|
||||||
|
RANDOM_INT,
|
||||||
|
RANGE_OF_SIZE,
|
||||||
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
|
TEXT_TO_LATENTS,
|
||||||
|
} from './constants';
|
||||||
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds the Canvas tab's Text to Image graph.
|
||||||
|
*/
|
||||||
|
export const buildCanvasTextToImageGraph = (
|
||||||
|
state: RootState
|
||||||
|
): NonNullableGraph => {
|
||||||
|
const {
|
||||||
|
positivePrompt,
|
||||||
|
negativePrompt,
|
||||||
|
model: model_name,
|
||||||
|
cfgScale: cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
iterations,
|
||||||
|
seed,
|
||||||
|
shouldRandomizeSeed,
|
||||||
|
} = state.generation;
|
||||||
|
|
||||||
|
// The bounding box determines width and height, not the width and height params
|
||||||
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
|
* ids.
|
||||||
|
*
|
||||||
|
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||||
|
* the `fit` param. These are added to the graph at the end.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
|
const graph: NonNullableGraph = {
|
||||||
|
id: TEXT_TO_IMAGE_GRAPH,
|
||||||
|
nodes: {
|
||||||
|
[POSITIVE_CONDITIONING]: {
|
||||||
|
type: 'compel',
|
||||||
|
id: POSITIVE_CONDITIONING,
|
||||||
|
prompt: positivePrompt,
|
||||||
|
},
|
||||||
|
[NEGATIVE_CONDITIONING]: {
|
||||||
|
type: 'compel',
|
||||||
|
id: NEGATIVE_CONDITIONING,
|
||||||
|
prompt: negativePrompt,
|
||||||
|
},
|
||||||
|
[RANGE_OF_SIZE]: {
|
||||||
|
type: 'range_of_size',
|
||||||
|
id: RANGE_OF_SIZE,
|
||||||
|
// start: 0, // seed - must be connected manually
|
||||||
|
size: iterations,
|
||||||
|
step: 1,
|
||||||
|
},
|
||||||
|
[NOISE]: {
|
||||||
|
type: 'noise',
|
||||||
|
id: NOISE,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
},
|
||||||
|
[TEXT_TO_LATENTS]: {
|
||||||
|
type: 't2l',
|
||||||
|
id: TEXT_TO_LATENTS,
|
||||||
|
cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
},
|
||||||
|
[MODEL_LOADER]: {
|
||||||
|
type: 'sd1_model_loader',
|
||||||
|
id: MODEL_LOADER,
|
||||||
|
model_name,
|
||||||
|
},
|
||||||
|
[LATENTS_TO_IMAGE]: {
|
||||||
|
type: 'l2i',
|
||||||
|
id: LATENTS_TO_IMAGE,
|
||||||
|
},
|
||||||
|
[ITERATE]: {
|
||||||
|
type: 'iterate',
|
||||||
|
id: ITERATE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
edges: [
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: TEXT_TO_LATENTS,
|
||||||
|
field: 'negative_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: TEXT_TO_LATENTS,
|
||||||
|
field: 'positive_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: TEXT_TO_LATENTS,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: TEXT_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: RANGE_OF_SIZE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'item',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'seed',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: TEXT_TO_LATENTS,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
// handle seed
|
||||||
|
if (shouldRandomizeSeed) {
|
||||||
|
// Random int node to generate the starting seed
|
||||||
|
const randomIntNode: RandomIntInvocation = {
|
||||||
|
id: RANDOM_INT,
|
||||||
|
type: 'rand_int',
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||||
|
|
||||||
|
// Connect random int to the start of the range of size so the range starts on the random first seed
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RANDOM_INT, field: 'a' },
|
||||||
|
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// User specified seed, so set the start of the range of size to the seed
|
||||||
|
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
// add controlnet
|
||||||
|
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
};
|
@ -1,465 +0,0 @@
|
|||||||
import { log } from 'app/logging/useLogger';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
|
||||||
import { set } from 'lodash-es';
|
|
||||||
import {
|
|
||||||
CompelInvocation,
|
|
||||||
Graph,
|
|
||||||
ImageResizeInvocation,
|
|
||||||
ImageToLatentsInvocation,
|
|
||||||
IterateInvocation,
|
|
||||||
LatentsToImageInvocation,
|
|
||||||
LatentsToLatentsInvocation,
|
|
||||||
NoiseInvocation,
|
|
||||||
RandomIntInvocation,
|
|
||||||
RangeOfSizeInvocation,
|
|
||||||
SD1ModelLoaderInvocation,
|
|
||||||
SD2ModelLoaderInvocation,
|
|
||||||
} from 'services/api';
|
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
|
||||||
|
|
||||||
const MODEL_LOADER = 'model_loader';
|
|
||||||
const POSITIVE_CONDITIONING = 'positive_conditioning';
|
|
||||||
const NEGATIVE_CONDITIONING = 'negative_conditioning';
|
|
||||||
const IMAGE_TO_LATENTS = 'image_to_latents';
|
|
||||||
const LATENTS_TO_LATENTS = 'latents_to_latents';
|
|
||||||
const LATENTS_TO_IMAGE = 'latents_to_image';
|
|
||||||
const RESIZE = 'resize_image';
|
|
||||||
const NOISE = 'noise';
|
|
||||||
const RANDOM_INT = 'rand_int';
|
|
||||||
const RANGE_OF_SIZE = 'range_of_size';
|
|
||||||
const ITERATE = 'iterate';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds the Image to Image tab graph.
|
|
||||||
*/
|
|
||||||
export const buildImageToImageGraph = (state: RootState): Graph => {
|
|
||||||
const {
|
|
||||||
positivePrompt,
|
|
||||||
negativePrompt,
|
|
||||||
model,
|
|
||||||
currentModelType,
|
|
||||||
cfgScale: cfg_scale,
|
|
||||||
scheduler,
|
|
||||||
steps,
|
|
||||||
initialImage,
|
|
||||||
img2imgStrength: strength,
|
|
||||||
shouldFitToWidthHeight,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
iterations,
|
|
||||||
seed,
|
|
||||||
shouldRandomizeSeed,
|
|
||||||
} = state.generation;
|
|
||||||
|
|
||||||
if (!initialImage) {
|
|
||||||
moduleLog.error('No initial image found in state');
|
|
||||||
throw new Error('No initial image found in state');
|
|
||||||
}
|
|
||||||
|
|
||||||
const graph: NonNullableGraph = {
|
|
||||||
nodes: {},
|
|
||||||
edges: [],
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create the model loader node
|
|
||||||
const modelLoaderNode: SD1ModelLoaderInvocation | SD2ModelLoaderInvocation = {
|
|
||||||
id: MODEL_LOADER,
|
|
||||||
type: currentModelType,
|
|
||||||
model_name: model,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create the positive conditioning (prompt) node
|
|
||||||
const positiveConditioningNode: CompelInvocation = {
|
|
||||||
id: POSITIVE_CONDITIONING,
|
|
||||||
type: 'compel',
|
|
||||||
prompt: positivePrompt,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Negative conditioning
|
|
||||||
const negativeConditioningNode: CompelInvocation = {
|
|
||||||
id: NEGATIVE_CONDITIONING,
|
|
||||||
type: 'compel',
|
|
||||||
prompt: negativePrompt,
|
|
||||||
};
|
|
||||||
|
|
||||||
// This will encode the raster image to latents - but it may get its `image` from a resize node,
|
|
||||||
// so we do not set its `image` property yet
|
|
||||||
const imageToLatentsNode: ImageToLatentsInvocation = {
|
|
||||||
id: IMAGE_TO_LATENTS,
|
|
||||||
type: 'i2l',
|
|
||||||
};
|
|
||||||
|
|
||||||
// This does the actual img2img inference
|
|
||||||
const latentsToLatentsNode: LatentsToLatentsInvocation = {
|
|
||||||
id: LATENTS_TO_LATENTS,
|
|
||||||
type: 'l2l',
|
|
||||||
cfg_scale,
|
|
||||||
scheduler,
|
|
||||||
steps,
|
|
||||||
strength,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Finally we decode the latents back to an image
|
|
||||||
const latentsToImageNode: LatentsToImageInvocation = {
|
|
||||||
id: LATENTS_TO_IMAGE,
|
|
||||||
type: 'l2i',
|
|
||||||
};
|
|
||||||
|
|
||||||
// Add all those nodes to the graph
|
|
||||||
graph.nodes[MODEL_LOADER] = modelLoaderNode;
|
|
||||||
graph.nodes[POSITIVE_CONDITIONING] = positiveConditioningNode;
|
|
||||||
graph.nodes[NEGATIVE_CONDITIONING] = negativeConditioningNode;
|
|
||||||
graph.nodes[IMAGE_TO_LATENTS] = imageToLatentsNode;
|
|
||||||
graph.nodes[LATENTS_TO_LATENTS] = latentsToLatentsNode;
|
|
||||||
graph.nodes[LATENTS_TO_IMAGE] = latentsToImageNode;
|
|
||||||
|
|
||||||
// Connect the model loader to the required nodes
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: MODEL_LOADER, field: 'clip' },
|
|
||||||
destination: {
|
|
||||||
node_id: POSITIVE_CONDITIONING,
|
|
||||||
field: 'clip',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: MODEL_LOADER, field: 'clip' },
|
|
||||||
destination: {
|
|
||||||
node_id: NEGATIVE_CONDITIONING,
|
|
||||||
field: 'clip',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: MODEL_LOADER, field: 'vae' },
|
|
||||||
destination: {
|
|
||||||
node_id: IMAGE_TO_LATENTS,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: MODEL_LOADER, field: 'unet' },
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_LATENTS,
|
|
||||||
field: 'unet',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: MODEL_LOADER, field: 'vae' },
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect the prompt nodes to the imageToLatents node
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: POSITIVE_CONDITIONING, field: 'conditioning' },
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_LATENTS,
|
|
||||||
field: 'positive_conditioning',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: NEGATIVE_CONDITIONING, field: 'conditioning' },
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_LATENTS,
|
|
||||||
field: 'negative_conditioning',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect the image-encoding node
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: IMAGE_TO_LATENTS, field: 'latents' },
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_LATENTS,
|
|
||||||
field: 'latents',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect the image-decoding node
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: LATENTS_TO_LATENTS, field: 'latents' },
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'latents',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Now we need to handle iterations and random seeds. There are four possible scenarios:
|
|
||||||
* - Single iteration, explicit seed
|
|
||||||
* - Single iteration, random seed
|
|
||||||
* - Multiple iterations, explicit seed
|
|
||||||
* - Multiple iterations, random seed
|
|
||||||
*
|
|
||||||
* They all have different graphs and connections.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Single iteration, explicit seed
|
|
||||||
if (!shouldRandomizeSeed && iterations === 1) {
|
|
||||||
// Noise node using the explicit seed
|
|
||||||
const noiseNode: NoiseInvocation = {
|
|
||||||
id: NOISE,
|
|
||||||
type: 'noise',
|
|
||||||
seed: seed,
|
|
||||||
};
|
|
||||||
|
|
||||||
graph.nodes[NOISE] = noiseNode;
|
|
||||||
|
|
||||||
// Connect noise to l2l
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: NOISE, field: 'noise' },
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_LATENTS,
|
|
||||||
field: 'noise',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Single iteration, random seed
|
|
||||||
if (shouldRandomizeSeed && iterations === 1) {
|
|
||||||
// Random int node to generate the seed
|
|
||||||
const randomIntNode: RandomIntInvocation = {
|
|
||||||
id: RANDOM_INT,
|
|
||||||
type: 'rand_int',
|
|
||||||
};
|
|
||||||
|
|
||||||
// Noise node without any seed
|
|
||||||
const noiseNode: NoiseInvocation = {
|
|
||||||
id: NOISE,
|
|
||||||
type: 'noise',
|
|
||||||
};
|
|
||||||
|
|
||||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
|
||||||
graph.nodes[NOISE] = noiseNode;
|
|
||||||
|
|
||||||
// Connect random int to the seed of the noise node
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RANDOM_INT, field: 'a' },
|
|
||||||
destination: {
|
|
||||||
node_id: NOISE,
|
|
||||||
field: 'seed',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect noise to l2l
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: NOISE, field: 'noise' },
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_LATENTS,
|
|
||||||
field: 'noise',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multiple iterations, explicit seed
|
|
||||||
if (!shouldRandomizeSeed && iterations > 1) {
|
|
||||||
// Range of size node to generate `iterations` count of seeds - range of size generates a collection
|
|
||||||
// of ints from `start` to `start + size`. The `start` is the seed, and the `size` is the number of
|
|
||||||
// iterations.
|
|
||||||
const rangeOfSizeNode: RangeOfSizeInvocation = {
|
|
||||||
id: RANGE_OF_SIZE,
|
|
||||||
type: 'range_of_size',
|
|
||||||
start: seed,
|
|
||||||
size: iterations,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Iterate node to iterate over the seeds generated by the range of size node
|
|
||||||
const iterateNode: IterateInvocation = {
|
|
||||||
id: ITERATE,
|
|
||||||
type: 'iterate',
|
|
||||||
};
|
|
||||||
|
|
||||||
// Noise node without any seed
|
|
||||||
const noiseNode: NoiseInvocation = {
|
|
||||||
id: NOISE,
|
|
||||||
type: 'noise',
|
|
||||||
};
|
|
||||||
|
|
||||||
// Adding to the graph
|
|
||||||
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
|
|
||||||
graph.nodes[ITERATE] = iterateNode;
|
|
||||||
graph.nodes[NOISE] = noiseNode;
|
|
||||||
|
|
||||||
// Connect range of size to iterate
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RANGE_OF_SIZE, field: 'collection' },
|
|
||||||
destination: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'collection',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect iterate to noise
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'item',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: NOISE,
|
|
||||||
field: 'seed',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect noise to l2l
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: NOISE, field: 'noise' },
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_LATENTS,
|
|
||||||
field: 'noise',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multiple iterations, random seed
|
|
||||||
if (shouldRandomizeSeed && iterations > 1) {
|
|
||||||
// Random int node to generate the seed
|
|
||||||
const randomIntNode: RandomIntInvocation = {
|
|
||||||
id: RANDOM_INT,
|
|
||||||
type: 'rand_int',
|
|
||||||
};
|
|
||||||
|
|
||||||
// Range of size node to generate `iterations` count of seeds - range of size generates a collection
|
|
||||||
const rangeOfSizeNode: RangeOfSizeInvocation = {
|
|
||||||
id: RANGE_OF_SIZE,
|
|
||||||
type: 'range_of_size',
|
|
||||||
size: iterations,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Iterate node to iterate over the seeds generated by the range of size node
|
|
||||||
const iterateNode: IterateInvocation = {
|
|
||||||
id: ITERATE,
|
|
||||||
type: 'iterate',
|
|
||||||
};
|
|
||||||
|
|
||||||
// Noise node without any seed
|
|
||||||
const noiseNode: NoiseInvocation = {
|
|
||||||
id: NOISE,
|
|
||||||
type: 'noise',
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Adding to the graph
|
|
||||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
|
||||||
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
|
|
||||||
graph.nodes[ITERATE] = iterateNode;
|
|
||||||
graph.nodes[NOISE] = noiseNode;
|
|
||||||
|
|
||||||
// Connect random int to the start of the range of size so the range starts on the random first seed
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RANDOM_INT, field: 'a' },
|
|
||||||
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect range of size to iterate
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RANGE_OF_SIZE, field: 'collection' },
|
|
||||||
destination: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'collection',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect iterate to noise
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'item',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: NOISE,
|
|
||||||
field: 'seed',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect noise to l2l
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: NOISE, field: 'noise' },
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_LATENTS,
|
|
||||||
field: 'noise',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
shouldFitToWidthHeight &&
|
|
||||||
(initialImage.width !== width || initialImage.height !== height)
|
|
||||||
) {
|
|
||||||
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
|
||||||
|
|
||||||
// Create a resize node, explicitly setting its image
|
|
||||||
const resizeNode: ImageResizeInvocation = {
|
|
||||||
id: RESIZE,
|
|
||||||
type: 'img_resize',
|
|
||||||
image: {
|
|
||||||
image_name: initialImage.image_name,
|
|
||||||
},
|
|
||||||
is_intermediate: true,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
};
|
|
||||||
|
|
||||||
graph.nodes[RESIZE] = resizeNode;
|
|
||||||
|
|
||||||
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RESIZE, field: 'image' },
|
|
||||||
destination: {
|
|
||||||
node_id: IMAGE_TO_LATENTS,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// The `RESIZE` node also passes its width and height to `NOISE`
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RESIZE, field: 'width' },
|
|
||||||
destination: {
|
|
||||||
node_id: NOISE,
|
|
||||||
field: 'width',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RESIZE, field: 'height' },
|
|
||||||
destination: {
|
|
||||||
node_id: NOISE,
|
|
||||||
field: 'height',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
|
||||||
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
|
|
||||||
image_name: initialImage.image_name,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Pass the image's dimensions to the `NOISE` node
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
|
|
||||||
destination: {
|
|
||||||
node_id: NOISE,
|
|
||||||
field: 'width',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
|
|
||||||
destination: {
|
|
||||||
node_id: NOISE,
|
|
||||||
field: 'height',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
|
|
||||||
|
|
||||||
return graph;
|
|
||||||
};
|
|
@ -0,0 +1,338 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import {
|
||||||
|
ImageResizeInvocation,
|
||||||
|
RandomIntInvocation,
|
||||||
|
RangeOfSizeInvocation,
|
||||||
|
} from 'services/api';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
ITERATE,
|
||||||
|
LATENTS_TO_IMAGE,
|
||||||
|
MODEL_LOADER,
|
||||||
|
NEGATIVE_CONDITIONING,
|
||||||
|
NOISE,
|
||||||
|
POSITIVE_CONDITIONING,
|
||||||
|
RANDOM_INT,
|
||||||
|
RANGE_OF_SIZE,
|
||||||
|
IMAGE_TO_IMAGE_GRAPH,
|
||||||
|
IMAGE_TO_LATENTS,
|
||||||
|
LATENTS_TO_LATENTS,
|
||||||
|
RESIZE,
|
||||||
|
} from './constants';
|
||||||
|
import { set } from 'lodash-es';
|
||||||
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds the Image to Image tab graph.
|
||||||
|
*/
|
||||||
|
export const buildLinearImageToImageGraph = (
|
||||||
|
state: RootState
|
||||||
|
): NonNullableGraph => {
|
||||||
|
const {
|
||||||
|
positivePrompt,
|
||||||
|
negativePrompt,
|
||||||
|
model: model_name,
|
||||||
|
cfgScale: cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
initialImage,
|
||||||
|
img2imgStrength: strength,
|
||||||
|
shouldFitToWidthHeight,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
iterations,
|
||||||
|
seed,
|
||||||
|
shouldRandomizeSeed,
|
||||||
|
} = state.generation;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
|
* ids.
|
||||||
|
*
|
||||||
|
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||||
|
* the `fit` param. These are added to the graph at the end.
|
||||||
|
*/
|
||||||
|
|
||||||
|
if (!initialImage) {
|
||||||
|
moduleLog.error('No initial image found in state');
|
||||||
|
throw new Error('No initial image found in state');
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
|
const graph: NonNullableGraph = {
|
||||||
|
id: IMAGE_TO_IMAGE_GRAPH,
|
||||||
|
nodes: {
|
||||||
|
[POSITIVE_CONDITIONING]: {
|
||||||
|
type: 'compel',
|
||||||
|
id: POSITIVE_CONDITIONING,
|
||||||
|
prompt: positivePrompt,
|
||||||
|
},
|
||||||
|
[NEGATIVE_CONDITIONING]: {
|
||||||
|
type: 'compel',
|
||||||
|
id: NEGATIVE_CONDITIONING,
|
||||||
|
prompt: negativePrompt,
|
||||||
|
},
|
||||||
|
[RANGE_OF_SIZE]: {
|
||||||
|
type: 'range_of_size',
|
||||||
|
id: RANGE_OF_SIZE,
|
||||||
|
// seed - must be connected manually
|
||||||
|
// start: 0,
|
||||||
|
size: iterations,
|
||||||
|
step: 1,
|
||||||
|
},
|
||||||
|
[NOISE]: {
|
||||||
|
type: 'noise',
|
||||||
|
id: NOISE,
|
||||||
|
},
|
||||||
|
[MODEL_LOADER]: {
|
||||||
|
type: 'sd1_model_loader',
|
||||||
|
id: MODEL_LOADER,
|
||||||
|
model_name,
|
||||||
|
},
|
||||||
|
[LATENTS_TO_IMAGE]: {
|
||||||
|
type: 'l2i',
|
||||||
|
id: LATENTS_TO_IMAGE,
|
||||||
|
},
|
||||||
|
[ITERATE]: {
|
||||||
|
type: 'iterate',
|
||||||
|
id: ITERATE,
|
||||||
|
},
|
||||||
|
[LATENTS_TO_LATENTS]: {
|
||||||
|
type: 'l2l',
|
||||||
|
id: LATENTS_TO_LATENTS,
|
||||||
|
cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
strength,
|
||||||
|
},
|
||||||
|
[IMAGE_TO_LATENTS]: {
|
||||||
|
type: 'i2l',
|
||||||
|
id: IMAGE_TO_LATENTS,
|
||||||
|
// must be set manually later, bc `fit` parameter may require a resize node inserted
|
||||||
|
// image: {
|
||||||
|
// image_name: initialImage.image_name,
|
||||||
|
// },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
edges: [
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: RANGE_OF_SIZE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'item',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'seed',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'negative_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'positive_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
// handle seed
|
||||||
|
if (shouldRandomizeSeed) {
|
||||||
|
// Random int node to generate the starting seed
|
||||||
|
const randomIntNode: RandomIntInvocation = {
|
||||||
|
id: RANDOM_INT,
|
||||||
|
type: 'rand_int',
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||||
|
|
||||||
|
// Connect random int to the start of the range of size so the range starts on the random first seed
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RANDOM_INT, field: 'a' },
|
||||||
|
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// User specified seed, so set the start of the range of size to the seed
|
||||||
|
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle `fit`
|
||||||
|
if (
|
||||||
|
shouldFitToWidthHeight &&
|
||||||
|
(initialImage.width !== width || initialImage.height !== height)
|
||||||
|
) {
|
||||||
|
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||||
|
|
||||||
|
// Create a resize node, explicitly setting its image
|
||||||
|
const resizeNode: ImageResizeInvocation = {
|
||||||
|
id: RESIZE,
|
||||||
|
type: 'img_resize',
|
||||||
|
image: {
|
||||||
|
image_name: initialImage.image_name,
|
||||||
|
},
|
||||||
|
is_intermediate: true,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[RESIZE] = resizeNode;
|
||||||
|
|
||||||
|
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'image' },
|
||||||
|
destination: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// The `RESIZE` node also passes its width and height to `NOISE`
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'width' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'width',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'height' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'height',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||||
|
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
|
||||||
|
image_name: initialImage.image_name,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Pass the image's dimensions to the `NOISE` node
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'width',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'height',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// add controlnet
|
||||||
|
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
};
|
@ -0,0 +1,226 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api';
|
||||||
|
import {
|
||||||
|
ITERATE,
|
||||||
|
LATENTS_TO_IMAGE,
|
||||||
|
MODEL_LOADER,
|
||||||
|
NEGATIVE_CONDITIONING,
|
||||||
|
NOISE,
|
||||||
|
POSITIVE_CONDITIONING,
|
||||||
|
RANDOM_INT,
|
||||||
|
RANGE_OF_SIZE,
|
||||||
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
|
TEXT_TO_LATENTS,
|
||||||
|
} from './constants';
|
||||||
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
|
||||||
|
type TextToImageGraphOverrides = {
|
||||||
|
width: number;
|
||||||
|
height: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const buildLinearTextToImageGraph = (
|
||||||
|
state: RootState,
|
||||||
|
overrides?: TextToImageGraphOverrides
|
||||||
|
): NonNullableGraph => {
|
||||||
|
const {
|
||||||
|
positivePrompt,
|
||||||
|
negativePrompt,
|
||||||
|
model: model_name,
|
||||||
|
cfgScale: cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
iterations,
|
||||||
|
seed,
|
||||||
|
shouldRandomizeSeed,
|
||||||
|
} = state.generation;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
|
* ids.
|
||||||
|
*
|
||||||
|
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||||
|
* the `fit` param. These are added to the graph at the end.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
|
const graph: NonNullableGraph = {
|
||||||
|
id: TEXT_TO_IMAGE_GRAPH,
|
||||||
|
nodes: {
|
||||||
|
[POSITIVE_CONDITIONING]: {
|
||||||
|
type: 'compel',
|
||||||
|
id: POSITIVE_CONDITIONING,
|
||||||
|
prompt: positivePrompt,
|
||||||
|
},
|
||||||
|
[NEGATIVE_CONDITIONING]: {
|
||||||
|
type: 'compel',
|
||||||
|
id: NEGATIVE_CONDITIONING,
|
||||||
|
prompt: negativePrompt,
|
||||||
|
},
|
||||||
|
[RANGE_OF_SIZE]: {
|
||||||
|
type: 'range_of_size',
|
||||||
|
id: RANGE_OF_SIZE,
|
||||||
|
// start: 0, // seed - must be connected manually
|
||||||
|
size: iterations,
|
||||||
|
step: 1,
|
||||||
|
},
|
||||||
|
[NOISE]: {
|
||||||
|
type: 'noise',
|
||||||
|
id: NOISE,
|
||||||
|
width: overrides?.width || width,
|
||||||
|
height: overrides?.height || height,
|
||||||
|
},
|
||||||
|
[TEXT_TO_LATENTS]: {
|
||||||
|
type: 't2l',
|
||||||
|
id: TEXT_TO_LATENTS,
|
||||||
|
cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
},
|
||||||
|
[MODEL_LOADER]: {
|
||||||
|
type: 'sd1_model_loader',
|
||||||
|
id: MODEL_LOADER,
|
||||||
|
model_name,
|
||||||
|
},
|
||||||
|
[LATENTS_TO_IMAGE]: {
|
||||||
|
type: 'l2i',
|
||||||
|
id: LATENTS_TO_IMAGE,
|
||||||
|
},
|
||||||
|
[ITERATE]: {
|
||||||
|
type: 'iterate',
|
||||||
|
id: ITERATE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
edges: [
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: TEXT_TO_LATENTS,
|
||||||
|
field: 'negative_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: TEXT_TO_LATENTS,
|
||||||
|
field: 'positive_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: TEXT_TO_LATENTS,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: TEXT_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MODEL_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: RANGE_OF_SIZE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'item',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'seed',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: TEXT_TO_LATENTS,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
// handle seed
|
||||||
|
if (shouldRandomizeSeed) {
|
||||||
|
// Random int node to generate the starting seed
|
||||||
|
const randomIntNode: RandomIntInvocation = {
|
||||||
|
id: RANDOM_INT,
|
||||||
|
type: 'rand_int',
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||||
|
|
||||||
|
// Connect random int to the start of the range of size so the range starts on the random first seed
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RANDOM_INT, field: 'a' },
|
||||||
|
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// User specified seed, so set the start of the range of size to the seed
|
||||||
|
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
// add controlnet
|
||||||
|
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
};
|
@ -1,357 +0,0 @@
|
|||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
|
||||||
import {
|
|
||||||
CompelInvocation,
|
|
||||||
Graph,
|
|
||||||
IterateInvocation,
|
|
||||||
LatentsToImageInvocation,
|
|
||||||
NoiseInvocation,
|
|
||||||
RandomIntInvocation,
|
|
||||||
RangeOfSizeInvocation,
|
|
||||||
SD1ModelLoaderInvocation,
|
|
||||||
SD2ModelLoaderInvocation,
|
|
||||||
TextToLatentsInvocation,
|
|
||||||
} from 'services/api';
|
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
|
||||||
|
|
||||||
const MODEL_LOADER = 'model_loader';
|
|
||||||
const POSITIVE_CONDITIONING = 'positive_conditioning';
|
|
||||||
const NEGATIVE_CONDITIONING = 'negative_conditioning';
|
|
||||||
const TEXT_TO_LATENTS = 'text_to_latents';
|
|
||||||
const LATENTS_TO_IMAGE = 'latents_to_image';
|
|
||||||
const NOISE = 'noise';
|
|
||||||
const RANDOM_INT = 'rand_int';
|
|
||||||
const RANGE_OF_SIZE = 'range_of_size';
|
|
||||||
const ITERATE = 'iterate';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds the Text to Image tab graph.
|
|
||||||
*/
|
|
||||||
export const buildTextToImageGraph = (state: RootState): Graph => {
|
|
||||||
const {
|
|
||||||
model,
|
|
||||||
currentModelType,
|
|
||||||
positivePrompt,
|
|
||||||
negativePrompt,
|
|
||||||
cfgScale: cfg_scale,
|
|
||||||
scheduler,
|
|
||||||
steps,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
iterations,
|
|
||||||
seed,
|
|
||||||
shouldRandomizeSeed,
|
|
||||||
} = state.generation;
|
|
||||||
|
|
||||||
const graph: NonNullableGraph = {
|
|
||||||
nodes: {},
|
|
||||||
edges: [],
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create the model loader node
|
|
||||||
const modelLoaderNode: SD1ModelLoaderInvocation | SD2ModelLoaderInvocation = {
|
|
||||||
id: MODEL_LOADER,
|
|
||||||
type: currentModelType,
|
|
||||||
model_name: model,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create the conditioning, t2l and l2i nodes
|
|
||||||
const positiveConditioningNode: CompelInvocation = {
|
|
||||||
id: POSITIVE_CONDITIONING,
|
|
||||||
type: 'compel',
|
|
||||||
prompt: positivePrompt,
|
|
||||||
};
|
|
||||||
|
|
||||||
const negativeConditioningNode: CompelInvocation = {
|
|
||||||
id: NEGATIVE_CONDITIONING,
|
|
||||||
type: 'compel',
|
|
||||||
prompt: negativePrompt,
|
|
||||||
};
|
|
||||||
|
|
||||||
const textToLatentsNode: TextToLatentsInvocation = {
|
|
||||||
id: TEXT_TO_LATENTS,
|
|
||||||
type: 't2l',
|
|
||||||
cfg_scale,
|
|
||||||
scheduler,
|
|
||||||
steps,
|
|
||||||
};
|
|
||||||
|
|
||||||
const latentsToImageNode: LatentsToImageInvocation = {
|
|
||||||
id: LATENTS_TO_IMAGE,
|
|
||||||
type: 'l2i',
|
|
||||||
};
|
|
||||||
|
|
||||||
// Add to the graph
|
|
||||||
graph.nodes[MODEL_LOADER] = modelLoaderNode;
|
|
||||||
graph.nodes[POSITIVE_CONDITIONING] = positiveConditioningNode;
|
|
||||||
graph.nodes[NEGATIVE_CONDITIONING] = negativeConditioningNode;
|
|
||||||
graph.nodes[TEXT_TO_LATENTS] = textToLatentsNode;
|
|
||||||
graph.nodes[LATENTS_TO_IMAGE] = latentsToImageNode;
|
|
||||||
|
|
||||||
// Connect the model loader to the required nodes
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: MODEL_LOADER, field: 'clip' },
|
|
||||||
destination: {
|
|
||||||
node_id: POSITIVE_CONDITIONING,
|
|
||||||
field: 'clip',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: MODEL_LOADER, field: 'clip' },
|
|
||||||
destination: {
|
|
||||||
node_id: NEGATIVE_CONDITIONING,
|
|
||||||
field: 'clip',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: MODEL_LOADER, field: 'unet' },
|
|
||||||
destination: {
|
|
||||||
node_id: TEXT_TO_LATENTS,
|
|
||||||
field: 'unet',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: MODEL_LOADER, field: 'vae' },
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'vae',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect them
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: POSITIVE_CONDITIONING, field: 'conditioning' },
|
|
||||||
destination: {
|
|
||||||
node_id: TEXT_TO_LATENTS,
|
|
||||||
field: 'positive_conditioning',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: NEGATIVE_CONDITIONING, field: 'conditioning' },
|
|
||||||
destination: {
|
|
||||||
node_id: TEXT_TO_LATENTS,
|
|
||||||
field: 'negative_conditioning',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: TEXT_TO_LATENTS, field: 'latents' },
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'latents',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Now we need to handle iterations and random seeds. There are four possible scenarios:
|
|
||||||
* - Single iteration, explicit seed
|
|
||||||
* - Single iteration, random seed
|
|
||||||
* - Multiple iterations, explicit seed
|
|
||||||
* - Multiple iterations, random seed
|
|
||||||
*
|
|
||||||
* They all have different graphs and connections.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Single iteration, explicit seed
|
|
||||||
if (!shouldRandomizeSeed && iterations === 1) {
|
|
||||||
// Noise node using the explicit seed
|
|
||||||
const noiseNode: NoiseInvocation = {
|
|
||||||
id: NOISE,
|
|
||||||
type: 'noise',
|
|
||||||
seed: seed,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
};
|
|
||||||
|
|
||||||
graph.nodes[NOISE] = noiseNode;
|
|
||||||
|
|
||||||
// Connect noise to l2l
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: NOISE, field: 'noise' },
|
|
||||||
destination: {
|
|
||||||
node_id: TEXT_TO_LATENTS,
|
|
||||||
field: 'noise',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Single iteration, random seed
|
|
||||||
if (shouldRandomizeSeed && iterations === 1) {
|
|
||||||
// Random int node to generate the seed
|
|
||||||
const randomIntNode: RandomIntInvocation = {
|
|
||||||
id: RANDOM_INT,
|
|
||||||
type: 'rand_int',
|
|
||||||
};
|
|
||||||
|
|
||||||
// Noise node without any seed
|
|
||||||
const noiseNode: NoiseInvocation = {
|
|
||||||
id: NOISE,
|
|
||||||
type: 'noise',
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
};
|
|
||||||
|
|
||||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
|
||||||
graph.nodes[NOISE] = noiseNode;
|
|
||||||
|
|
||||||
// Connect random int to the seed of the noise node
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RANDOM_INT, field: 'a' },
|
|
||||||
destination: {
|
|
||||||
node_id: NOISE,
|
|
||||||
field: 'seed',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect noise to t2l
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: NOISE, field: 'noise' },
|
|
||||||
destination: {
|
|
||||||
node_id: TEXT_TO_LATENTS,
|
|
||||||
field: 'noise',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multiple iterations, explicit seed
|
|
||||||
if (!shouldRandomizeSeed && iterations > 1) {
|
|
||||||
// Range of size node to generate `iterations` count of seeds - range of size generates a collection
|
|
||||||
// of ints from `start` to `start + size`. The `start` is the seed, and the `size` is the number of
|
|
||||||
// iterations.
|
|
||||||
const rangeOfSizeNode: RangeOfSizeInvocation = {
|
|
||||||
id: RANGE_OF_SIZE,
|
|
||||||
type: 'range_of_size',
|
|
||||||
start: seed,
|
|
||||||
size: iterations,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Iterate node to iterate over the seeds generated by the range of size node
|
|
||||||
const iterateNode: IterateInvocation = {
|
|
||||||
id: ITERATE,
|
|
||||||
type: 'iterate',
|
|
||||||
};
|
|
||||||
|
|
||||||
// Noise node without any seed
|
|
||||||
const noiseNode: NoiseInvocation = {
|
|
||||||
id: NOISE,
|
|
||||||
type: 'noise',
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Adding to the graph
|
|
||||||
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
|
|
||||||
graph.nodes[ITERATE] = iterateNode;
|
|
||||||
graph.nodes[NOISE] = noiseNode;
|
|
||||||
|
|
||||||
// Connect range of size to iterate
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RANGE_OF_SIZE, field: 'collection' },
|
|
||||||
destination: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'collection',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect iterate to noise
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'item',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: NOISE,
|
|
||||||
field: 'seed',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect noise to t2l
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: NOISE, field: 'noise' },
|
|
||||||
destination: {
|
|
||||||
node_id: TEXT_TO_LATENTS,
|
|
||||||
field: 'noise',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multiple iterations, random seed
|
|
||||||
if (shouldRandomizeSeed && iterations > 1) {
|
|
||||||
// Random int node to generate the seed
|
|
||||||
const randomIntNode: RandomIntInvocation = {
|
|
||||||
id: RANDOM_INT,
|
|
||||||
type: 'rand_int',
|
|
||||||
};
|
|
||||||
|
|
||||||
// Range of size node to generate `iterations` count of seeds - range of size generates a collection
|
|
||||||
const rangeOfSizeNode: RangeOfSizeInvocation = {
|
|
||||||
id: RANGE_OF_SIZE,
|
|
||||||
type: 'range_of_size',
|
|
||||||
size: iterations,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Iterate node to iterate over the seeds generated by the range of size node
|
|
||||||
const iterateNode: IterateInvocation = {
|
|
||||||
id: ITERATE,
|
|
||||||
type: 'iterate',
|
|
||||||
};
|
|
||||||
|
|
||||||
// Noise node without any seed
|
|
||||||
const noiseNode: NoiseInvocation = {
|
|
||||||
id: NOISE,
|
|
||||||
type: 'noise',
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Adding to the graph
|
|
||||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
|
||||||
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
|
|
||||||
graph.nodes[ITERATE] = iterateNode;
|
|
||||||
graph.nodes[NOISE] = noiseNode;
|
|
||||||
|
|
||||||
// Connect random int to the start of the range of size so the range starts on the random first seed
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RANDOM_INT, field: 'a' },
|
|
||||||
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect range of size to iterate
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RANGE_OF_SIZE, field: 'collection' },
|
|
||||||
destination: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'collection',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect iterate to noise
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'item',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: NOISE,
|
|
||||||
field: 'seed',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Connect noise to t2l
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: NOISE, field: 'noise' },
|
|
||||||
destination: {
|
|
||||||
node_id: TEXT_TO_LATENTS,
|
|
||||||
field: 'noise',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
|
|
||||||
|
|
||||||
return graph;
|
|
||||||
};
|
|
@ -0,0 +1,20 @@
|
|||||||
|
// friendly node ids
|
||||||
|
export const POSITIVE_CONDITIONING = 'positive_conditioning';
|
||||||
|
export const NEGATIVE_CONDITIONING = 'negative_conditioning';
|
||||||
|
export const TEXT_TO_LATENTS = 'text_to_latents';
|
||||||
|
export const LATENTS_TO_IMAGE = 'latents_to_image';
|
||||||
|
export const NOISE = 'noise';
|
||||||
|
export const RANDOM_INT = 'rand_int';
|
||||||
|
export const RANGE_OF_SIZE = 'range_of_size';
|
||||||
|
export const ITERATE = 'iterate';
|
||||||
|
export const MODEL_LOADER = 'model_loader';
|
||||||
|
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||||
|
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||||
|
export const RESIZE = 'resize_image';
|
||||||
|
export const INPAINT = 'inpaint';
|
||||||
|
export const CONTROL_NET_COLLECT = 'control_net_collect';
|
||||||
|
|
||||||
|
// friendly graph ids
|
||||||
|
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
||||||
|
export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph';
|
||||||
|
export const INPAINT_GRAPH = 'inpaint_graph';
|
@ -1,5 +1,4 @@
|
|||||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||||
import ParamSeedCollapse from 'features/parameters/components/Parameters/Seed/ParamSeedCollapse';
|
|
||||||
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||||
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
||||||
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
|
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
|
||||||
@ -8,6 +7,7 @@ import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
|
|||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||||
|
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||||
|
|
||||||
const UnifiedCanvasParameters = () => {
|
const UnifiedCanvasParameters = () => {
|
||||||
return (
|
return (
|
||||||
@ -16,6 +16,7 @@ const UnifiedCanvasParameters = () => {
|
|||||||
<ParamNegativeConditioning />
|
<ParamNegativeConditioning />
|
||||||
<ProcessButtons />
|
<ProcessButtons />
|
||||||
<UnifiedCanvasCoreParameters />
|
<UnifiedCanvasCoreParameters />
|
||||||
|
<ParamControlNetCollapse />
|
||||||
<ParamVariationCollapse />
|
<ParamVariationCollapse />
|
||||||
<ParamSymmetryCollapse />
|
<ParamSymmetryCollapse />
|
||||||
<ParamSeamCorrectionCollapse />
|
<ParamSeamCorrectionCollapse />
|
||||||
|
@ -21,7 +21,6 @@ export type { ConditioningField } from './models/ConditioningField';
|
|||||||
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
|
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
|
||||||
export type { ControlField } from './models/ControlField';
|
export type { ControlField } from './models/ControlField';
|
||||||
export type { ControlNetInvocation } from './models/ControlNetInvocation';
|
export type { ControlNetInvocation } from './models/ControlNetInvocation';
|
||||||
export type { ControlNetModelConfig } from './models/ControlNetModelConfig';
|
|
||||||
export type { ControlOutput } from './models/ControlOutput';
|
export type { ControlOutput } from './models/ControlOutput';
|
||||||
export type { CreateModelRequest } from './models/CreateModelRequest';
|
export type { CreateModelRequest } from './models/CreateModelRequest';
|
||||||
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
|
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
|
||||||
@ -56,7 +55,6 @@ export type { ImageProcessorInvocation } from './models/ImageProcessorInvocation
|
|||||||
export type { ImageRecordChanges } from './models/ImageRecordChanges';
|
export type { ImageRecordChanges } from './models/ImageRecordChanges';
|
||||||
export type { ImageResizeInvocation } from './models/ImageResizeInvocation';
|
export type { ImageResizeInvocation } from './models/ImageResizeInvocation';
|
||||||
export type { ImageScaleInvocation } from './models/ImageScaleInvocation';
|
export type { ImageScaleInvocation } from './models/ImageScaleInvocation';
|
||||||
export type { ImageToImageInvocation } from './models/ImageToImageInvocation';
|
|
||||||
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
|
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
|
||||||
export type { ImageUrlsDTO } from './models/ImageUrlsDTO';
|
export type { ImageUrlsDTO } from './models/ImageUrlsDTO';
|
||||||
export type { InfillColorInvocation } from './models/InfillColorInvocation';
|
export type { InfillColorInvocation } from './models/InfillColorInvocation';
|
||||||
@ -65,6 +63,14 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation';
|
|||||||
export type { InpaintInvocation } from './models/InpaintInvocation';
|
export type { InpaintInvocation } from './models/InpaintInvocation';
|
||||||
export type { IntCollectionOutput } from './models/IntCollectionOutput';
|
export type { IntCollectionOutput } from './models/IntCollectionOutput';
|
||||||
export type { IntOutput } from './models/IntOutput';
|
export type { IntOutput } from './models/IntOutput';
|
||||||
|
export type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config';
|
||||||
|
export type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './models/invokeai__backend__model_management__models__lora__LoRAModel__Config';
|
||||||
|
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig';
|
||||||
|
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig';
|
||||||
|
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig';
|
||||||
|
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig';
|
||||||
|
export type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config';
|
||||||
|
export type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './models/invokeai__backend__model_management__models__vae__VaeModel__Config';
|
||||||
export type { IterateInvocation } from './models/IterateInvocation';
|
export type { IterateInvocation } from './models/IterateInvocation';
|
||||||
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
|
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
|
||||||
export type { LatentsField } from './models/LatentsField';
|
export type { LatentsField } from './models/LatentsField';
|
||||||
@ -77,7 +83,6 @@ export type { LoadImageInvocation } from './models/LoadImageInvocation';
|
|||||||
export type { LoraInfo } from './models/LoraInfo';
|
export type { LoraInfo } from './models/LoraInfo';
|
||||||
export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation';
|
export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation';
|
||||||
export type { LoraLoaderOutput } from './models/LoraLoaderOutput';
|
export type { LoraLoaderOutput } from './models/LoraLoaderOutput';
|
||||||
export type { LoRAModelConfig } from './models/LoRAModelConfig';
|
|
||||||
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
||||||
export type { MaskOutput } from './models/MaskOutput';
|
export type { MaskOutput } from './models/MaskOutput';
|
||||||
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
|
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
|
||||||
@ -113,20 +118,13 @@ export type { SchedulerPredictionType } from './models/SchedulerPredictionType';
|
|||||||
export type { SD1ModelLoaderInvocation } from './models/SD1ModelLoaderInvocation';
|
export type { SD1ModelLoaderInvocation } from './models/SD1ModelLoaderInvocation';
|
||||||
export type { SD2ModelLoaderInvocation } from './models/SD2ModelLoaderInvocation';
|
export type { SD2ModelLoaderInvocation } from './models/SD2ModelLoaderInvocation';
|
||||||
export type { ShowImageInvocation } from './models/ShowImageInvocation';
|
export type { ShowImageInvocation } from './models/ShowImageInvocation';
|
||||||
export type { StableDiffusion1ModelCheckpointConfig } from './models/StableDiffusion1ModelCheckpointConfig';
|
|
||||||
export type { StableDiffusion1ModelDiffusersConfig } from './models/StableDiffusion1ModelDiffusersConfig';
|
|
||||||
export type { StableDiffusion2ModelCheckpointConfig } from './models/StableDiffusion2ModelCheckpointConfig';
|
|
||||||
export type { StableDiffusion2ModelDiffusersConfig } from './models/StableDiffusion2ModelDiffusersConfig';
|
|
||||||
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
|
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
|
||||||
export type { SubModelType } from './models/SubModelType';
|
export type { SubModelType } from './models/SubModelType';
|
||||||
export type { SubtractInvocation } from './models/SubtractInvocation';
|
export type { SubtractInvocation } from './models/SubtractInvocation';
|
||||||
export type { TextToImageInvocation } from './models/TextToImageInvocation';
|
|
||||||
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
|
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
|
||||||
export type { TextualInversionModelConfig } from './models/TextualInversionModelConfig';
|
|
||||||
export type { UNetField } from './models/UNetField';
|
export type { UNetField } from './models/UNetField';
|
||||||
export type { UpscaleInvocation } from './models/UpscaleInvocation';
|
export type { UpscaleInvocation } from './models/UpscaleInvocation';
|
||||||
export type { VaeField } from './models/VaeField';
|
export type { VaeField } from './models/VaeField';
|
||||||
export type { VaeModelConfig } from './models/VaeModelConfig';
|
|
||||||
export type { VaeRepo } from './models/VaeRepo';
|
export type { VaeRepo } from './models/VaeRepo';
|
||||||
export type { ValidationError } from './models/ValidationError';
|
export type { ValidationError } from './models/ValidationError';
|
||||||
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
|
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
|
||||||
|
@ -19,3 +19,4 @@ export type ClipField = {
|
|||||||
*/
|
*/
|
||||||
loras: Array<LoraInfo>;
|
loras: Array<LoraInfo>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -26,7 +26,6 @@ import type { ImagePasteInvocation } from './ImagePasteInvocation';
|
|||||||
import type { ImageProcessorInvocation } from './ImageProcessorInvocation';
|
import type { ImageProcessorInvocation } from './ImageProcessorInvocation';
|
||||||
import type { ImageResizeInvocation } from './ImageResizeInvocation';
|
import type { ImageResizeInvocation } from './ImageResizeInvocation';
|
||||||
import type { ImageScaleInvocation } from './ImageScaleInvocation';
|
import type { ImageScaleInvocation } from './ImageScaleInvocation';
|
||||||
import type { ImageToImageInvocation } from './ImageToImageInvocation';
|
|
||||||
import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation';
|
import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation';
|
||||||
import type { InfillColorInvocation } from './InfillColorInvocation';
|
import type { InfillColorInvocation } from './InfillColorInvocation';
|
||||||
import type { InfillPatchMatchInvocation } from './InfillPatchMatchInvocation';
|
import type { InfillPatchMatchInvocation } from './InfillPatchMatchInvocation';
|
||||||
@ -62,7 +61,6 @@ import type { SD2ModelLoaderInvocation } from './SD2ModelLoaderInvocation';
|
|||||||
import type { ShowImageInvocation } from './ShowImageInvocation';
|
import type { ShowImageInvocation } from './ShowImageInvocation';
|
||||||
import type { StepParamEasingInvocation } from './StepParamEasingInvocation';
|
import type { StepParamEasingInvocation } from './StepParamEasingInvocation';
|
||||||
import type { SubtractInvocation } from './SubtractInvocation';
|
import type { SubtractInvocation } from './SubtractInvocation';
|
||||||
import type { TextToImageInvocation } from './TextToImageInvocation';
|
|
||||||
import type { TextToLatentsInvocation } from './TextToLatentsInvocation';
|
import type { TextToLatentsInvocation } from './TextToLatentsInvocation';
|
||||||
import type { UpscaleInvocation } from './UpscaleInvocation';
|
import type { UpscaleInvocation } from './UpscaleInvocation';
|
||||||
import type { ZoeDepthImageProcessorInvocation } from './ZoeDepthImageProcessorInvocation';
|
import type { ZoeDepthImageProcessorInvocation } from './ZoeDepthImageProcessorInvocation';
|
||||||
@ -75,9 +73,10 @@ export type Graph = {
|
|||||||
/**
|
/**
|
||||||
* The nodes in this graph
|
* The nodes in this graph
|
||||||
*/
|
*/
|
||||||
nodes?: Record<string, (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | CompelInvocation | LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CvInpaintInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | DynamicPromptInvocation | RestoreFaceInvocation | UpscaleInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | ImageToImageInvocation | LatentsToLatentsInvocation | InpaintInvocation)>;
|
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation)>;
|
||||||
/**
|
/**
|
||||||
* The connections between nodes and their fields in this graph
|
* The connections between nodes and their fields in this graph
|
||||||
*/
|
*/
|
||||||
edges?: Array<Edge>;
|
edges?: Array<Edge>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ export type GraphExecutionState = {
|
|||||||
/**
|
/**
|
||||||
* The results of node executions
|
* The results of node executions
|
||||||
*/
|
*/
|
||||||
results: Record<string, (IntCollectionOutput | FloatCollectionOutput | ModelLoaderOutput | LoraLoaderOutput | CompelOutput | ImageOutput | MaskOutput | ControlOutput | LatentsOutput | NoiseOutput | IntOutput | FloatOutput | PromptOutput | PromptCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
|
results: Record<string, (ImageOutput | MaskOutput | ControlOutput | ModelLoaderOutput | LoraLoaderOutput | PromptOutput | PromptCollectionOutput | CompelOutput | IntOutput | FloatOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | FloatCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
|
||||||
/**
|
/**
|
||||||
* Errors raised when executing nodes
|
* Errors raised when executing nodes
|
||||||
*/
|
*/
|
||||||
@ -62,3 +62,4 @@ export type GraphExecutionState = {
|
|||||||
*/
|
*/
|
||||||
source_prepared_mapping: Record<string, Array<string>>;
|
source_prepared_mapping: Record<string, Array<string>>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,76 +0,0 @@
|
|||||||
/* istanbul ignore file */
|
|
||||||
/* tslint:disable */
|
|
||||||
/* eslint-disable */
|
|
||||||
|
|
||||||
import type { ImageField } from './ImageField';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generates an image using img2img.
|
|
||||||
*/
|
|
||||||
export type ImageToImageInvocation = {
|
|
||||||
/**
|
|
||||||
* The id of this node. Must be unique among all nodes.
|
|
||||||
*/
|
|
||||||
id: string;
|
|
||||||
/**
|
|
||||||
* Whether or not this node is an intermediate node.
|
|
||||||
*/
|
|
||||||
is_intermediate?: boolean;
|
|
||||||
type?: 'img2img';
|
|
||||||
/**
|
|
||||||
* The prompt to generate an image from
|
|
||||||
*/
|
|
||||||
prompt?: string;
|
|
||||||
/**
|
|
||||||
* The seed to use (omit for random)
|
|
||||||
*/
|
|
||||||
seed?: number;
|
|
||||||
/**
|
|
||||||
* The number of steps to use to generate the image
|
|
||||||
*/
|
|
||||||
steps?: number;
|
|
||||||
/**
|
|
||||||
* The width of the resulting image
|
|
||||||
*/
|
|
||||||
width?: number;
|
|
||||||
/**
|
|
||||||
* The height of the resulting image
|
|
||||||
*/
|
|
||||||
height?: number;
|
|
||||||
/**
|
|
||||||
* The Classifier-Free Guidance, higher values may result in a result closer to the prompt
|
|
||||||
*/
|
|
||||||
cfg_scale?: number;
|
|
||||||
/**
|
|
||||||
* The scheduler to use
|
|
||||||
*/
|
|
||||||
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'unipc';
|
|
||||||
/**
|
|
||||||
* The model to use (currently ignored)
|
|
||||||
*/
|
|
||||||
model?: string;
|
|
||||||
/**
|
|
||||||
* Whether or not to produce progress images during generation
|
|
||||||
*/
|
|
||||||
progress_images?: boolean;
|
|
||||||
/**
|
|
||||||
* The control model to use
|
|
||||||
*/
|
|
||||||
control_model?: string;
|
|
||||||
/**
|
|
||||||
* The processed control image
|
|
||||||
*/
|
|
||||||
control_image?: ImageField;
|
|
||||||
/**
|
|
||||||
* The input image
|
|
||||||
*/
|
|
||||||
image?: ImageField;
|
|
||||||
/**
|
|
||||||
* The strength of the original image
|
|
||||||
*/
|
|
||||||
strength?: number;
|
|
||||||
/**
|
|
||||||
* Whether or not the result should be fit to the aspect ratio of the input image
|
|
||||||
*/
|
|
||||||
fit?: boolean;
|
|
||||||
};
|
|
@ -3,7 +3,10 @@
|
|||||||
/* eslint-disable */
|
/* eslint-disable */
|
||||||
|
|
||||||
import type { ColorField } from './ColorField';
|
import type { ColorField } from './ColorField';
|
||||||
|
import type { ConditioningField } from './ConditioningField';
|
||||||
import type { ImageField } from './ImageField';
|
import type { ImageField } from './ImageField';
|
||||||
|
import type { UNetField } from './UNetField';
|
||||||
|
import type { VaeField } from './VaeField';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generates an image using inpaint.
|
* Generates an image using inpaint.
|
||||||
@ -19,9 +22,13 @@ export type InpaintInvocation = {
|
|||||||
is_intermediate?: boolean;
|
is_intermediate?: boolean;
|
||||||
type?: 'inpaint';
|
type?: 'inpaint';
|
||||||
/**
|
/**
|
||||||
* The prompt to generate an image from
|
* Positive conditioning for generation
|
||||||
*/
|
*/
|
||||||
prompt?: string;
|
positive_conditioning?: ConditioningField;
|
||||||
|
/**
|
||||||
|
* Negative conditioning for generation
|
||||||
|
*/
|
||||||
|
negative_conditioning?: ConditioningField;
|
||||||
/**
|
/**
|
||||||
* The seed to use (omit for random)
|
* The seed to use (omit for random)
|
||||||
*/
|
*/
|
||||||
@ -47,21 +54,13 @@ export type InpaintInvocation = {
|
|||||||
*/
|
*/
|
||||||
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'lms_k' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2s_k' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'dpmpp_2m_sde' | 'dpmpp_2m_sde_k' | 'dpmpp_sde' | 'dpmpp_sde_k' | 'unipc';
|
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'lms_k' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2s_k' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'dpmpp_2m_sde' | 'dpmpp_2m_sde_k' | 'dpmpp_sde' | 'dpmpp_sde_k' | 'unipc';
|
||||||
/**
|
/**
|
||||||
* The model to use (currently ignored)
|
* UNet model
|
||||||
*/
|
*/
|
||||||
model?: string;
|
unet?: UNetField;
|
||||||
/**
|
/**
|
||||||
* Whether or not to produce progress images during generation
|
* Vae model
|
||||||
*/
|
*/
|
||||||
progress_images?: boolean;
|
vae?: VaeField;
|
||||||
/**
|
|
||||||
* The control model to use
|
|
||||||
*/
|
|
||||||
control_model?: string;
|
|
||||||
/**
|
|
||||||
* The processed control image
|
|
||||||
*/
|
|
||||||
control_image?: ImageField;
|
|
||||||
/**
|
/**
|
||||||
* The input image
|
* The input image
|
||||||
*/
|
*/
|
||||||
|
@ -28,3 +28,4 @@ export type LoraInfo = {
|
|||||||
*/
|
*/
|
||||||
weight: number;
|
weight: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -35,3 +35,4 @@ export type LoraLoaderInvocation = {
|
|||||||
*/
|
*/
|
||||||
clip?: ClipField;
|
clip?: ClipField;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -19,3 +19,4 @@ export type LoraLoaderOutput = {
|
|||||||
*/
|
*/
|
||||||
clip?: ClipField;
|
clip?: ClipField;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -24,3 +24,4 @@ export type ModelInfo = {
|
|||||||
*/
|
*/
|
||||||
submodel?: SubModelType;
|
submodel?: SubModelType;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -24,3 +24,4 @@ export type ModelLoaderOutput = {
|
|||||||
*/
|
*/
|
||||||
vae?: VaeField;
|
vae?: VaeField;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2,15 +2,16 @@
|
|||||||
/* tslint:disable */
|
/* tslint:disable */
|
||||||
/* eslint-disable */
|
/* eslint-disable */
|
||||||
|
|
||||||
import type { ControlNetModelConfig } from './ControlNetModelConfig';
|
import type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './invokeai__backend__model_management__models__controlnet__ControlNetModel__Config';
|
||||||
import type { LoRAModelConfig } from './LoRAModelConfig';
|
import type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './invokeai__backend__model_management__models__lora__LoRAModel__Config';
|
||||||
import type { StableDiffusion1ModelCheckpointConfig } from './StableDiffusion1ModelCheckpointConfig';
|
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig';
|
||||||
import type { StableDiffusion1ModelDiffusersConfig } from './StableDiffusion1ModelDiffusersConfig';
|
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig';
|
||||||
import type { StableDiffusion2ModelCheckpointConfig } from './StableDiffusion2ModelCheckpointConfig';
|
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig';
|
||||||
import type { StableDiffusion2ModelDiffusersConfig } from './StableDiffusion2ModelDiffusersConfig';
|
import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig';
|
||||||
import type { TextualInversionModelConfig } from './TextualInversionModelConfig';
|
import type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config';
|
||||||
import type { VaeModelConfig } from './VaeModelConfig';
|
import type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './invokeai__backend__model_management__models__vae__VaeModel__Config';
|
||||||
|
|
||||||
export type ModelsList = {
|
export type ModelsList = {
|
||||||
models: Record<string, Record<string, Record<string, (TextualInversionModelConfig | StableDiffusion2ModelDiffusersConfig | ControlNetModelConfig | StableDiffusion2ModelCheckpointConfig | StableDiffusion1ModelCheckpointConfig | VaeModelConfig | StableDiffusion1ModelDiffusersConfig | LoRAModelConfig)>>>;
|
models: Record<string, Record<string, Record<string, (invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig | invokeai__backend__model_management__models__controlnet__ControlNetModel__Config | invokeai__backend__model_management__models__lora__LoRAModel__Config | invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig | invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config | invokeai__backend__model_management__models__vae__VaeModel__Config | invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig | invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig)>>>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -20,3 +20,4 @@ export type SD1ModelLoaderInvocation = {
|
|||||||
*/
|
*/
|
||||||
model_name?: string;
|
model_name?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -20,3 +20,4 @@ export type SD2ModelLoaderInvocation = {
|
|||||||
*/
|
*/
|
||||||
model_name?: string;
|
model_name?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,64 +0,0 @@
|
|||||||
/* istanbul ignore file */
|
|
||||||
/* tslint:disable */
|
|
||||||
/* eslint-disable */
|
|
||||||
|
|
||||||
import type { ImageField } from './ImageField';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generates an image using text2img.
|
|
||||||
*/
|
|
||||||
export type TextToImageInvocation = {
|
|
||||||
/**
|
|
||||||
* The id of this node. Must be unique among all nodes.
|
|
||||||
*/
|
|
||||||
id: string;
|
|
||||||
/**
|
|
||||||
* Whether or not this node is an intermediate node.
|
|
||||||
*/
|
|
||||||
is_intermediate?: boolean;
|
|
||||||
type?: 'txt2img';
|
|
||||||
/**
|
|
||||||
* The prompt to generate an image from
|
|
||||||
*/
|
|
||||||
prompt?: string;
|
|
||||||
/**
|
|
||||||
* The seed to use (omit for random)
|
|
||||||
*/
|
|
||||||
seed?: number;
|
|
||||||
/**
|
|
||||||
* The number of steps to use to generate the image
|
|
||||||
*/
|
|
||||||
steps?: number;
|
|
||||||
/**
|
|
||||||
* The width of the resulting image
|
|
||||||
*/
|
|
||||||
width?: number;
|
|
||||||
/**
|
|
||||||
* The height of the resulting image
|
|
||||||
*/
|
|
||||||
height?: number;
|
|
||||||
/**
|
|
||||||
* The Classifier-Free Guidance, higher values may result in a result closer to the prompt
|
|
||||||
*/
|
|
||||||
cfg_scale?: number;
|
|
||||||
/**
|
|
||||||
* The scheduler to use
|
|
||||||
*/
|
|
||||||
scheduler?: 'ddim' | 'ddpm' | 'deis' | 'lms' | 'pndm' | 'heun' | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' | 'kdpm_2' | 'kdpm_2_a' | 'dpmpp_2s' | 'dpmpp_2m' | 'dpmpp_2m_k' | 'unipc';
|
|
||||||
/**
|
|
||||||
* The model to use (currently ignored)
|
|
||||||
*/
|
|
||||||
model?: string;
|
|
||||||
/**
|
|
||||||
* Whether or not to produce progress images during generation
|
|
||||||
*/
|
|
||||||
progress_images?: boolean;
|
|
||||||
/**
|
|
||||||
* The control model to use
|
|
||||||
*/
|
|
||||||
control_model?: string;
|
|
||||||
/**
|
|
||||||
* The processed control image
|
|
||||||
*/
|
|
||||||
control_image?: ImageField;
|
|
||||||
};
|
|
@ -19,3 +19,4 @@ export type UNetField = {
|
|||||||
*/
|
*/
|
||||||
loras: Array<LoraInfo>;
|
loras: Array<LoraInfo>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -10,3 +10,4 @@ export type VaeField = {
|
|||||||
*/
|
*/
|
||||||
vae: ModelInfo;
|
vae: ModelInfo;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -0,0 +1,14 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ModelError } from './ModelError';
|
||||||
|
|
||||||
|
export type invokeai__backend__model_management__models__controlnet__ControlNetModel__Config = {
|
||||||
|
path: string;
|
||||||
|
description?: string;
|
||||||
|
format: ('checkpoint' | 'diffusers');
|
||||||
|
default?: boolean;
|
||||||
|
error?: ModelError;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,14 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ModelError } from './ModelError';
|
||||||
|
|
||||||
|
export type invokeai__backend__model_management__models__lora__LoRAModel__Config = {
|
||||||
|
path: string;
|
||||||
|
description?: string;
|
||||||
|
format: ('lycoris' | 'diffusers');
|
||||||
|
default?: boolean;
|
||||||
|
error?: ModelError;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,18 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ModelError } from './ModelError';
|
||||||
|
import type { ModelVariantType } from './ModelVariantType';
|
||||||
|
|
||||||
|
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig = {
|
||||||
|
path: string;
|
||||||
|
description?: string;
|
||||||
|
format: 'checkpoint';
|
||||||
|
default?: boolean;
|
||||||
|
error?: ModelError;
|
||||||
|
vae?: string;
|
||||||
|
config?: string;
|
||||||
|
variant: ModelVariantType;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,17 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ModelError } from './ModelError';
|
||||||
|
import type { ModelVariantType } from './ModelVariantType';
|
||||||
|
|
||||||
|
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig = {
|
||||||
|
path: string;
|
||||||
|
description?: string;
|
||||||
|
format: 'diffusers';
|
||||||
|
default?: boolean;
|
||||||
|
error?: ModelError;
|
||||||
|
vae?: string;
|
||||||
|
variant: ModelVariantType;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,21 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ModelError } from './ModelError';
|
||||||
|
import type { ModelVariantType } from './ModelVariantType';
|
||||||
|
import type { SchedulerPredictionType } from './SchedulerPredictionType';
|
||||||
|
|
||||||
|
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig = {
|
||||||
|
path: string;
|
||||||
|
description?: string;
|
||||||
|
format: 'checkpoint';
|
||||||
|
default?: boolean;
|
||||||
|
error?: ModelError;
|
||||||
|
vae?: string;
|
||||||
|
config?: string;
|
||||||
|
variant: ModelVariantType;
|
||||||
|
prediction_type: SchedulerPredictionType;
|
||||||
|
upcast_attention: boolean;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,20 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ModelError } from './ModelError';
|
||||||
|
import type { ModelVariantType } from './ModelVariantType';
|
||||||
|
import type { SchedulerPredictionType } from './SchedulerPredictionType';
|
||||||
|
|
||||||
|
export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig = {
|
||||||
|
path: string;
|
||||||
|
description?: string;
|
||||||
|
format: 'diffusers';
|
||||||
|
default?: boolean;
|
||||||
|
error?: ModelError;
|
||||||
|
vae?: string;
|
||||||
|
variant: ModelVariantType;
|
||||||
|
prediction_type: SchedulerPredictionType;
|
||||||
|
upcast_attention: boolean;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,14 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ModelError } from './ModelError';
|
||||||
|
|
||||||
|
export type invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config = {
|
||||||
|
path: string;
|
||||||
|
description?: string;
|
||||||
|
format: null;
|
||||||
|
default?: boolean;
|
||||||
|
error?: ModelError;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,14 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ModelError } from './ModelError';
|
||||||
|
|
||||||
|
export type invokeai__backend__model_management__models__vae__VaeModel__Config = {
|
||||||
|
path: string;
|
||||||
|
description?: string;
|
||||||
|
format: ('checkpoint' | 'diffusers');
|
||||||
|
default?: boolean;
|
||||||
|
error?: ModelError;
|
||||||
|
};
|
||||||
|
|
@ -19,18 +19,18 @@ export class ModelsService {
|
|||||||
* @throws ApiError
|
* @throws ApiError
|
||||||
*/
|
*/
|
||||||
public static listModels({
|
public static listModels({
|
||||||
baseModel,
|
baseModel,
|
||||||
modelType,
|
modelType,
|
||||||
}: {
|
}: {
|
||||||
/**
|
/**
|
||||||
* Base model
|
* Base model
|
||||||
*/
|
*/
|
||||||
baseModel?: BaseModelType,
|
baseModel?: BaseModelType,
|
||||||
/**
|
/**
|
||||||
* The type of model to get
|
* The type of model to get
|
||||||
*/
|
*/
|
||||||
modelType?: ModelType,
|
modelType?: ModelType,
|
||||||
}): CancelablePromise<ModelsList> {
|
}): CancelablePromise<ModelsList> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
url: '/api/v1/models/',
|
url: '/api/v1/models/',
|
||||||
@ -51,10 +51,10 @@ modelType?: ModelType,
|
|||||||
* @throws ApiError
|
* @throws ApiError
|
||||||
*/
|
*/
|
||||||
public static updateModel({
|
public static updateModel({
|
||||||
requestBody,
|
requestBody,
|
||||||
}: {
|
}: {
|
||||||
requestBody: CreateModelRequest,
|
requestBody: CreateModelRequest,
|
||||||
}): CancelablePromise<any> {
|
}): CancelablePromise<any> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
url: '/api/v1/models/',
|
url: '/api/v1/models/',
|
||||||
@ -73,10 +73,10 @@ requestBody: CreateModelRequest,
|
|||||||
* @throws ApiError
|
* @throws ApiError
|
||||||
*/
|
*/
|
||||||
public static delModel({
|
public static delModel({
|
||||||
modelName,
|
modelName,
|
||||||
}: {
|
}: {
|
||||||
modelName: string,
|
modelName: string,
|
||||||
}): CancelablePromise<any> {
|
}): CancelablePromise<any> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
url: '/api/v1/models/{model_name}',
|
url: '/api/v1/models/{model_name}',
|
||||||
|
@ -27,7 +27,6 @@ import type { ImagePasteInvocation } from '../models/ImagePasteInvocation';
|
|||||||
import type { ImageProcessorInvocation } from '../models/ImageProcessorInvocation';
|
import type { ImageProcessorInvocation } from '../models/ImageProcessorInvocation';
|
||||||
import type { ImageResizeInvocation } from '../models/ImageResizeInvocation';
|
import type { ImageResizeInvocation } from '../models/ImageResizeInvocation';
|
||||||
import type { ImageScaleInvocation } from '../models/ImageScaleInvocation';
|
import type { ImageScaleInvocation } from '../models/ImageScaleInvocation';
|
||||||
import type { ImageToImageInvocation } from '../models/ImageToImageInvocation';
|
|
||||||
import type { ImageToLatentsInvocation } from '../models/ImageToLatentsInvocation';
|
import type { ImageToLatentsInvocation } from '../models/ImageToLatentsInvocation';
|
||||||
import type { InfillColorInvocation } from '../models/InfillColorInvocation';
|
import type { InfillColorInvocation } from '../models/InfillColorInvocation';
|
||||||
import type { InfillPatchMatchInvocation } from '../models/InfillPatchMatchInvocation';
|
import type { InfillPatchMatchInvocation } from '../models/InfillPatchMatchInvocation';
|
||||||
@ -64,7 +63,6 @@ import type { SD2ModelLoaderInvocation } from '../models/SD2ModelLoaderInvocatio
|
|||||||
import type { ShowImageInvocation } from '../models/ShowImageInvocation';
|
import type { ShowImageInvocation } from '../models/ShowImageInvocation';
|
||||||
import type { StepParamEasingInvocation } from '../models/StepParamEasingInvocation';
|
import type { StepParamEasingInvocation } from '../models/StepParamEasingInvocation';
|
||||||
import type { SubtractInvocation } from '../models/SubtractInvocation';
|
import type { SubtractInvocation } from '../models/SubtractInvocation';
|
||||||
import type { TextToImageInvocation } from '../models/TextToImageInvocation';
|
|
||||||
import type { TextToLatentsInvocation } from '../models/TextToLatentsInvocation';
|
import type { TextToLatentsInvocation } from '../models/TextToLatentsInvocation';
|
||||||
import type { UpscaleInvocation } from '../models/UpscaleInvocation';
|
import type { UpscaleInvocation } from '../models/UpscaleInvocation';
|
||||||
import type { ZoeDepthImageProcessorInvocation } from '../models/ZoeDepthImageProcessorInvocation';
|
import type { ZoeDepthImageProcessorInvocation } from '../models/ZoeDepthImageProcessorInvocation';
|
||||||
@ -82,23 +80,23 @@ export class SessionsService {
|
|||||||
* @throws ApiError
|
* @throws ApiError
|
||||||
*/
|
*/
|
||||||
public static listSessions({
|
public static listSessions({
|
||||||
page,
|
page,
|
||||||
perPage = 10,
|
perPage = 10,
|
||||||
query = '',
|
query = '',
|
||||||
}: {
|
}: {
|
||||||
/**
|
/**
|
||||||
* The page of results to get
|
* The page of results to get
|
||||||
*/
|
*/
|
||||||
page?: number,
|
page?: number,
|
||||||
/**
|
/**
|
||||||
* The number of results per page
|
* The number of results per page
|
||||||
*/
|
*/
|
||||||
perPage?: number,
|
perPage?: number,
|
||||||
/**
|
/**
|
||||||
* The query string to search for
|
* The query string to search for
|
||||||
*/
|
*/
|
||||||
query?: string,
|
query?: string,
|
||||||
}): CancelablePromise<PaginatedResults_GraphExecutionState_> {
|
}): CancelablePromise<PaginatedResults_GraphExecutionState_> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
url: '/api/v1/sessions/',
|
url: '/api/v1/sessions/',
|
||||||
@ -120,10 +118,10 @@ query?: string,
|
|||||||
* @throws ApiError
|
* @throws ApiError
|
||||||
*/
|
*/
|
||||||
public static createSession({
|
public static createSession({
|
||||||
requestBody,
|
requestBody,
|
||||||
}: {
|
}: {
|
||||||
requestBody?: Graph,
|
requestBody?: Graph,
|
||||||
}): CancelablePromise<GraphExecutionState> {
|
}): CancelablePromise<GraphExecutionState> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
url: '/api/v1/sessions/',
|
url: '/api/v1/sessions/',
|
||||||
@ -143,13 +141,13 @@ requestBody?: Graph,
|
|||||||
* @throws ApiError
|
* @throws ApiError
|
||||||
*/
|
*/
|
||||||
public static getSession({
|
public static getSession({
|
||||||
sessionId,
|
sessionId,
|
||||||
}: {
|
}: {
|
||||||
/**
|
/**
|
||||||
* The id of the session to get
|
* The id of the session to get
|
||||||
*/
|
*/
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
}): CancelablePromise<GraphExecutionState> {
|
}): CancelablePromise<GraphExecutionState> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
url: '/api/v1/sessions/{session_id}',
|
url: '/api/v1/sessions/{session_id}',
|
||||||
@ -170,15 +168,15 @@ sessionId: string,
|
|||||||
* @throws ApiError
|
* @throws ApiError
|
||||||
*/
|
*/
|
||||||
public static addNode({
|
public static addNode({
|
||||||
sessionId,
|
sessionId,
|
||||||
requestBody,
|
requestBody,
|
||||||
}: {
|
}: {
|
||||||
/**
|
/**
|
||||||
* The id of the session
|
* The id of the session
|
||||||
*/
|
*/
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
requestBody: (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | CompelInvocation | LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CvInpaintInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | DynamicPromptInvocation | RestoreFaceInvocation | UpscaleInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | ImageToImageInvocation | LatentsToLatentsInvocation | InpaintInvocation),
|
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation),
|
||||||
}): CancelablePromise<string> {
|
}): CancelablePromise<string> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
url: '/api/v1/sessions/{session_id}/nodes',
|
url: '/api/v1/sessions/{session_id}/nodes',
|
||||||
@ -202,20 +200,20 @@ requestBody: (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation |
|
|||||||
* @throws ApiError
|
* @throws ApiError
|
||||||
*/
|
*/
|
||||||
public static updateNode({
|
public static updateNode({
|
||||||
sessionId,
|
sessionId,
|
||||||
nodePath,
|
nodePath,
|
||||||
requestBody,
|
requestBody,
|
||||||
}: {
|
}: {
|
||||||
/**
|
/**
|
||||||
* The id of the session
|
* The id of the session
|
||||||
*/
|
*/
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
/**
|
/**
|
||||||
* The path to the node in the graph
|
* The path to the node in the graph
|
||||||
*/
|
*/
|
||||||
nodePath: string,
|
nodePath: string,
|
||||||
requestBody: (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | CompelInvocation | LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CvInpaintInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | DynamicPromptInvocation | RestoreFaceInvocation | UpscaleInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | ImageToImageInvocation | LatentsToLatentsInvocation | InpaintInvocation),
|
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation),
|
||||||
}): CancelablePromise<GraphExecutionState> {
|
}): CancelablePromise<GraphExecutionState> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'PUT',
|
method: 'PUT',
|
||||||
url: '/api/v1/sessions/{session_id}/nodes/{node_path}',
|
url: '/api/v1/sessions/{session_id}/nodes/{node_path}',
|
||||||
@ -240,18 +238,18 @@ requestBody: (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation |
|
|||||||
* @throws ApiError
|
* @throws ApiError
|
||||||
*/
|
*/
|
||||||
public static deleteNode({
|
public static deleteNode({
|
||||||
sessionId,
|
sessionId,
|
||||||
nodePath,
|
nodePath,
|
||||||
}: {
|
}: {
|
||||||
/**
|
/**
|
||||||
* The id of the session
|
* The id of the session
|
||||||
*/
|
*/
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
/**
|
/**
|
||||||
* The path to the node to delete
|
* The path to the node to delete
|
||||||
*/
|
*/
|
||||||
nodePath: string,
|
nodePath: string,
|
||||||
}): CancelablePromise<GraphExecutionState> {
|
}): CancelablePromise<GraphExecutionState> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
url: '/api/v1/sessions/{session_id}/nodes/{node_path}',
|
url: '/api/v1/sessions/{session_id}/nodes/{node_path}',
|
||||||
@ -274,15 +272,15 @@ nodePath: string,
|
|||||||
* @throws ApiError
|
* @throws ApiError
|
||||||
*/
|
*/
|
||||||
public static addEdge({
|
public static addEdge({
|
||||||
sessionId,
|
sessionId,
|
||||||
requestBody,
|
requestBody,
|
||||||
}: {
|
}: {
|
||||||
/**
|
/**
|
||||||
* The id of the session
|
* The id of the session
|
||||||
*/
|
*/
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
requestBody: Edge,
|
requestBody: Edge,
|
||||||
}): CancelablePromise<GraphExecutionState> {
|
}): CancelablePromise<GraphExecutionState> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
url: '/api/v1/sessions/{session_id}/edges',
|
url: '/api/v1/sessions/{session_id}/edges',
|
||||||
@ -306,33 +304,33 @@ requestBody: Edge,
|
|||||||
* @throws ApiError
|
* @throws ApiError
|
||||||
*/
|
*/
|
||||||
public static deleteEdge({
|
public static deleteEdge({
|
||||||
sessionId,
|
sessionId,
|
||||||
fromNodeId,
|
fromNodeId,
|
||||||
fromField,
|
fromField,
|
||||||
toNodeId,
|
toNodeId,
|
||||||
toField,
|
toField,
|
||||||
}: {
|
}: {
|
||||||
/**
|
/**
|
||||||
* The id of the session
|
* The id of the session
|
||||||
*/
|
*/
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
/**
|
/**
|
||||||
* The id of the node the edge is coming from
|
* The id of the node the edge is coming from
|
||||||
*/
|
*/
|
||||||
fromNodeId: string,
|
fromNodeId: string,
|
||||||
/**
|
/**
|
||||||
* The field of the node the edge is coming from
|
* The field of the node the edge is coming from
|
||||||
*/
|
*/
|
||||||
fromField: string,
|
fromField: string,
|
||||||
/**
|
/**
|
||||||
* The id of the node the edge is going to
|
* The id of the node the edge is going to
|
||||||
*/
|
*/
|
||||||
toNodeId: string,
|
toNodeId: string,
|
||||||
/**
|
/**
|
||||||
* The field of the node the edge is going to
|
* The field of the node the edge is going to
|
||||||
*/
|
*/
|
||||||
toField: string,
|
toField: string,
|
||||||
}): CancelablePromise<GraphExecutionState> {
|
}): CancelablePromise<GraphExecutionState> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
url: '/api/v1/sessions/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}',
|
url: '/api/v1/sessions/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}',
|
||||||
@ -358,18 +356,18 @@ toField: string,
|
|||||||
* @throws ApiError
|
* @throws ApiError
|
||||||
*/
|
*/
|
||||||
public static invokeSession({
|
public static invokeSession({
|
||||||
sessionId,
|
sessionId,
|
||||||
all = false,
|
all = false,
|
||||||
}: {
|
}: {
|
||||||
/**
|
/**
|
||||||
* The id of the session to invoke
|
* The id of the session to invoke
|
||||||
*/
|
*/
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
/**
|
/**
|
||||||
* Whether or not to invoke all remaining invocations
|
* Whether or not to invoke all remaining invocations
|
||||||
*/
|
*/
|
||||||
all?: boolean,
|
all?: boolean,
|
||||||
}): CancelablePromise<any> {
|
}): CancelablePromise<any> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'PUT',
|
method: 'PUT',
|
||||||
url: '/api/v1/sessions/{session_id}/invoke',
|
url: '/api/v1/sessions/{session_id}/invoke',
|
||||||
@ -394,13 +392,13 @@ all?: boolean,
|
|||||||
* @throws ApiError
|
* @throws ApiError
|
||||||
*/
|
*/
|
||||||
public static cancelSessionInvoke({
|
public static cancelSessionInvoke({
|
||||||
sessionId,
|
sessionId,
|
||||||
}: {
|
}: {
|
||||||
/**
|
/**
|
||||||
* The id of the session to cancel
|
* The id of the session to cancel
|
||||||
*/
|
*/
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
}): CancelablePromise<any> {
|
}): CancelablePromise<any> {
|
||||||
return __request(OpenAPI, {
|
return __request(OpenAPI, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
url: '/api/v1/sessions/{session_id}/invoke',
|
url: '/api/v1/sessions/{session_id}/invoke',
|
||||||
|
Loading…
Reference in New Issue
Block a user